diff --git a/UPGRADE.rst b/UPGRADE.rst
index 4de1bb5841..6825b567e9 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -105,6 +105,28 @@ shown below:
return {"localpart": localpart}
+Removal historical Synapse Admin API
+------------------------------------
+
+Historically, the Synapse Admin API has been accessible under:
+
+* ``/_matrix/client/api/v1/admin``
+* ``/_matrix/client/unstable/admin``
+* ``/_matrix/client/r0/admin``
+* ``/_synapse/admin/v1``
+
+The endpoints with ``/_matrix/client/*`` prefixes have been removed as of v1.24.0.
+The Admin API is now only accessible under:
+
+* ``/_synapse/admin/v1``
+
+The only exception is the `/admin/whois` endpoint, which is
+`also available via the client-server API <https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid>`_.
+
+The deprecation of the old endpoints was announced with Synapse 1.20.0 (released
+on 2020-09-22) and makes it easier for homeserver admins to lock down external
+access to the Admin API endpoints.
+
Upgrading to v1.23.0
====================
diff --git a/changelog.d/8772.misc b/changelog.d/8772.misc
new file mode 100644
index 0000000000..d74d0a3d5d
--- /dev/null
+++ b/changelog.d/8772.misc
@@ -0,0 +1 @@
+Add a commandline script to sign arbitrary json objects.
diff --git a/changelog.d/8773.misc b/changelog.d/8773.misc
new file mode 100644
index 0000000000..62778ba410
--- /dev/null
+++ b/changelog.d/8773.misc
@@ -0,0 +1 @@
+Minor log line improvements for the SSO mapping code used to generate Matrix IDs from SSO IDs.
diff --git a/changelog.d/8779.doc b/changelog.d/8779.doc
new file mode 100644
index 0000000000..3641ae7f91
--- /dev/null
+++ b/changelog.d/8779.doc
@@ -0,0 +1 @@
+Update `turn-howto.md` with troubleshooting notes.
diff --git a/changelog.d/8784.misc b/changelog.d/8784.misc
new file mode 100644
index 0000000000..18a4263398
--- /dev/null
+++ b/changelog.d/8784.misc
@@ -0,0 +1 @@
+Fix a bug introduced in v1.20.0 where the user-agent and IP address reported during user registration for CAS, OpenID Connect, and SAML were of the wrong form.
diff --git a/changelog.d/8785.removal b/changelog.d/8785.removal
new file mode 100644
index 0000000000..ee8ee32598
--- /dev/null
+++ b/changelog.d/8785.removal
@@ -0,0 +1 @@
+Remove old `/_matrix/client/*/admin` endpoints which was deprecated since Synapse 1.20.0.
\ No newline at end of file
diff --git a/changelog.d/8795.doc b/changelog.d/8795.doc
new file mode 100644
index 0000000000..f97a74efb5
--- /dev/null
+++ b/changelog.d/8795.doc
@@ -0,0 +1 @@
+Improve the documentation for the admin API to list all media in a room with respect to encrypted events.
diff --git a/changelog.d/8798.bugfix b/changelog.d/8798.bugfix
new file mode 100644
index 0000000000..9bdb2b51ea
--- /dev/null
+++ b/changelog.d/8798.bugfix
@@ -0,0 +1 @@
+Fix a bug where synctl could spawn duplicate copies of a worker. Contributed by Waylon Cude.
diff --git a/changelog.d/8801.feature b/changelog.d/8801.feature
new file mode 100644
index 0000000000..77f7fe4e5d
--- /dev/null
+++ b/changelog.d/8801.feature
@@ -0,0 +1 @@
+Add support for re-trying generation of a localpart for OpenID Connect mapping providers.
diff --git a/changelog.d/8806.misc b/changelog.d/8806.misc
new file mode 100644
index 0000000000..ee144846a5
--- /dev/null
+++ b/changelog.d/8806.misc
@@ -0,0 +1 @@
+Add type hints to HTTP abstractions.
diff --git a/changelog.d/8812.misc b/changelog.d/8812.misc
new file mode 100644
index 0000000000..ee144846a5
--- /dev/null
+++ b/changelog.d/8812.misc
@@ -0,0 +1 @@
+Add type hints to HTTP abstractions.
diff --git a/changelog.d/8815.misc b/changelog.d/8815.misc
new file mode 100644
index 0000000000..647edeb568
--- /dev/null
+++ b/changelog.d/8815.misc
@@ -0,0 +1 @@
+Optimise the lookup for an invite from another homeserver when trying to reject it.
\ No newline at end of file
diff --git a/changelog.d/8818.doc b/changelog.d/8818.doc
new file mode 100644
index 0000000000..571b0e3f60
--- /dev/null
+++ b/changelog.d/8818.doc
@@ -0,0 +1 @@
+Update the formatting of the `push` section of the homeserver config file to better align with the [code style guidelines](https://github.com/matrix-org/synapse/blob/develop/docs/code_style.md#configuration-file-format).
\ No newline at end of file
diff --git a/changelog.d/8822.doc b/changelog.d/8822.doc
new file mode 100644
index 0000000000..4299245990
--- /dev/null
+++ b/changelog.d/8822.doc
@@ -0,0 +1 @@
+Improve documentation how to configure prometheus for workers.
\ No newline at end of file
diff --git a/changelog.d/8823.bugfix b/changelog.d/8823.bugfix
new file mode 100644
index 0000000000..74af1c20b6
--- /dev/null
+++ b/changelog.d/8823.bugfix
@@ -0,0 +1 @@
+Fix `register_new_matrix_user` failing with "Bad Request" when trailing slash is included in server URL. Contributed by @angdraug.
diff --git a/contrib/prometheus/README.md b/contrib/prometheus/README.md
index e646cb7ea7..b3f23bcc80 100644
--- a/contrib/prometheus/README.md
+++ b/contrib/prometheus/README.md
@@ -20,6 +20,7 @@ Add a new job to the main prometheus.conf file:
```
### for Prometheus v2
+
Add a new job to the main prometheus.yml file:
```yaml
@@ -29,14 +30,17 @@ Add a new job to the main prometheus.yml file:
scheme: "https"
static_configs:
- - targets: ['SERVER.LOCATION:PORT']
+ - targets: ["my.server.here:port"]
```
+An example of a Prometheus configuration with workers can be found in
+[metrics-howto.md](https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md).
+
To use `synapse.rules` add
```yaml
- rule_files:
- - "/PATH/TO/synapse-v2.rules"
+ rule_files:
+ - "/PATH/TO/synapse-v2.rules"
```
Metrics are disabled by default when running synapse; they must be enabled
diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md
index 3994e1f1a9..71137c6dfc 100644
--- a/docs/admin_api/media_admin_api.md
+++ b/docs/admin_api/media_admin_api.md
@@ -1,6 +1,7 @@
# List all media in a room
This API gets a list of known media in a room.
+However, it only shows media from unencrypted events or rooms.
The API is:
```
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 84863296e3..1473a3d4e3 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -176,6 +176,13 @@ The api is::
GET /_synapse/admin/v1/whois/<user_id>
+and::
+
+ GET /_matrix/client/r0/admin/whois/<userId>
+
+See also: `Client Server API Whois
+<https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid>`_
+
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
diff --git a/docs/metrics-howto.md b/docs/metrics-howto.md
index fb71af4911..6b84153274 100644
--- a/docs/metrics-howto.md
+++ b/docs/metrics-howto.md
@@ -13,10 +13,12 @@
can be enabled by adding the \"metrics\" resource to the existing
listener as such:
- resources:
- - names:
- - client
- - metrics
+ ```yaml
+ resources:
+ - names:
+ - client
+ - metrics
+ ```
This provides a simple way of adding metrics to your Synapse
installation, and serves under `/_synapse/metrics`. If you do not
@@ -31,11 +33,13 @@
Add a new listener to homeserver.yaml:
- listeners:
- - type: metrics
- port: 9000
- bind_addresses:
- - '0.0.0.0'
+ ```yaml
+ listeners:
+ - type: metrics
+ port: 9000
+ bind_addresses:
+ - '0.0.0.0'
+ ```
For both options, you will need to ensure that `enable_metrics` is
set to `True`.
@@ -47,10 +51,13 @@
It needs to set the `metrics_path` to a non-default value (under
`scrape_configs`):
- - job_name: "synapse"
- metrics_path: "/_synapse/metrics"
- static_configs:
- - targets: ["my.server.here:port"]
+ ```yaml
+ - job_name: "synapse"
+ scrape_interval: 15s
+ metrics_path: "/_synapse/metrics"
+ static_configs:
+ - targets: ["my.server.here:port"]
+ ```
where `my.server.here` is the IP address of Synapse, and `port` is
the listener port configured with the `metrics` resource.
@@ -60,7 +67,8 @@
1. Restart Prometheus.
-1. Consider using the [grafana dashboard](https://github.com/matrix-org/synapse/tree/master/contrib/grafana/) and required [recording rules](https://github.com/matrix-org/synapse/tree/master/contrib/prometheus/)
+1. Consider using the [grafana dashboard](https://github.com/matrix-org/synapse/tree/master/contrib/grafana/)
+ and required [recording rules](https://github.com/matrix-org/synapse/tree/master/contrib/prometheus/)
## Monitoring workers
@@ -76,9 +84,9 @@ To allow collecting metrics from a worker, you need to add a
under `worker_listeners`:
```yaml
- - type: metrics
- bind_address: ''
- port: 9101
+ - type: metrics
+ bind_address: ''
+ port: 9101
```
The `bind_address` and `port` parameters should be set so that
@@ -87,6 +95,38 @@ don't clash with an existing worker.
With this example, the worker's metrics would then be available
on `http://127.0.0.1:9101`.
+Example Prometheus target for Synapse with workers:
+
+```yaml
+ - job_name: "synapse"
+ scrape_interval: 15s
+ metrics_path: "/_synapse/metrics"
+ static_configs:
+ - targets: ["my.server.here:port"]
+ labels:
+ instance: "my.server"
+ job: "master"
+ index: 1
+ - targets: ["my.workerserver.here:port"]
+ labels:
+ instance: "my.server"
+ job: "generic_worker"
+ index: 1
+ - targets: ["my.workerserver.here:port"]
+ labels:
+ instance: "my.server"
+ job: "generic_worker"
+ index: 2
+ - targets: ["my.workerserver.here:port"]
+ labels:
+ instance: "my.server"
+ job: "media_repository"
+ index: 1
+```
+
+Labels (`instance`, `job`, `index`) can be defined as anything.
+The labels are used to group graphs in grafana.
+
## Renaming of metrics & deprecation of old names in 1.2
Synapse 1.2 updates the Prometheus metrics to match the naming
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index ea32a3d266..11267a77ba 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -2429,20 +2429,25 @@ password_providers:
-# Clients requesting push notifications can either have the body of
-# the message sent in the notification poke along with other details
-# like the sender, or just the event ID and room ID (`event_id_only`).
-# If clients choose the former, this option controls whether the
-# notification request includes the content of the event (other details
-# like the sender are still included). For `event_id_only` push, it
-# has no effect.
-#
-# For modern android devices the notification content will still appear
-# because it is loaded by the app. iPhone, however will send a
-# notification saying only that a message arrived and who it came from.
-#
-#push:
-# include_content: true
+## Push ##
+
+push:
+ # Clients requesting push notifications can either have the body of
+ # the message sent in the notification poke along with other details
+ # like the sender, or just the event ID and room ID (`event_id_only`).
+ # If clients choose the former, this option controls whether the
+ # notification request includes the content of the event (other details
+ # like the sender are still included). For `event_id_only` push, it
+ # has no effect.
+ #
+ # For modern android devices the notification content will still appear
+ # because it is loaded by the app. iPhone, however will send a
+ # notification saying only that a message arrived and who it came from.
+ #
+ # The default value is "true" to include message details. Uncomment to only
+ # include the event ID and room ID in push notification payloads.
+ #
+ #include_content: false
# Spam checkers are third-party modules that can block specific actions
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index 707dd73978..dee53b5d40 100644
--- a/docs/sso_mapping_providers.md
+++ b/docs/sso_mapping_providers.md
@@ -63,13 +63,22 @@ A custom mapping provider must specify the following methods:
information from.
- This method must return a string, which is the unique identifier for the
user. Commonly the ``sub`` claim of the response.
-* `map_user_attributes(self, userinfo, token)`
+* `map_user_attributes(self, userinfo, token, failures)`
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
- `token` - A dictionary which includes information necessary to make
further requests to the OpenID provider.
+ - `failures` - An `int` that represents the amount of times the returned
+ mxid localpart mapping has failed. This should be used
+ to create a deduplicated mxid localpart which should be
+ returned instead. For example, if this method returns
+ `john.doe` as the value of `localpart` in the returned
+ dict, and that is already taken on the homeserver, this
+ method will be called again with the same parameters but
+ with failures=1. The method should then return a different
+ `localpart` value, such as `john.doe1`.
- Returns a dictionary with two keys:
- localpart: A required string, used to generate the Matrix ID.
- displayname: An optional string, the display name for the user.
diff --git a/docs/turn-howto.md b/docs/turn-howto.md
index d4a726be66..a470c274a5 100644
--- a/docs/turn-howto.md
+++ b/docs/turn-howto.md
@@ -42,10 +42,10 @@ This will install and start a systemd service called `coturn`.
./configure
- > You may need to install `libevent2`: if so, you should do so in
- > the way recommended by your operating system. You can ignore
- > warnings about lack of database support: a database is unnecessary
- > for this purpose.
+ You may need to install `libevent2`: if so, you should do so in
+ the way recommended by your operating system. You can ignore
+ warnings about lack of database support: a database is unnecessary
+ for this purpose.
1. Build and install it:
@@ -66,6 +66,19 @@ This will install and start a systemd service called `coturn`.
pwgen -s 64 1
+ A `realm` must be specified, but its value is somewhat arbitrary. (It is
+ sent to clients as part of the authentication flow.) It is conventional to
+ set it to be your server name.
+
+1. You will most likely want to configure coturn to write logs somewhere. The
+ easiest way is normally to send them to the syslog:
+
+ syslog
+
+ (in which case, the logs will be available via `journalctl -u coturn` on a
+ systemd system). Alternatively, coturn can be configured to write to a
+ logfile - check the example config file supplied with coturn.
+
1. Consider your security settings. TURN lets users request a relay which will
connect to arbitrary IP addresses and ports. The following configuration is
suggested as a minimum starting point:
@@ -96,11 +109,31 @@ This will install and start a systemd service called `coturn`.
# TLS private key file
pkey=/path/to/privkey.pem
+ In this case, replace the `turn:` schemes in the `turn_uri` settings below
+ with `turns:`.
+
+ We recommend that you only try to set up TLS/DTLS once you have set up a
+ basic installation and got it working.
+
1. Ensure your firewall allows traffic into the TURN server on the ports
- you've configured it to listen on (By default: 3478 and 5349 for the TURN(s)
+ you've configured it to listen on (By default: 3478 and 5349 for TURN
traffic (remember to allow both TCP and UDP traffic), and ports 49152-65535
for the UDP relay.)
+1. We do not recommend running a TURN server behind NAT, and are not aware of
+ anyone doing so successfully.
+
+ If you want to try it anyway, you will at least need to tell coturn its
+ external IP address:
+
+ external-ip=192.88.99.1
+
+ ... and your NAT gateway must forward all of the relayed ports directly
+ (eg, port 56789 on the external IP must be always be forwarded to port
+ 56789 on the internal IP).
+
+ If you get this working, let us know!
+
1. (Re)start the turn server:
* If you used the Debian package (or have set up a systemd unit yourself):
@@ -137,9 +170,10 @@ Your home server configuration file needs the following extra keys:
without having gone through a CAPTCHA or similar to register a
real account.
-As an example, here is the relevant section of the config file for matrix.org:
+As an example, here is the relevant section of the config file for `matrix.org`. The
+`turn_uris` are appropriate for TURN servers listening on the default ports, with no TLS.
- turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ]
+ turn_uris: [ "turn:turn.matrix.org?transport=udp", "turn:turn.matrix.org?transport=tcp" ]
turn_shared_secret: "n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons"
turn_user_lifetime: 86400000
turn_allow_guests: True
@@ -155,5 +189,86 @@ After updating the homeserver configuration, you must restart synapse:
```
systemctl restart synapse.service
```
+... and then reload any clients (or wait an hour for them to refresh their
+settings).
+
+## Troubleshooting
+
+The normal symptoms of a misconfigured TURN server are that calls between
+devices on different networks ring, but get stuck at "call
+connecting". Unfortunately, troubleshooting this can be tricky.
+
+Here are a few things to try:
+
+ * Check that your TURN server is not behind NAT. As above, we're not aware of
+ anyone who has successfully set this up.
+
+ * Check that you have opened your firewall to allow TCP and UDP traffic to the
+ TURN ports (normally 3478 and 5479).
+
+ * Check that you have opened your firewall to allow UDP traffic to the UDP
+ relay ports (49152-65535 by default).
+
+ * Some WebRTC implementations (notably, that of Google Chrome) appear to get
+ confused by TURN servers which are reachable over IPv6 (this appears to be
+ an unexpected side-effect of its handling of multiple IP addresses as
+ defined by
+ [`draft-ietf-rtcweb-ip-handling`](https://tools.ietf.org/html/draft-ietf-rtcweb-ip-handling-12)).
+
+ Try removing any AAAA records for your TURN server, so that it is only
+ reachable over IPv4.
+
+ * Enable more verbose logging in coturn via the `verbose` setting:
+
+ ```
+ verbose
+ ```
+
+ ... and then see if there are any clues in its logs.
+
+ * If you are using a browser-based client under Chrome, check
+ `chrome://webrtc-internals/` for insights into the internals of the
+ negotiation. On Firefox, check the "Connection Log" on `about:webrtc`.
+
+ (Understanding the output is beyond the scope of this document!)
+
+ * There is a WebRTC test tool at
+ https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To
+ use it, you will need a username/password for your TURN server. You can
+ either:
+
+ * look for the `GET /_matrix/client/r0/voip/turnServer` request made by a
+ matrix client to your homeserver in your browser's network inspector. In
+ the response you should see `username` and `password`. Or:
+
+ * Use the following shell commands:
+
+ ```sh
+ secret=staticAuthSecretHere
+
+ u=$((`date +%s` + 3600)):test
+ p=$(echo -n $u | openssl dgst -hmac $secret -sha1 -binary | base64)
+ echo -e "username: $u\npassword: $p"
+ ```
+
+ Or:
+
+ * Temporarily configure coturn to accept a static username/password. To do
+ this, comment out `use-auth-secret` and `static-auth-secret` and add the
+ following:
+
+ ```
+ lt-cred-mech
+ user=username:password
+ ```
+
+ **Note**: these settings will not take effect unless `use-auth-secret`
+ and `static-auth-secret` are disabled.
+
+ Restart coturn after changing the configuration file.
+
+ Remember to restore the original settings to go back to testing with
+ Matrix clients!
-..and your Home Server now supports VoIP relaying!
+ If the TURN server is working correctly, you should see at least one `relay`
+ entry in the results.
diff --git a/mypy.ini b/mypy.ini
index fc9f8d8050..a5503abe26 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -8,6 +8,7 @@ show_traceback = True
mypy_path = stubs
warn_unreachable = True
files =
+ scripts-dev/sign_json,
synapse/api,
synapse/appservice,
synapse/config,
@@ -37,13 +38,17 @@ files =
synapse/handlers/presence.py,
synapse/handlers/profile.py,
synapse/handlers/read_marker.py,
+ synapse/handlers/register.py,
synapse/handlers/room.py,
synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py,
synapse/handlers/saml_handler.py,
synapse/handlers/sync.py,
synapse/handlers/ui_auth,
+ synapse/http/client.py,
+ synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/well_known_resolver.py,
+ synapse/http/matrixfederationclient.py,
synapse/http/server.py,
synapse/http/site.py,
synapse/logging,
@@ -105,7 +110,7 @@ ignore_missing_imports = True
[mypy-opentracing]
ignore_missing_imports = True
-[mypy-OpenSSL]
+[mypy-OpenSSL.*]
ignore_missing_imports = True
[mypy-netaddr]
diff --git a/scripts-dev/sign_json b/scripts-dev/sign_json
new file mode 100755
index 0000000000..44553fb79a
--- /dev/null
+++ b/scripts-dev/sign_json
@@ -0,0 +1,127 @@
+#!/usr/bin/env python
+#
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import json
+import sys
+from json import JSONDecodeError
+
+import yaml
+from signedjson.key import read_signing_keys
+from signedjson.sign import sign_json
+
+from synapse.util import json_encoder
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="""Adds a signature to a JSON object.
+
+Example usage:
+
+ $ scripts-dev/sign_json.py -N test -k localhost.signing.key "{}"
+ {"signatures":{"test":{"ed25519:a_ZnZh":"LmPnml6iM0iR..."}}}
+""",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+
+ parser.add_argument(
+ "-N",
+ "--server-name",
+ help="Name to give as the local homeserver. If unspecified, will be "
+ "read from the config file.",
+ )
+
+ parser.add_argument(
+ "-k",
+ "--signing-key-path",
+ help="Path to the file containing the private ed25519 key to sign the "
+ "request with.",
+ )
+
+ parser.add_argument(
+ "-c",
+ "--config",
+ default="homeserver.yaml",
+ help=(
+ "Path to synapse config file, from which the server name and/or signing "
+ "key path will be read. Ignored if --server-name and --signing-key-path "
+ "are both given."
+ ),
+ )
+
+ input_args = parser.add_mutually_exclusive_group()
+
+ input_args.add_argument("input_data", nargs="?", help="Raw JSON to be signed.")
+
+ input_args.add_argument(
+ "-i",
+ "--input",
+ type=argparse.FileType("r"),
+ default=sys.stdin,
+ help=(
+ "A file from which to read the JSON to be signed. If neither --input nor "
+ "input_data are given, JSON will be read from stdin."
+ ),
+ )
+
+ parser.add_argument(
+ "-o",
+ "--output",
+ type=argparse.FileType("w"),
+ default=sys.stdout,
+ help="Where to write the signed JSON. Defaults to stdout.",
+ )
+
+ args = parser.parse_args()
+
+ if not args.server_name or not args.signing_key_path:
+ read_args_from_config(args)
+
+ with open(args.signing_key_path) as f:
+ key = read_signing_keys(f)[0]
+
+ json_to_sign = args.input_data
+ if json_to_sign is None:
+ json_to_sign = args.input.read()
+
+ try:
+ obj = json.loads(json_to_sign)
+ except JSONDecodeError as e:
+ print("Unable to parse input as JSON: %s" % e, file=sys.stderr)
+ sys.exit(1)
+
+ if not isinstance(obj, dict):
+ print("Input json was not an object", file=sys.stderr)
+ sys.exit(1)
+
+ sign_json(obj, args.server_name, key)
+ for c in json_encoder.iterencode(obj):
+ args.output.write(c)
+ args.output.write("\n")
+
+
+def read_args_from_config(args: argparse.Namespace) -> None:
+ with open(args.config, "r") as fh:
+ config = yaml.safe_load(fh)
+ if not args.server_name:
+ args.server_name = config["server_name"]
+ if not args.signing_key_path:
+ args.signing_key_path = config["signing_key_path"]
+
+
+if __name__ == "__main__":
+ main()
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index da0996edbc..dfe26dea6d 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -37,7 +37,7 @@ def request_registration(
exit=sys.exit,
):
- url = "%s/_matrix/client/r0/admin/register" % (server_location,)
+ url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),)
# Get the nonce
r = requests.get(url, verify=False)
diff --git a/synapse/config/push.py b/synapse/config/push.py
index a1f3752c8a..a71baac89c 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -21,7 +21,7 @@ class PushConfig(Config):
section = "push"
def read_config(self, config, **kwargs):
- push_config = config.get("push", {})
+ push_config = config.get("push") or {}
self.push_include_content = push_config.get("include_content", True)
pusher_instances = config.get("pusher_instances") or []
@@ -49,18 +49,23 @@ class PushConfig(Config):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
- # Clients requesting push notifications can either have the body of
- # the message sent in the notification poke along with other details
- # like the sender, or just the event ID and room ID (`event_id_only`).
- # If clients choose the former, this option controls whether the
- # notification request includes the content of the event (other details
- # like the sender are still included). For `event_id_only` push, it
- # has no effect.
- #
- # For modern android devices the notification content will still appear
- # because it is loaded by the app. iPhone, however will send a
- # notification saying only that a message arrived and who it came from.
- #
- #push:
- # include_content: true
+ ## Push ##
+
+ push:
+ # Clients requesting push notifications can either have the body of
+ # the message sent in the notification poke along with other details
+ # like the sender, or just the event ID and room ID (`event_id_only`).
+ # If clients choose the former, this option controls whether the
+ # notification request includes the content of the event (other details
+ # like the sender are still included). For `event_id_only` push, it
+ # has no effect.
+ #
+ # For modern android devices the notification content will still appear
+ # because it is loaded by the app. iPhone, however will send a
+ # notification saying only that a message arrived and who it came from.
+ #
+ # The default value is "true" to include message details. Uncomment to only
+ # include the event ID and room ID in push notification payloads.
+ #
+ #include_content: false
"""
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 048a3b3c0b..f4ea0a9767 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib
-from typing import Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
from xml.etree import ElementTree as ET
from twisted.web.client import PartialDownloadError
@@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -31,10 +34,10 @@ class CasHandler:
Utility class for to handle the response from a CAS SSO service.
Args:
- hs (synapse.server.HomeServer)
+ hs
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
@@ -200,27 +203,57 @@ class CasHandler:
args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args)
- localpart = map_username_to_mxid_localpart(username)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = await self._auth_handler.check_user_exists(user_id)
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.get_user_agent("")
+ ip_address = self.hs.get_ip_from_request(request)
+
+ # Get the matrix ID from the CAS username.
+ user_id = await self._map_cas_user_to_matrix_user(
+ username, user_display_name, user_agent, ip_address
+ )
if session:
await self._auth_handler.complete_sso_ui_auth(
- registered_user_id, session, request,
+ user_id, session, request,
)
-
else:
- if not registered_user_id:
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
-
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=user_display_name,
- user_agent_ips=(user_agent, ip_address),
- )
+ # If this not a UI auth request than there must be a redirect URL.
+ assert client_redirect_url
await self._auth_handler.complete_sso_login(
- registered_user_id, request, client_redirect_url
+ user_id, request, client_redirect_url
)
+
+ async def _map_cas_user_to_matrix_user(
+ self,
+ remote_user_id: str,
+ display_name: Optional[str],
+ user_agent: str,
+ ip_address: str,
+ ) -> str:
+ """
+ Given a CAS username, retrieve the user ID for it and possibly register the user.
+
+ Args:
+ remote_user_id: The username from the CAS response.
+ display_name: The display name from the CAS response.
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
+
+ Returns:
+ The user ID associated with this response.
+ """
+
+ localpart = map_username_to_mxid_localpart(remote_user_id)
+ user_id = UserID(localpart, self._hostname).to_string()
+ registered_user_id = await self._auth_handler.check_user_exists(user_id)
+
+ # If the user does not exist, register it.
+ if not registered_user_id:
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart,
+ default_display_name=display_name,
+ user_agent_ips=[(user_agent, ip_address)],
+ )
+
+ return registered_user_id
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 4bfd8d5617..78c4e94a9d 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@@ -35,15 +36,10 @@ from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
-from synapse.handlers.sso import MappingException
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.types import (
- JsonDict,
- UserID,
- contains_invalid_mxid_characters,
- map_username_to_mxid_localpart,
-)
+from synapse.types import JsonDict, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -869,73 +865,51 @@ class OidcHandler(BaseHandler):
# to be strings.
remote_user_id = str(remote_user_id)
- # first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
- self._auth_provider_id, remote_user_id,
+ # Older mapping providers don't accept the `failures` argument, so we
+ # try and detect support.
+ mapper_signature = inspect.signature(
+ self._user_mapping_provider.map_user_attributes
)
- if previously_registered_user_id:
- return previously_registered_user_id
+ supports_failures = "failures" in mapper_signature.parameters
- # Otherwise, generate a new user.
- try:
- attributes = await self._user_mapping_provider.map_user_attributes(
- userinfo, token
- )
- except Exception as e:
- raise MappingException(
- "Could not extract user attributes from OIDC response: " + str(e)
- )
+ async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Call the mapping provider to map the OIDC userinfo and token to user attributes.
- logger.debug(
- "Retrieved user attributes from user mapping provider: %r", attributes
- )
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ if supports_failures:
+ attributes = await self._user_mapping_provider.map_user_attributes(
+ userinfo, token, failures
+ )
+ else:
+ # If the mapping provider does not support processing failures,
+ # do not continually generate the same Matrix ID since it will
+ # continue to already be in use. Note that the error raised is
+ # arbitrary and will get turned into a MappingException.
+ if failures:
+ raise RuntimeError(
+ "Mapping provider does not support de-duplicating Matrix IDs"
+ )
- localpart = attributes["localpart"]
- if not localpart:
- raise MappingException(
- "Error parsing OIDC response: OIDC mapping provider plugin "
- "did not return a localpart value"
- )
+ attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
+ userinfo, token
+ )
- user_id = UserID(localpart, self.server_name).to_string()
- users = await self.store.get_users_by_id_case_insensitive(user_id)
- if users:
- if self._allow_existing_users:
- if len(users) == 1:
- registered_user_id = next(iter(users))
- elif user_id in users:
- registered_user_id = user_id
- else:
- raise MappingException(
- "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
- user_id, list(users.keys())
- )
- )
- else:
- # This mxid is taken
- raise MappingException("mxid '{}' is already taken".format(user_id))
- else:
- # Since the localpart is provided via a potentially untrusted module,
- # ensure the MXID is valid before registering.
- if contains_invalid_mxid_characters(localpart):
- raise MappingException("localpart is invalid: %s" % (localpart,))
-
- # It's the first time this user is logging in and the mapped mxid was
- # not taken, register the user
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=attributes["display_name"],
- user_agent_ips=(user_agent, ip_address),
- )
+ return UserAttributes(**attributes)
- await self.store.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id,
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ oidc_response_to_user_attributes,
+ self._allow_existing_users,
)
- return registered_user_id
-UserAttribute = TypedDict(
- "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+UserAttributeDict = TypedDict(
+ "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")
@@ -978,13 +952,15 @@ class OidcMappingProvider(Generic[C]):
raise NotImplementedError()
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
"""Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
+ failures: How many times a call to this function with this
+ UserInfo has resulted in a failure.
Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
@@ -1084,13 +1060,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
return userinfo[self._config.subject_claim]
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
localpart = self._config.localpart_template.render(user=userinfo).strip()
# Ensure only valid characters are included in the MXID.
localpart = map_username_to_mxid_localpart(localpart)
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
+
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
@@ -1100,7 +1080,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
- return UserAttribute(localpart=localpart, display_name=display_name)
+ return UserAttributeDict(localpart=localpart, display_name=display_name)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 3ebf8d52d3..c227c4fe91 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,10 +15,12 @@
"""Contains functions for registering clients."""
import logging
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
@@ -32,16 +34,14 @@ from synapse.types import RoomAlias, UserID, create_requester
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler):
- def __init__(self, hs):
- """
-
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
@@ -74,7 +74,10 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
async def check_username(
- self, localpart, guest_access_token=None, assigned_user_id=None,
+ self,
+ localpart: str,
+ guest_access_token: Optional[str] = None,
+ assigned_user_id: Optional[str] = None,
):
"""
@@ -155,39 +158,45 @@ class RegistrationHandler(BaseHandler):
async def register_user(
self,
- localpart=None,
- password_hash=None,
- guest_access_token=None,
- make_guest=False,
- admin=False,
- threepid=None,
- user_type=None,
- default_display_name=None,
- address=None,
- bind_emails=[],
- by_admin=False,
- user_agent_ips=None,
- ):
+ localpart: Optional[str] = None,
+ password_hash: Optional[str] = None,
+ guest_access_token: Optional[str] = None,
+ make_guest: bool = False,
+ admin: bool = False,
+ threepid: Optional[dict] = None,
+ user_type: Optional[str] = None,
+ default_display_name: Optional[str] = None,
+ address: Optional[str] = None,
+ bind_emails: List[str] = [],
+ by_admin: bool = False,
+ user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+ ) -> str:
"""Registers a new client on the server.
Args:
localpart: The local part of the user ID to register. If None,
one will be generated.
- password_hash (str|None): The hashed password to assign to this user so they can
+ password_hash: The hashed password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
- user_type (str|None): type of user. One of the values from
+ guest_access_token: The access token used when this was a guest
+ account.
+ make_guest: True if the the new user should be guest,
+ false to add a regular user account.
+ admin: True if the user should be registered as a server admin.
+ threepid: The threepid used for registering, if any.
+ user_type: type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
- default_display_name (unicode|None): if set, the new user's displayname
+ default_display_name: if set, the new user's displayname
will be set to this. Defaults to 'localpart'.
- address (str|None): the IP address used to perform the registration.
- bind_emails (List[str]): list of emails to bind to this account.
- by_admin (bool): True if this registration is being made via the
+ address: the IP address used to perform the registration.
+ bind_emails: list of emails to bind to this account.
+ by_admin: True if this registration is being made via the
admin api, otherwise False.
- user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
+ user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
Returns:
- str: user_id
+ The registere user_id.
Raises:
SynapseError if there was a problem registering.
"""
@@ -257,8 +266,10 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
fail_count = 0
- user = None
- while not user:
+ # If a default display name is not given, generate one.
+ generate_display_name = default_display_name is None
+ # This breaks on successful registration *or* errors after 10 failures.
+ while True:
# Fail after being unable to find a suitable ID a few times
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
@@ -267,7 +278,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
- if default_display_name is None:
+ if generate_display_name:
default_display_name = localpart
try:
await self.register_with_store(
@@ -288,8 +299,6 @@ class RegistrationHandler(BaseHandler):
break
except SynapseError:
# if user id is taken, just generate another
- user = None
- user_id = None
fail_count += 1
if not self.hs.config.user_consent_at_registration:
@@ -329,7 +338,7 @@ class RegistrationHandler(BaseHandler):
return user_id
- async def _create_and_join_rooms(self, user_id: str):
+ async def _create_and_join_rooms(self, user_id: str) -> None:
"""
Create the auto-join rooms and join or invite the user to them.
@@ -413,7 +422,7 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
- async def _join_rooms(self, user_id: str):
+ async def _join_rooms(self, user_id: str) -> None:
"""
Join or invite the user to the auto-join rooms.
@@ -459,6 +468,9 @@ class RegistrationHandler(BaseHandler):
# Send the invite, if necessary.
if requires_invite:
+ # If an invite is required, there must be a auto-join user ID.
+ assert self.hs.config.registration.auto_join_user_id
+
await room_member_handler.update_membership(
requester=create_requester(
self.hs.config.registration.auto_join_user_id,
@@ -490,7 +502,7 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
- async def _auto_join_rooms(self, user_id: str):
+ async def _auto_join_rooms(self, user_id: str) -> None:
"""Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created.
@@ -513,17 +525,17 @@ class RegistrationHandler(BaseHandler):
else:
await self._join_rooms(user_id)
- async def post_consent_actions(self, user_id):
+ async def post_consent_actions(self, user_id: str) -> None:
"""A series of registration actions that can only be carried out once consent
has been granted
Args:
- user_id (str): The user to join
+ user_id: The user to join
"""
await self._auto_join_rooms(user_id)
async def appservice_register(
- self, user_localpart, as_token, password_hash, display_name
+ self, user_localpart: str, as_token: str, password_hash: str, display_name: str
):
# FIXME: this should be factored out and merged with normal register()
user = UserID(user_localpart, self.hs.hostname)
@@ -564,7 +576,9 @@ class RegistrationHandler(BaseHandler):
return user_id
- def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
+ def check_user_id_not_appservice_exclusive(
+ self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
+ ) -> None:
# don't allow people to register the server notices mxid
if self._server_notices_mxid is not None:
if user_id == self._server_notices_mxid:
@@ -619,24 +633,12 @@ class RegistrationHandler(BaseHandler):
},
)
- async def _generate_user_id(self):
- if self._next_generated_user_id is None:
- with await self._generate_user_id_linearizer.queue(()):
- if self._next_generated_user_id is None:
- self._next_generated_user_id = (
- await self.store.find_next_generated_user_id_localpart()
- )
-
- id = self._next_generated_user_id
- self._next_generated_user_id += 1
- return str(id)
-
- def check_registration_ratelimit(self, address):
+ def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
Args:
- address (str|None): the IP address used to perform the registration. If this is
+ address: the IP address used to perform the registration. If this is
None, no ratelimiting will be performed.
Raises:
@@ -647,42 +649,39 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter.ratelimit(address)
- def register_with_store(
+ async def register_with_store(
self,
- user_id,
- password_hash=None,
- was_guest=False,
- make_guest=False,
- appservice_id=None,
- create_profile_with_displayname=None,
- admin=False,
- user_type=None,
- address=None,
- shadow_banned=False,
- ):
+ user_id: str,
+ password_hash: Optional[str] = None,
+ was_guest: bool = False,
+ make_guest: bool = False,
+ appservice_id: Optional[str] = None,
+ create_profile_with_displayname: Optional[str] = None,
+ admin: bool = False,
+ user_type: Optional[str] = None,
+ address: Optional[str] = None,
+ shadow_banned: bool = False,
+ ) -> None:
"""Register user in the datastore.
Args:
- user_id (str): The desired user ID to register.
- password_hash (str|None): Optional. The password hash for this user.
- was_guest (bool): Optional. Whether this is a guest account being
+ user_id: The desired user ID to register.
+ password_hash: Optional. The password hash for this user.
+ was_guest: Optional. Whether this is a guest account being
upgraded to a non-guest account.
- make_guest (boolean): True if the the new user should be guest,
+ make_guest: True if the the new user should be guest,
false to add a regular user account.
- appservice_id (str|None): The ID of the appservice registering the user.
- create_profile_with_displayname (unicode|None): Optionally create a
+ appservice_id: The ID of the appservice registering the user.
+ create_profile_with_displayname: Optionally create a
profile for the user, setting their displayname to the given value
- admin (boolean): is an admin user?
- user_type (str|None): type of user. One of the values from
+ admin: is an admin user?
+ user_type: type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
- address (str|None): the IP address used to perform the registration.
- shadow_banned (bool): Whether to shadow-ban the user
-
- Returns:
- Awaitable
+ address: the IP address used to perform the registration.
+ shadow_banned: Whether to shadow-ban the user
"""
if self.hs.config.worker_app:
- return self._register_client(
+ await self._register_client(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
@@ -695,7 +694,7 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
else:
- return self.store.register_user(
+ await self.store.register_user(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
@@ -708,22 +707,24 @@ class RegistrationHandler(BaseHandler):
)
async def register_device(
- self, user_id, device_id, initial_display_name, is_guest=False
- ):
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ is_guest: bool = False,
+ ) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config.
Args:
- user_id (str): full canonical @user:id
- device_id (str|None): The device ID to check, or None to generate
- a new one.
- initial_display_name (str|None): An optional display name for the
- device.
- is_guest (bool): Whether this is a guest account
+ user_id: full canonical @user:id
+ device_id: The device ID to check, or None to generate a new one.
+ initial_display_name: An optional display name for the device.
+ is_guest: Whether this is a guest account
Returns:
- tuple[str, str]: Tuple of device ID and access token
+ Tuple of device ID and access token
"""
if self.hs.config.worker_app:
@@ -743,7 +744,7 @@ class RegistrationHandler(BaseHandler):
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime
- device_id = await self.device_handler.check_device_registered(
+ registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
@@ -753,20 +754,21 @@ class RegistrationHandler(BaseHandler):
)
else:
access_token = await self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id, valid_until_ms=valid_until_ms
+ user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
)
- return (device_id, access_token)
+ return (registered_device_id, access_token)
- async def post_registration_actions(self, user_id, auth_result, access_token):
+ async def post_registration_actions(
+ self, user_id: str, auth_result: dict, access_token: Optional[str]
+ ) -> None:
"""A user has completed registration
Args:
- user_id (str): The user ID that consented
- auth_result (dict): The authenticated credentials of the newly
- registered user.
- access_token (str|None): The access token of the newly logged in
- device, or None if `inhibit_login` enabled.
+ user_id: The user ID that consented
+ auth_result: The authenticated credentials of the newly registered user.
+ access_token: The access token of the newly logged in device, or
+ None if `inhibit_login` enabled.
"""
if self.hs.config.worker_app:
await self._post_registration_client(
@@ -818,19 +820,20 @@ class RegistrationHandler(BaseHandler):
if auth_result and LoginType.TERMS in auth_result:
await self._on_user_consented(user_id, self.hs.config.user_consent_version)
- async def _on_user_consented(self, user_id, consent_version):
+ async def _on_user_consented(self, user_id: str, consent_version: str) -> None:
"""A user consented to the terms on registration
Args:
- user_id (str): The user ID that consented.
- consent_version (str): version of the policy the user has
- consented to.
+ user_id: The user ID that consented.
+ consent_version: version of the policy the user has consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
- async def register_email_threepid(self, user_id, threepid, token):
+ async def register_email_threepid(
+ self, user_id: str, threepid: dict, token: Optional[str]
+ ) -> None:
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
@@ -839,10 +842,9 @@ class RegistrationHandler(BaseHandler):
Must be called on master.
Args:
- user_id (str): id of user
- threepid (object): m.login.email.identity auth response
- token (str|None): access_token for the user, or None if not logged
- in.
+ user_id: id of user
+ threepid: m.login.email.identity auth response
+ token: access_token for the user, or None if not logged in.
"""
reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd):
@@ -868,6 +870,8 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token)
+ # The token better still exist.
+ assert user_tuple
token_id = user_tuple.token_id
await self.pusher_pool.add_pusher(
@@ -882,14 +886,14 @@ class RegistrationHandler(BaseHandler):
data={},
)
- async def _register_msisdn_threepid(self, user_id, threepid):
+ async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None:
"""Add a phone number as a 3pid identifier
Must be called on master.
Args:
- user_id (str): id of user
- threepid (object): m.login.msisdn auth response
+ user_id: id of user
+ threepid: m.login.msisdn auth response
"""
try:
assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index c83d82e81d..9b15dd9951 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -31,7 +31,6 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
@@ -552,10 +551,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
- invite = await self.store.get_invite_for_local_user_in_room(
- user_id=target.to_string(), room_id=room_id
- ) # type: Optional[RoomsForUser]
- if not invite:
+ (
+ current_membership_type,
+ current_membership_event_id,
+ ) = await self.store.get_local_current_membership_for_user_in_room(
+ target.to_string(), room_id
+ )
+ if (
+ current_membership_type != Membership.INVITE
+ or not current_membership_event_id
+ ):
logger.info(
"%s sent a leave request to %s, but that is not an active room "
"on this server, and there is no pending invite",
@@ -565,6 +570,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(404, "Not a known room")
+ invite = await self.store.get_event(current_membership_event_id)
logger.info(
"%s rejects invite to %s from %s", target, room_id, invite.sender
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 5d9b555b13..34db10ffe4 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -25,13 +25,12 @@ from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
-from synapse.handlers.sso import MappingException
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
from synapse.types import (
UserID,
- contains_invalid_mxid_characters,
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
@@ -39,7 +38,7 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
- import synapse.server
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -56,7 +55,7 @@ class Saml2SessionData:
class SamlHandler(BaseHandler):
- def __init__(self, hs: "synapse.server.HomeServer"):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid
@@ -250,14 +249,26 @@ class SamlHandler(BaseHandler):
"Failed to extract remote user id from SAML response"
)
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- # first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
- self._auth_provider_id, remote_user_id,
+ async def saml_response_to_remapped_user_attributes(
+ failures: int,
+ ) -> UserAttributes:
+ """
+ Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
+
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ # Call the mapping provider.
+ result = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, failures, client_redirect_url
+ )
+ # Remap some of the results.
+ return UserAttributes(
+ localpart=result.get("mxid_localpart"),
+ display_name=result.get("displayname"),
+ emails=result.get("emails"),
)
- if previously_registered_user_id:
- return previously_registered_user_id
+ with (await self._mapping_lock.queue(self._auth_provider_id)):
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@@ -268,7 +279,8 @@ class SamlHandler(BaseHandler):
user_id = UserID(
map_username_to_mxid_localpart(attrval), self.server_name
).to_string()
- logger.info(
+
+ logger.debug(
"Looking for existing account based on mapped %s %s",
self._grandfathered_mxid_source_attribute,
user_id,
@@ -283,59 +295,13 @@ class SamlHandler(BaseHandler):
)
return registered_user_id
- # Map saml response to user attributes using the configured mapping provider
- for i in range(1000):
- attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
- saml2_auth, i, client_redirect_url=client_redirect_url,
- )
-
- logger.debug(
- "Retrieved SAML attributes from user mapping provider: %s "
- "(attempt %d)",
- attribute_dict,
- i,
- )
-
- localpart = attribute_dict.get("mxid_localpart")
- if not localpart:
- raise MappingException(
- "Error parsing SAML2 response: SAML mapping provider plugin "
- "did not return a mxid_localpart value"
- )
-
- displayname = attribute_dict.get("displayname")
- emails = attribute_dict.get("emails", [])
-
- # Check if this mxid already exists
- if not await self.store.get_users_by_id_case_insensitive(
- UserID(localpart, self.server_name).to_string()
- ):
- # This mxid is free
- break
- else:
- # Unable to generate a username in 1000 iterations
- # Break and return error to the user
- raise MappingException(
- "Unable to generate a Matrix ID from the SAML response"
- )
-
- # Since the localpart is provided via a potentially untrusted module,
- # ensure the MXID is valid before registering.
- if contains_invalid_mxid_characters(localpart):
- raise MappingException("localpart is invalid: %s" % (localpart,))
-
- logger.info("Mapped SAML user to local part %s", localpart)
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=displayname,
- bind_emails=emails,
- user_agent_ips=(user_agent, ip_address),
- )
-
- await self.store.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ saml_response_to_remapped_user_attributes,
)
- return registered_user_id
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
@@ -450,11 +416,11 @@ class DefaultSamlMappingProvider:
)
# Use the configured mapper for this mxid_source
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ localpart = self._mxid_mapper(mxid_source)
# Append suffix integer if last call to this function failed to produce
- # a usable mxid
- localpart = base_mxid_localpart + (str(failures) if failures else "")
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 9cb1866a71..d963082210 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -13,10 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+
+import attr
from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
+from synapse.types import UserID, contains_invalid_mxid_characters
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -29,9 +32,20 @@ class MappingException(Exception):
"""
+@attr.s
+class UserAttributes:
+ localpart = attr.ib(type=str)
+ display_name = attr.ib(type=Optional[str], default=None)
+ emails = attr.ib(type=List[str], default=attr.Factory(list))
+
+
class SsoHandler(BaseHandler):
+ # The number of attempts to ask the mapping provider for when generating an MXID.
+ _MAP_USERNAME_RETRIES = 1000
+
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
+ self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template
def render_error(
@@ -71,20 +85,165 @@ class SsoHandler(BaseHandler):
Returns:
The mxid of a previously seen user.
"""
- # Check if we already have a mapping for this user.
- logger.info(
+ logger.debug(
"Looking for existing mapping for user %s:%s",
auth_provider_id,
remote_user_id,
)
+
+ # Check if we already have a mapping for this user.
previously_registered_user_id = await self.store.get_user_by_external_id(
auth_provider_id, remote_user_id,
)
# A match was found, return the user ID.
if previously_registered_user_id is not None:
- logger.info("Found existing mapping %s", previously_registered_user_id)
+ logger.info(
+ "Found existing mapping for IdP '%s' and remote_user_id '%s': %s",
+ auth_provider_id,
+ remote_user_id,
+ previously_registered_user_id,
+ )
return previously_registered_user_id
# No match.
return None
+
+ async def get_mxid_from_sso(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ user_agent: str,
+ ip_address: str,
+ sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ allow_existing_users: bool = False,
+ ) -> str:
+ """
+ Given an SSO ID, retrieve the user ID for it and possibly register the user.
+
+ This first checks if the SSO ID has previously been linked to a matrix ID,
+ if it has that matrix ID is returned regardless of the current mapping
+ logic.
+
+ The mapping function is called (potentially multiple times) to generate
+ a localpart for the user.
+
+ If an unused localpart is generated, the user is registered from the
+ given user-agent and IP address and the SSO ID is linked to this matrix
+ ID for subsequent calls.
+
+ If allow_existing_users is true the mapping function is only called once
+ and results in:
+
+ 1. The use of a previously registered matrix ID. In this case, the
+ SSO ID is linked to the matrix ID. (Note it is possible that
+ other SSO IDs are linked to the same matrix ID.)
+ 2. An unused localpart, in which case the user is registered (as
+ discussed above).
+ 3. An error if the generated localpart matches multiple pre-existing
+ matrix IDs. Generally this should not happen.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
+ sso_to_matrix_id_mapper: A callable to generate the user attributes.
+ The only parameter is an integer which represents the amount of
+ times the returned mxid localpart mapping has failed.
+ allow_existing_users: True if the localpart returned from the
+ mapping provider can be linked to an existing matrix ID.
+
+ Returns:
+ The user ID associated with the SSO response.
+
+ Raises:
+ MappingException if there was a problem mapping the response to a user.
+ RedirectException: some mapping providers may raise this if they need
+ to redirect to an interstitial page.
+
+ """
+ # first of all, check if we already have a mapping for this user
+ previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
+ if previously_registered_user_id:
+ return previously_registered_user_id
+
+ # Otherwise, generate a new user.
+ for i in range(self._MAP_USERNAME_RETRIES):
+ try:
+ attributes = await sso_to_matrix_id_mapper(i)
+ except Exception as e:
+ raise MappingException(
+ "Could not extract user attributes from SSO response: " + str(e)
+ )
+
+ logger.debug(
+ "Retrieved user attributes from user mapping provider: %r (attempt %d)",
+ attributes,
+ i,
+ )
+
+ if not attributes.localpart:
+ raise MappingException(
+ "Error parsing SSO response: SSO mapping provider plugin "
+ "did not return a localpart value"
+ )
+
+ # Check if this mxid already exists
+ user_id = UserID(attributes.localpart, self.server_name).to_string()
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
+ # Note, if allow_existing_users is true then the loop is guaranteed
+ # to end on the first iteration: either by matching an existing user,
+ # raising an error, or registering a new user. See the docstring for
+ # more in-depth an explanation.
+ if users and allow_existing_users:
+ # If an existing matrix ID is returned, then use it.
+ if len(users) == 1:
+ previously_registered_user_id = next(iter(users))
+ elif user_id in users:
+ previously_registered_user_id = user_id
+ else:
+ # Do not attempt to continue generating Matrix IDs.
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, users
+ )
+ )
+
+ # Future logins should also match this user ID.
+ await self.store.record_user_external_id(
+ auth_provider_id, remote_user_id, previously_registered_user_id
+ )
+
+ return previously_registered_user_id
+
+ elif not users:
+ # This mxid is free
+ break
+ else:
+ # Unable to generate a username in 1000 iterations
+ # Break and return error to the user
+ raise MappingException(
+ "Unable to generate a Matrix ID from the SSO response"
+ )
+
+ # Since the localpart is provided via a potentially untrusted module,
+ # ensure the MXID is valid before registering.
+ if contains_invalid_mxid_characters(attributes.localpart):
+ raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
+
+ logger.debug("Mapped SSO user to local part %s", attributes.localpart)
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=attributes.localpart,
+ default_display_name=attributes.display_name,
+ bind_emails=attributes.emails,
+ user_agent_ips=[(user_agent, ip_address)],
+ )
+
+ await self.store.record_user_external_id(
+ auth_provider_id, remote_user_id, registered_user_id
+ )
+ return registered_user_id
diff --git a/synapse/http/client.py b/synapse/http/client.py
index f409368802..e5b13593f2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,9 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import urllib
+import urllib.parse
from io import BytesIO
from typing import (
+ TYPE_CHECKING,
Any,
BinaryIO,
Dict,
@@ -31,7 +32,7 @@ from typing import (
import treq
from canonicaljson import encode_canonical_json
-from netaddr import IPAddress
+from netaddr import IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -39,6 +40,8 @@ from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.interfaces import (
+ IAddress,
+ IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
)
@@ -53,7 +56,7 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IAgent, IBodyProducer, IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -63,6 +66,9 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
@@ -84,12 +90,19 @@ QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
-def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
+def check_against_blacklist(
+ ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
+) -> bool:
"""
+ Compares an IP address to allowed and disallowed IP sets.
+
Args:
- ip_address (netaddr.IPAddress)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ ip_address: The IP address to check
+ ip_whitelist: Allowed IP addresses.
+ ip_blacklist: Disallowed IP addresses.
+
+ Returns:
+ True if the IP address is in the blacklist and not in the whitelist.
"""
if ip_address in ip_blacklist:
if ip_whitelist is None or ip_address not in ip_whitelist:
@@ -118,23 +131,30 @@ class IPBlacklistingResolver:
addresses, preventing DNS rebinding attacks on URL preview.
"""
- def __init__(self, reactor, ip_whitelist, ip_blacklist):
+ def __init__(
+ self,
+ reactor: IReactorPluggableNameResolver,
+ ip_whitelist: Optional[IPSet],
+ ip_blacklist: IPSet,
+ ):
"""
Args:
- reactor (twisted.internet.reactor)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ reactor: The twisted reactor.
+ ip_whitelist: IP addresses to allow.
+ ip_blacklist: IP addresses to disallow.
"""
self._reactor = reactor
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
- def resolveHostName(self, recv, hostname, portNumber=0):
+ def resolveHostName(
+ self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
+ ) -> IResolutionReceiver:
r = recv()
- addresses = []
+ addresses = [] # type: List[IAddress]
- def _callback():
+ def _callback() -> None:
r.resolutionBegan(None)
has_bad_ip = False
@@ -161,15 +181,15 @@ class IPBlacklistingResolver:
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
- def resolutionBegan(resolutionInProgress):
+ def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
@staticmethod
- def addressResolved(address):
+ def addressResolved(address: IAddress) -> None:
addresses.append(address)
@staticmethod
- def resolutionComplete():
+ def resolutionComplete() -> None:
_callback()
self._reactor.nameResolver.resolveHostName(
@@ -185,19 +205,29 @@ class BlacklistingAgentWrapper(Agent):
directly (without an IP address lookup).
"""
- def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
+ def __init__(
+ self,
+ agent: IAgent,
+ ip_whitelist: Optional[IPSet] = None,
+ ip_blacklist: Optional[IPSet] = None,
+ ):
"""
Args:
- agent (twisted.web.client.Agent): The Agent to wrap.
- reactor (twisted.internet.reactor)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ agent: The Agent to wrap.
+ ip_whitelist: IP addresses to allow.
+ ip_blacklist: IP addresses to disallow.
"""
self._agent = agent
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
- def request(self, method, uri, headers=None, bodyProducer=None):
+ def request(
+ self,
+ method: bytes,
+ uri: bytes,
+ headers: Optional[Headers] = None,
+ bodyProducer: Optional[IBodyProducer] = None,
+ ) -> defer.Deferred:
h = urllib.parse.urlparse(uri.decode("ascii"))
try:
@@ -226,23 +256,23 @@ class SimpleHttpClient:
def __init__(
self,
- hs,
- treq_args={},
- ip_whitelist=None,
- ip_blacklist=None,
- http_proxy=None,
- https_proxy=None,
+ hs: "HomeServer",
+ treq_args: Dict[str, Any] = {},
+ ip_whitelist: Optional[IPSet] = None,
+ ip_blacklist: Optional[IPSet] = None,
+ http_proxy: Optional[bytes] = None,
+ https_proxy: Optional[bytes] = None,
):
"""
Args:
- hs (synapse.server.HomeServer)
- treq_args (dict): Extra keyword arguments to be given to treq.request.
- ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
+ hs
+ treq_args: Extra keyword arguments to be given to treq.request.
+ ip_blacklist: The IP addresses that are blacklisted that
we may not request.
- ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
+ ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
- http_proxy (bytes): proxy server to use for http connections. host[:port]
- https_proxy (bytes): proxy server to use for https connections. host[:port]
+ http_proxy: proxy server to use for http connections. host[:port]
+ https_proxy: proxy server to use for https connections. host[:port]
"""
self.hs = hs
@@ -306,7 +336,6 @@ class SimpleHttpClient:
# by the DNS resolution.
self.agent = BlacklistingAgentWrapper(
self.agent,
- self.reactor,
ip_whitelist=self._ip_whitelist,
ip_blacklist=self._ip_blacklist,
)
@@ -397,7 +426,7 @@ class SimpleHttpClient:
async def post_urlencoded_get_json(
self,
uri: str,
- args: Mapping[str, Union[str, List[str]]] = {},
+ args: Optional[Mapping[str, Union[str, List[str]]]] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""
@@ -422,9 +451,7 @@ class SimpleHttpClient:
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
- query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
- "utf8"
- )
+ query_bytes = encode_query_args(args)
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@@ -432,7 +459,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes
@@ -479,7 +506,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str
@@ -495,7 +522,10 @@ class SimpleHttpClient:
)
async def get_json(
- self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+ self,
+ uri: str,
+ args: Optional[QueryParams] = None,
+ headers: Optional[RawHeaders] = None,
) -> Any:
"""Gets some json from the given URI.
@@ -516,7 +546,7 @@ class SimpleHttpClient:
"""
actual_headers = {b"Accept": [b"application/json"]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
@@ -525,7 +555,7 @@ class SimpleHttpClient:
self,
uri: str,
json_body: Any,
- args: QueryParams = {},
+ args: Optional[QueryParams] = None,
headers: RawHeaders = None,
) -> Any:
"""Puts some json to the given URI.
@@ -546,9 +576,9 @@ class SimpleHttpClient:
ValueError: if the response was not JSON
"""
- if len(args):
- query_bytes = urllib.parse.urlencode(args, True)
- uri = "%s?%s" % (uri, query_bytes)
+ if args:
+ query_str = urllib.parse.urlencode(args, True)
+ uri = "%s?%s" % (uri, query_str)
json_str = encode_canonical_json(json_body)
@@ -558,7 +588,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str
@@ -574,7 +604,10 @@ class SimpleHttpClient:
)
async def get_raw(
- self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+ self,
+ uri: str,
+ args: Optional[QueryParams] = None,
+ headers: Optional[RawHeaders] = None,
) -> bytes:
"""Gets raw text from the given URI.
@@ -592,13 +625,13 @@ class SimpleHttpClient:
HttpResponseException on a non-2xx HTTP response.
"""
- if len(args):
- query_bytes = urllib.parse.urlencode(args, True)
- uri = "%s?%s" % (uri, query_bytes)
+ if args:
+ query_str = urllib.parse.urlencode(args, True)
+ uri = "%s?%s" % (uri, query_str)
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request("GET", uri, headers=Headers(actual_headers))
@@ -641,7 +674,7 @@ class SimpleHttpClient:
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request("GET", url, headers=Headers(actual_headers))
@@ -649,12 +682,13 @@ class SimpleHttpClient:
if (
b"Content-Length" in resp_headers
+ and max_size
and int(resp_headers[b"Content-Length"][0]) > max_size
):
- logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
+ logger.warning("Requested URL is too large > %r bytes" % (max_size,))
raise SynapseError(
502,
- "Requested file is too large > %r bytes" % (self.max_size,),
+ "Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
@@ -668,7 +702,7 @@ class SimpleHttpClient:
try:
length = await make_deferred_yieldable(
- _readBodyToFile(response, output_stream, max_size)
+ readBodyToFile(response, output_stream, max_size)
)
except SynapseError:
# This can happen e.g. because the body is too large.
@@ -696,18 +730,16 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
-# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
-# The two should be factored out.
-
-
class _ReadBodyToFileProtocol(protocol.Protocol):
- def __init__(self, stream, deferred, max_size):
+ def __init__(
+ self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
+ ):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
- def dataReceived(self, data):
+ def dataReceived(self, data: bytes) -> None:
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
@@ -721,7 +753,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred = defer.Deferred()
self.transport.loseConnection()
- def connectionLost(self, reason):
+ def connectionLost(self, reason: Failure) -> None:
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
@@ -732,35 +764,48 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
-# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
-# The two should be factored out.
+def readBodyToFile(
+ response: IResponse, stream: BinaryIO, max_size: Optional[int]
+) -> defer.Deferred:
+ """
+ Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
+ Args:
+ response: The HTTP response to read from.
+ stream: The file-object to write to.
+ max_size: The maximum file size to allow.
+
+ Returns:
+ A Deferred which resolves to the length of the read body.
+ """
-def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
-def encode_urlencode_args(args):
- return {k: encode_urlencode_arg(v) for k, v in args.items()}
+def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
+ """
+ Encodes a map of query arguments to bytes which can be appended to a URL.
+ Args:
+ args: The query arguments, a mapping of string to string or list of strings.
+
+ Returns:
+ The query arguments encoded as bytes.
+ """
+ if args is None:
+ return b""
-def encode_urlencode_arg(arg):
- if isinstance(arg, str):
- return arg.encode("utf-8")
- elif isinstance(arg, list):
- return [encode_urlencode_arg(i) for i in arg]
- else:
- return arg
+ encoded_args = {}
+ for k, vs in args.items():
+ if isinstance(vs, str):
+ vs = [vs]
+ encoded_args[k] = [v.encode("utf8") for v in vs]
+ query_str = urllib.parse.urlencode(encoded_args, True)
-def _print_ex(e):
- if hasattr(e, "reasons") and e.reasons:
- for ex in e.reasons:
- _print_ex(ex)
- else:
- logger.exception(e)
+ return query_str.encode("utf8")
class InsecureInterceptableContextFactory(ssl.ContextFactory):
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 83d6196d4a..e77f9587d0 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -12,21 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-import urllib
-from typing import List
+import urllib.parse
+from typing import List, Optional
from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import Agent, HTTPConnectionPool
+from twisted.internet.interfaces import (
+ IProtocolFactory,
+ IReactorCore,
+ IStreamClientEndpoint,
+)
+from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IAgentEndpointFactory
+from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -44,30 +48,30 @@ class MatrixFederationAgent:
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
Args:
- reactor (IReactor): twisted reactor to use for underlying requests
+ reactor: twisted reactor to use for underlying requests
- tls_client_options_factory (FederationPolicyForHTTPS|None):
+ tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
- user_agent (bytes):
+ user_agent:
The user agent header to use for federation requests.
- _srv_resolver (SrvResolver|None):
- SRVResolver impl to use for looking up SRV records. None to use a default
- implementation.
+ _srv_resolver:
+ SrvResolver implementation to use for looking up SRV records. None
+ to use a default implementation.
- _well_known_resolver (WellKnownResolver|None):
+ _well_known_resolver:
WellKnownResolver to use to perform well-known lookups. None to use a
default implementation.
"""
def __init__(
self,
- reactor,
- tls_client_options_factory,
- user_agent,
- _srv_resolver=None,
- _well_known_resolver=None,
+ reactor: IReactorCore,
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ user_agent: bytes,
+ _srv_resolver: Optional[SrvResolver] = None,
+ _well_known_resolver: Optional[WellKnownResolver] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@@ -99,15 +103,20 @@ class MatrixFederationAgent:
self._well_known_resolver = _well_known_resolver
@defer.inlineCallbacks
- def request(self, method, uri, headers=None, bodyProducer=None):
+ def request(
+ self,
+ method: bytes,
+ uri: bytes,
+ headers: Optional[Headers] = None,
+ bodyProducer: Optional[IBodyProducer] = None,
+ ) -> defer.Deferred:
"""
Args:
- method (bytes): HTTP method: GET/POST/etc
- uri (bytes): Absolute URI to be retrieved
- headers (twisted.web.http_headers.Headers|None):
- HTTP headers to send with the request, or None to
- send no extra headers.
- bodyProducer (twisted.web.iweb.IBodyProducer|None):
+ method: HTTP method: GET/POST/etc
+ uri: Absolute URI to be retrieved
+ headers:
+ HTTP headers to send with the request, or None to send no extra headers.
+ bodyProducer:
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
@@ -123,6 +132,9 @@ class MatrixFederationAgent:
# explicit port.
parsed_uri = urllib.parse.urlparse(uri)
+ # There must be a valid hostname.
+ assert parsed_uri.hostname
+
# If this is a matrix:// URI check if the server has delegated matrix
# traffic using well-known delegation.
#
@@ -179,7 +191,12 @@ class MatrixHostnameEndpointFactory:
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.
"""
- def __init__(self, reactor, tls_client_options_factory, srv_resolver):
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ srv_resolver: Optional[SrvResolver],
+ ):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
@@ -203,15 +220,20 @@ class MatrixHostnameEndpoint:
resolution (i.e. via SRV). Does not check for well-known delegation.
Args:
- reactor (IReactor)
- tls_client_options_factory (ClientTLSOptionsFactory|None):
+ reactor: twisted reactor to use for underlying requests
+ tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
- srv_resolver (SrvResolver): The SRV resolver to use
- parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
- to connect to.
+ srv_resolver: The SRV resolver to use
+ parsed_uri: The parsed URI that we're wanting to connect to.
"""
- def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ srv_resolver: SrvResolver,
+ parsed_uri: URI,
+ ):
self._reactor = reactor
self._parsed_uri = parsed_uri
@@ -231,13 +253,13 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver
- def connect(self, protocol_factory):
+ def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
"""Implements IStreamClientEndpoint interface
"""
return run_in_background(self._do_connect, protocol_factory)
- async def _do_connect(self, protocol_factory):
+ async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
first_exception = None
server_list = await self._resolve_server()
@@ -303,20 +325,20 @@ class MatrixHostnameEndpoint:
return [Server(host, 8448)]
-def _is_ip_literal(host):
+def _is_ip_literal(host: bytes) -> bool:
"""Test if the given host name is either an IPv4 or IPv6 literal.
Args:
- host (bytes)
+ host: The host name to check
Returns:
- bool
+ True if the hostname is an IP address literal.
"""
- host = host.decode("ascii")
+ host_str = host.decode("ascii")
try:
- IPAddress(host)
+ IPAddress(host_str)
return True
except AddrFormatError:
return False
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 1cc666fbf6..5e08ef1664 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import random
import time
@@ -21,10 +20,11 @@ from typing import Callable, Dict, Optional, Tuple
import attr
from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTime
from twisted.web.client import RedirectAgent, readBody
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IAgent, IResponse
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
@@ -81,11 +81,11 @@ class WellKnownResolver:
def __init__(
self,
- reactor,
- agent,
- user_agent,
- well_known_cache=None,
- had_well_known_cache=None,
+ reactor: IReactorTime,
+ agent: IAgent,
+ user_agent: bytes,
+ well_known_cache: Optional[TTLCache] = None,
+ had_well_known_cache: Optional[TTLCache] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@@ -127,7 +127,7 @@ class WellKnownResolver:
with Measure(self._clock, "get_well_known"):
result, cache_period = await self._fetch_well_known(
server_name
- ) # type: Tuple[Optional[bytes], float]
+ ) # type: Optional[bytes], float
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 7e17cdb73e..4e27f93b7a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -17,8 +17,9 @@ import cgi
import logging
import random
import sys
-import urllib
+import urllib.parse
from io import BytesIO
+from typing import Callable, Dict, List, Optional, Tuple, Union
import attr
import treq
@@ -27,25 +28,27 @@ from prometheus_client import Counter
from signedjson.sign import sign_json
from zope.interface import implementer
-from twisted.internet import defer, protocol
+from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
-from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
- Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
- SynapseError,
)
from synapse.http import QuieterFileBodyProducer
-from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
+from synapse.http.client import (
+ BlacklistingAgentWrapper,
+ IPBlacklistingResolver,
+ encode_query_args,
+ readBodyToFile,
+)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import (
@@ -54,6 +57,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
+from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -76,47 +80,44 @@ MAXINT = sys.maxsize
_next_id = 1
+QueryArgs = Dict[str, Union[str, List[str]]]
+
+
@attr.s(slots=True, frozen=True)
class MatrixFederationRequest:
- method = attr.ib()
+ method = attr.ib(type=str)
"""HTTP method
- :type: str
"""
- path = attr.ib()
+ path = attr.ib(type=str)
"""HTTP path
- :type: str
"""
- destination = attr.ib()
+ destination = attr.ib(type=str)
"""The remote server to send the HTTP request to.
- :type: str"""
+ """
- json = attr.ib(default=None)
+ json = attr.ib(default=None, type=Optional[JsonDict])
"""JSON to send in the body.
- :type: dict|None
"""
- json_callback = attr.ib(default=None)
+ json_callback = attr.ib(default=None, type=Optional[Callable[[], JsonDict]])
"""A callback to generate the JSON.
- :type: func|None
"""
- query = attr.ib(default=None)
+ query = attr.ib(default=None, type=Optional[dict])
"""Query arguments.
- :type: dict|None
"""
- txn_id = attr.ib(default=None)
+ txn_id = attr.ib(default=None, type=Optional[str])
"""Unique ID for this request (for logging)
- :type: str|None
"""
uri = attr.ib(init=False, type=bytes)
"""The URI of this request
"""
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
global _next_id
txn_id = "%s-O-%s" % (self.method, _next_id)
_next_id = (_next_id + 1) % (MAXINT - 1)
@@ -136,7 +137,7 @@ class MatrixFederationRequest:
)
object.__setattr__(self, "uri", uri)
- def get_json(self):
+ def get_json(self) -> Optional[JsonDict]:
if self.json_callback:
return self.json_callback()
return self.json
@@ -148,7 +149,7 @@ async def _handle_json_response(
request: MatrixFederationRequest,
response: IResponse,
start_ms: int,
-):
+) -> JsonDict:
"""
Reads the JSON body of a response, with a timeout
@@ -160,7 +161,7 @@ async def _handle_json_response(
start_ms: Timestamp when request was made
Returns:
- dict: parsed JSON response
+ The parsed JSON response
"""
try:
check_content_type_is_json(response.headers)
@@ -250,9 +251,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent,
- self.reactor,
- ip_blacklist=hs.config.federation_ip_range_blacklist,
+ self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
@@ -266,27 +265,29 @@ class MatrixFederationHttpClient:
self._cooperator = Cooperator(scheduler=schedule)
async def _send_request_with_optional_trailing_slash(
- self, request, try_trailing_slash_on_400=False, **send_request_args
- ):
+ self,
+ request: MatrixFederationRequest,
+ try_trailing_slash_on_400: bool = False,
+ **send_request_args
+ ) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3
due to #3622.
Args:
- request (MatrixFederationRequest): details of request to be sent
- try_trailing_slash_on_400 (bool): Whether on receiving a 400
+ request: details of request to be sent
+ try_trailing_slash_on_400: Whether on receiving a 400
'M_UNRECOGNIZED' from the server to retry the request with a
trailing slash appended to the request path.
- send_request_args (Dict): A dictionary of arguments to pass to
- `_send_request()`.
+ send_request_args: A dictionary of arguments to pass to `_send_request()`.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
Returns:
- Dict: Parsed JSON response body.
+ Parsed JSON response body.
"""
try:
response = await self._send_request(request, **send_request_args)
@@ -313,24 +314,26 @@ class MatrixFederationHttpClient:
async def _send_request(
self,
- request,
- retry_on_dns_fail=True,
- timeout=None,
- long_retries=False,
- ignore_backoff=False,
- backoff_on_404=False,
- ):
+ request: MatrixFederationRequest,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ long_retries: bool = False,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ ) -> IResponse:
"""
Sends a request to the given server.
Args:
- request (MatrixFederationRequest): details of request to be sent
+ request: details of request to be sent
+
+ retry_on_dns_fail: true if the request should be retied on DNS failures
- timeout (int|None): number of milliseconds to wait for the response headers
+ timeout: number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
60s by default.
- long_retries (bool): whether to use the long retry algorithm.
+ long_retries: whether to use the long retry algorithm.
The regular retry algorithm makes 4 attempts, with intervals
[0.5s, 1s, 2s].
@@ -346,14 +349,13 @@ class MatrixFederationHttpClient:
NB: the long retry algorithm takes over 20 minutes to complete, with
a default timeout of 60s!
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
- backoff_on_404 (bool): Back off if we get a 404
+ backoff_on_404: Back off if we get a 404
Returns:
- twisted.web.client.Response: resolves with the HTTP
- response object on success.
+ Resolves with the HTTP response object on success.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
@@ -404,7 +406,7 @@ class MatrixFederationHttpClient:
)
# Inject the span into the headers
- headers_dict = {}
+ headers_dict = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers_dict, request.destination)
headers_dict[b"User-Agent"] = [self.version_string_bytes]
@@ -435,7 +437,7 @@ class MatrixFederationHttpClient:
data = encode_canonical_json(json)
producer = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator
- )
+ ) # type: Optional[IBodyProducer]
else:
producer = None
auth_headers = self.build_auth_headers(
@@ -524,14 +526,16 @@ class MatrixFederationHttpClient:
)
body = None
- e = HttpResponseException(response.code, response_phrase, body)
+ exc = HttpResponseException(
+ response.code, response_phrase, body
+ )
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
if response.code == 429:
- raise RequestSendFailed(e, can_retry=True) from e
+ raise RequestSendFailed(exc, can_retry=True) from exc
else:
- raise e
+ raise exc
break
except RequestSendFailed as e:
@@ -582,22 +586,27 @@ class MatrixFederationHttpClient:
return response
def build_auth_headers(
- self, destination, method, url_bytes, content=None, destination_is=None
- ):
+ self,
+ destination: Optional[bytes],
+ method: bytes,
+ url_bytes: bytes,
+ content: Optional[JsonDict] = None,
+ destination_is: Optional[bytes] = None,
+ ) -> List[bytes]:
"""
Builds the Authorization headers for a federation request
Args:
- destination (bytes|None): The destination homeserver of the request.
+ destination: The destination homeserver of the request.
May be None if the destination is an identity server, in which case
destination_is must be non-None.
- method (bytes): The HTTP method of the request
- url_bytes (bytes): The URI path of the request
- content (object): The body of the request
- destination_is (bytes): As 'destination', but if the destination is an
+ method: The HTTP method of the request
+ url_bytes: The URI path of the request
+ content: The body of the request
+ destination_is: As 'destination', but if the destination is an
identity server
Returns:
- list[bytes]: a list of headers to be added as "Authorization:" headers
+ A list of headers to be added as "Authorization:" headers
"""
request = {
"method": method.decode("ascii"),
@@ -629,33 +638,32 @@ class MatrixFederationHttpClient:
async def put_json(
self,
- destination,
- path,
- args={},
- data={},
- json_data_callback=None,
- long_retries=False,
- timeout=None,
- ignore_backoff=False,
- backoff_on_404=False,
- try_trailing_slash_on_400=False,
- ):
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ data: Optional[JsonDict] = None,
+ json_data_callback: Optional[Callable[[], JsonDict]] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ ) -> Union[JsonDict, list]:
""" Sends the specified json data using PUT
Args:
- destination (str): The remote server to send the HTTP request
- to.
- path (str): The HTTP path.
- args (dict): query params
- data (dict): A dict containing the data that will be used as
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path.
+ args: query params
+ data: A dict containing the data that will be used as
the request body. This will be encoded as JSON.
- json_data_callback (callable): A callable returning the dict to
+ json_data_callback: A callable returning the dict to
use as the request body.
- long_retries (bool): whether to use the long retry algorithm. See
+ long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -663,19 +671,19 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
- backoff_on_404 (bool): True if we should count a 404 response as
+ backoff_on_404: True if we should count a 404 response as
a failure of the server (and should therefore back off future
requests).
- try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
+ try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end
of the request. Workaround for #3622 in Synapse <= v0.99.3. This
will be attempted before backing off if backing off has been
enabled.
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
+ Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@@ -721,29 +729,28 @@ class MatrixFederationHttpClient:
async def post_json(
self,
- destination,
- path,
- data={},
- long_retries=False,
- timeout=None,
- ignore_backoff=False,
- args={},
- ):
+ destination: str,
+ path: str,
+ data: Optional[JsonDict] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryArgs] = None,
+ ) -> Union[JsonDict, list]:
""" Sends the specified json data using POST
Args:
- destination (str): The remote server to send the HTTP request
- to.
+ destination: The remote server to send the HTTP request to.
- path (str): The HTTP path.
+ path: The HTTP path.
- data (dict): A dict containing the data that will be used as
+ data: A dict containing the data that will be used as
the request body. This will be encoded as JSON.
- long_retries (bool): whether to use the long retry algorithm. See
+ long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -751,10 +758,10 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data and
+ ignore_backoff: true to ignore the historical backoff data and
try the request anyway.
- args (dict): query params
+ args: query params
Returns:
dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
@@ -795,26 +802,25 @@ class MatrixFederationHttpClient:
async def get_json(
self,
- destination,
- path,
- args=None,
- retry_on_dns_fail=True,
- timeout=None,
- ignore_backoff=False,
- try_trailing_slash_on_400=False,
- ):
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ ) -> Union[JsonDict, list]:
""" GETs some json from the given host homeserver and path
Args:
- destination (str): The remote server to send the HTTP request
- to.
+ destination: The remote server to send the HTTP request to.
- path (str): The HTTP path.
+ path: The HTTP path.
- args (dict|None): A dictionary used to create query strings, defaults to
+ args: A dictionary used to create query strings, defaults to
None.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -822,14 +828,14 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
- try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
+ try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
+ Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@@ -870,24 +876,23 @@ class MatrixFederationHttpClient:
async def delete_json(
self,
- destination,
- path,
- long_retries=False,
- timeout=None,
- ignore_backoff=False,
- args={},
- ):
+ destination: str,
+ path: str,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryArgs] = None,
+ ) -> Union[JsonDict, list]:
"""Send a DELETE request to the remote expecting some json response
Args:
- destination (str): The remote server to send the HTTP request
- to.
- path (str): The HTTP path.
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path.
- long_retries (bool): whether to use the long retry algorithm. See
+ long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -895,12 +900,12 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data and
+ ignore_backoff: true to ignore the historical backoff data and
try the request anyway.
- args (dict): query params
+ args: query params
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
+ Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@@ -938,25 +943,25 @@ class MatrixFederationHttpClient:
async def get_file(
self,
- destination,
- path,
+ destination: str,
+ path: str,
output_stream,
- args={},
- retry_on_dns_fail=True,
- max_size=None,
- ignore_backoff=False,
- ):
+ args: Optional[QueryArgs] = None,
+ retry_on_dns_fail: bool = True,
+ max_size: Optional[int] = None,
+ ignore_backoff: bool = False,
+ ) -> Tuple[int, Dict[bytes, List[bytes]]]:
"""GETs a file from a given homeserver
Args:
- destination (str): The remote server to send the HTTP request to.
- path (str): The HTTP path to GET.
- output_stream (file): File to write the response body to.
- args (dict): Optional dictionary used to create the query string.
- ignore_backoff (bool): true to ignore the historical backoff data
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path to GET.
+ output_stream: File to write the response body to.
+ args: Optional dictionary used to create the query string.
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
Returns:
- tuple[int, dict]: Resolves with an (int,dict) tuple of
+ Resolves with an (int,dict) tuple of
the file length and a dict of the response headers.
Raises:
@@ -980,7 +985,7 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
try:
- d = _readBodyToFile(response, output_stream, max_size)
+ d = readBodyToFile(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d)
except Exception as e:
@@ -1004,40 +1009,6 @@ class MatrixFederationHttpClient:
return (length, headers)
-class _ReadBodyToFileProtocol(protocol.Protocol):
- def __init__(self, stream, deferred, max_size):
- self.stream = stream
- self.deferred = deferred
- self.length = 0
- self.max_size = max_size
-
- def dataReceived(self, data):
- self.stream.write(data)
- self.length += len(data)
- if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(
- SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- )
- )
- self.deferred = defer.Deferred()
- self.transport.loseConnection()
-
- def connectionLost(self, reason):
- if reason.check(ResponseDone):
- self.deferred.callback(self.length)
- else:
- self.deferred.errback(reason)
-
-
-def _readBodyToFile(response, stream, max_size):
- d = defer.Deferred()
- response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
- return d
-
-
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
@@ -1049,13 +1020,13 @@ def _flatten_response_never_received(e):
return repr(e)
-def check_content_type_is_json(headers):
+def check_content_type_is_json(headers: Headers) -> None:
"""
Check that a set of HTTP headers have a Content-Type header, and that it
is application/json.
Args:
- headers (twisted.web.http_headers.Headers): headers to check
+ headers: headers to check
Raises:
RequestSendFailed: if the Content-Type header is missing or isn't JSON
@@ -1078,18 +1049,3 @@ def check_content_type_is_json(headers):
),
can_retry=False,
)
-
-
-def encode_query_args(args):
- if args is None:
- return b""
-
- encoded_args = {}
- for k, vs in args.items():
- if isinstance(vs, str):
- vs = [vs]
- encoded_args[k] = [v.encode("UTF-8") for v in vs]
-
- query_bytes = urllib.parse.urlencode(encoded_args, True)
-
- return query_bytes.encode("utf8")
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 7a3a5c46ca..55ddebb4fe 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -21,11 +21,7 @@ import synapse
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.rest.admin._base import (
- admin_patterns,
- assert_requester_is_admin,
- historical_admin_path_patterns,
-)
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.rest.admin.devices import (
DeleteDevicesRestServlet,
DeviceRestServlet,
@@ -84,7 +80,7 @@ class VersionServlet(RestServlet):
class PurgeHistoryRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns(
+ PATTERNS = admin_patterns(
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
)
@@ -169,9 +165,7 @@ class PurgeHistoryRestServlet(RestServlet):
class PurgeHistoryStatusRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns(
- "/purge_history_status/(?P<purge_id>[^/]+)"
- )
+ PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
def __init__(self, hs):
"""
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index db9fea263a..e09234c644 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -22,28 +22,6 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
-def historical_admin_path_patterns(path_regex):
- """Returns the list of patterns for an admin endpoint, including historical ones
-
- This is a backwards-compatibility hack. Previously, the Admin API was exposed at
- various paths under /_matrix/client. This function returns a list of patterns
- matching those paths (as well as the new one), so that existing scripts which rely
- on the endpoints being available there are not broken.
-
- Note that this should only be used for existing endpoints: new ones should just
- register for the /_synapse/admin path.
- """
- return [
- re.compile(prefix + path_regex)
- for prefix in (
- "^/_synapse/admin/v1",
- "^/_matrix/client/api/v1/admin",
- "^/_matrix/client/unstable/admin",
- "^/_matrix/client/r0/admin",
- )
- ]
-
-
def admin_patterns(path_regex: str, version: str = "v1"):
"""Returns the list of patterns for an admin endpoint
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index 0b54ca09f4..d0c86b204a 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -16,10 +16,7 @@ import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
-from synapse.rest.admin._base import (
- assert_user_is_admin,
- historical_admin_path_patterns,
-)
+from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
logger = logging.getLogger(__name__)
@@ -28,7 +25,7 @@ class DeleteGroupAdminRestServlet(RestServlet):
"""Allows deleting of local groups
"""
- PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
+ PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
def __init__(self, hs):
self.group_server = hs.get_groups_server_handler()
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index ba50cb876d..c82b4f87d6 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -22,7 +22,6 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
- historical_admin_path_patterns,
)
logger = logging.getLogger(__name__)
@@ -34,10 +33,10 @@ class QuarantineMediaInRoom(RestServlet):
"""
PATTERNS = (
- historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+ admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+
# This path kept around for legacy reasons
- historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
+ admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
def __init__(self, hs):
@@ -63,9 +62,7 @@ class QuarantineMediaByUser(RestServlet):
this server.
"""
- PATTERNS = historical_admin_path_patterns(
- "/user/(?P<user_id>[^/]+)/media/quarantine"
- )
+ PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -90,7 +87,7 @@ class QuarantineMediaByID(RestServlet):
it via this server.
"""
- PATTERNS = historical_admin_path_patterns(
+ PATTERNS = admin_patterns(
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
@@ -116,7 +113,7 @@ class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
- PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media")
+ PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -134,7 +131,7 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/purge_media_cache")
+ PATTERNS = admin_patterns("/purge_media_cache")
def __init__(self, hs):
self.media_repository = hs.get_media_repository()
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index ee345e12ce..353151169a 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -29,7 +29,6 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
- historical_admin_path_patterns,
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester
@@ -44,7 +43,7 @@ class ShutdownRoomRestServlet(RestServlet):
joined to the new room.
"""
- PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
+ PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
def __init__(self, hs):
self.hs = hs
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index fa8d8e6d91..b0ff5e1ead 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -33,8 +33,8 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
- historical_admin_path_patterns,
)
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
@@ -55,7 +55,7 @@ _GET_PUSHERS_ALLOWED_KEYS = {
class UsersRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
def __init__(self, hs):
self.hs = hs
@@ -338,7 +338,7 @@ class UserRegisterServlet(RestServlet):
nonce to the time it was generated, in int seconds.
"""
- PATTERNS = historical_admin_path_patterns("/register")
+ PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60
def __init__(self, hs):
@@ -461,7 +461,14 @@ class UserRegisterServlet(RestServlet):
class WhoisRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
+ path_regex = "/whois/(?P<user_id>[^/]*)$"
+ PATTERNS = (
+ admin_patterns(path_regex)
+ +
+ # URL for spec reason
+ # https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid
+ client_patterns("/admin" + path_regex, v1=True)
+ )
def __init__(self, hs):
self.hs = hs
@@ -485,7 +492,7 @@ class WhoisRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
+ PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@@ -516,7 +523,7 @@ class DeactivateAccountRestServlet(RestServlet):
class AccountValidityRenewServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
+ PATTERNS = admin_patterns("/account_validity/validity$")
def __init__(self, hs):
"""
@@ -559,9 +566,7 @@ class ResetPasswordRestServlet(RestServlet):
200 OK with empty object if success otherwise an error.
"""
- PATTERNS = historical_admin_path_patterns(
- "/reset_password/(?P<target_user_id>[^/]*)"
- )
+ PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -603,7 +608,7 @@ class SearchUsersRestServlet(RestServlet):
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
- PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
+ PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.hs = hs
diff --git a/synapse/server.py b/synapse/server.py
index 12a783de17..c82d8f9fad 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -27,7 +27,8 @@ import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
-import twisted
+import twisted.internet.base
+import twisted.internet.tcp
from twisted.mail.smtp import sendmail
from twisted.web.iweb import IPolicyForHTTPS
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 01d9dbb36f..dcdaf09682 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@@ -350,6 +350,38 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
+ async def get_local_current_membership_for_user_in_room(
+ self, user_id: str, room_id: str
+ ) -> Tuple[Optional[str], Optional[str]]:
+ """Retrieve the current local membership state and event ID for a user in a room.
+
+ Args:
+ user_id: The ID of the user.
+ room_id: The ID of the room.
+
+ Returns:
+ A tuple of (membership_type, event_id). Both will be None if a
+ room_id/user_id pair is not found.
+ """
+ # Paranoia check.
+ if not self.hs.is_mine_id(user_id):
+ raise Exception(
+ "Cannot call 'get_local_current_membership_for_user_in_room' on "
+ "non-local user %s" % (user_id,),
+ )
+
+ results_dict = await self.db_pool.simple_select_one(
+ "local_current_membership",
+ {"room_id": room_id, "user_id": user_id},
+ ("membership", "event_id"),
+ allow_none=True,
+ desc="get_local_current_membership_for_user_in_room",
+ )
+ if not results_dict:
+ return None, None
+
+ return results_dict.get("membership"), results_dict.get("event_id")
+
@cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
diff --git a/synctl b/synctl
index 9395ebd048..cfa9cec0c4 100755
--- a/synctl
+++ b/synctl
@@ -358,6 +358,13 @@ def main():
for worker in workers:
env = os.environ.copy()
+ # Skip starting a worker if its already running
+ if os.path.exists(worker.pidfile) and pid_running(
+ int(open(worker.pidfile).read())
+ ):
+ print(worker.app + " already running")
+ continue
+
if worker.cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index b4fa02acc4..e880d32be6 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -89,6 +89,14 @@ class TestMappingProviderExtra(TestMappingProvider):
return {"phone": userinfo["phone"]}
+class TestMappingProviderFailures(TestMappingProvider):
+ async def map_user_attributes(self, userinfo, token, failures):
+ return {
+ "localpart": userinfo["username"] + (str(failures) if failures else ""),
+ "display_name": None,
+ }
+
+
def simple_async_mock(return_value=None, raises=None):
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
@@ -152,6 +160,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error = Mock(return_value=None)
self.handler._sso_handler.render_error = self.render_error
+ # Reduce the number of attempts when generating MXIDs.
+ self.handler._sso_handler._MAP_USERNAME_RETRIES = 3
+
return hs
def metadata_edit(self, values):
@@ -693,7 +704,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
),
MappingException,
)
- self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
+ self.assertEqual(
+ str(e.value),
+ "Could not extract user attributes from SSO response: Mapping provider does not support de-duplicating Matrix IDs",
+ )
@override_config({"oidc_config": {"allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self):
@@ -703,6 +717,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(
store.register_user(user_id=user.to_string(), password_hash=None)
)
+
+ # Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
@@ -715,6 +731,23 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.assertEqual(mxid, "@test_user:test")
+ # Note that a second SSO user can be mapped to the same Matrix ID. (This
+ # requires a unique sub, but something that maps to the same matrix ID,
+ # in this case we'll just use the same username. A more realistic example
+ # would be subs which are email addresses, and mapping from the localpart
+ # of the email, e.g. bob@foo.com and bob@bar.com -> @bob:test.)
+ userinfo = {
+ "sub": "test1",
+ "username": "test_user",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
self.get_success(
@@ -762,6 +795,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "föö",
}
token = {}
+
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
@@ -769,3 +803,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
MappingException,
)
self.assertEqual(str(e.value), "localpart is invalid: föö")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "module": __name__ + ".TestMappingProviderFailures"
+ }
+ }
+ }
+ )
+ def test_map_userinfo_to_user_retries(self):
+ """The mapping provider can retry generating an MXID if the MXID is already in use."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+ userinfo = {
+ "sub": "test",
+ "username": "test_user",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ # test_user is already taken, so test_user1 gets registered instead.
+ self.assertEqual(mxid, "@test_user1:test")
+
+ # Register all of the potential users for a particular username.
+ self.get_success(
+ store.register_user(user_id="@tester:test", password_hash=None)
+ )
+ for i in range(1, 3):
+ self.get_success(
+ store.register_user(user_id="@tester%d:test" % i, password_hash=None)
+ )
+
+ # Now attempt to map to a username, this will fail since all potential usernames are taken.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ }
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(
+ str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ )
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 1cac00ea37..4f6a912ac4 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -656,7 +656,7 @@ class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
return user_one, user_two, user_three, user_three_token
def expire(self, user_id_to_expire, admin_tok):
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
request_data = {
"user_id": user_id_to_expire,
"expiration_ts": 0,
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 898e43411e..4f76f8f768 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -100,7 +100,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token))
# Now delete the group
- url = "/admin/delete_group/" + group_id
+ url = "/_synapse/admin/v1/delete_group/" + group_id
request, channel = self.make_request(
"POST",
url.encode("ascii"),
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 54824a5410..46933a0493 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -78,7 +78,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
)
# Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
+ url = "/_synapse/admin/v1/shutdown_room/" + room_id
request, channel = self.make_request(
"POST",
url.encode("ascii"),
@@ -112,7 +112,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
+ url = "/_synapse/admin/v1/shutdown_room/" + room_id
request, channel = self.make_request(
"POST",
url.encode("ascii"),
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 9661af7e79..54d46f4bd3 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -41,7 +41,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- self.url = "/_matrix/client/r0/admin/register"
+ self.url = "/_synapse/admin/v1/register"
self.registration_handler = Mock()
self.identity_handler = Mock()
@@ -1768,3 +1768,111 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# though the MAU limit would stop the user doing so.
puppet_token = self._get_token()
self.helper.join(room_id, user=self.other_user, tok=puppet_token)
+
+
+class WhoisRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url1 = "/_synapse/admin/v1/whois/%s" % urllib.parse.quote(self.other_user)
+ self.url2 = "/_matrix/client/r0/admin/whois/%s" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to get information of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url1, b"{}")
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request("GET", self.url2, b"{}")
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_not_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ self.register_user("user2", "pass")
+ other_user2_token = self.login("user2", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url1, access_token=other_user2_token,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "GET", self.url2, access_token=other_user2_token,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+ url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
+
+ request, channel = self.make_request(
+ "GET", url1, access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only whois a local user", channel.json_body["error"])
+
+ request, channel = self.make_request(
+ "GET", url2, access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only whois a local user", channel.json_body["error"])
+
+ def test_get_whois_admin(self):
+ """
+ The lookup should succeed for an admin.
+ """
+ request, channel = self.make_request(
+ "GET", self.url1, access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ self.assertIn("devices", channel.json_body)
+
+ request, channel = self.make_request(
+ "GET", self.url2, access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ self.assertIn("devices", channel.json_body)
+
+ def test_get_whois_user(self):
+ """
+ The lookup should succeed for a normal user looking up their own information.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url1, access_token=other_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ self.assertIn("devices", channel.json_body)
+
+ request, channel = self.make_request(
+ "GET", self.url2, access_token=other_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ self.assertIn("devices", channel.json_body)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 86f4de1a6a..2272caa048 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -381,7 +381,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
params = {"user_id": user_id}
request_data = json.dumps(params)
request, channel = self.make_request(
@@ -401,7 +401,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
params = {
"user_id": user_id,
"expiration_ts": 0,
@@ -428,7 +428,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
params = {
"user_id": user_id,
"expiration_ts": 0,
@@ -492,7 +492,7 @@ class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
# Ensure the admin never expires
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
params = {
"user_id": admin_id,
"expiration_ts": 999999999999,
@@ -532,7 +532,7 @@ class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(replicated_content)
# Expire the user
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
params = {
"user_id": user_id,
"expiration_ts": 0,
@@ -565,7 +565,7 @@ class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertIsNone(replicated_content)
# Now renew the user, and check they get replicated again to the identity server
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
params = {
"user_id": user_id,
"expiration_ts": 99999999999,
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 6bdde1a2ba..a69117c5a9 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -416,7 +416,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
self.reactor,
self.site,
"GET",
- "/_matrix/client/r0/admin/users/" + self.user_id,
+ "/_synapse/admin/v1/users/" + self.user_id,
access_token=access_token,
custom_headers=headers1.items(),
**make_request_args,
diff --git a/tests/unittest.py b/tests/unittest.py
index c7c889c405..a9d59e31f7 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -554,7 +554,7 @@ class HomeserverTestCase(TestCase):
self.hs.config.registration_shared_secret = "shared"
# Create the user
- request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
+ request, channel = self.make_request("GET", "/_synapse/admin/v1/register")
self.assertEqual(channel.code, 200, msg=channel.result)
nonce = channel.json_body["nonce"]
@@ -580,7 +580,7 @@ class HomeserverTestCase(TestCase):
}
)
request, channel = self.make_request(
- "POST", "/_matrix/client/r0/admin/register", body.encode("utf8")
+ "POST", "/_synapse/admin/v1/register", body.encode("utf8")
)
self.assertEqual(channel.code, 200, channel.json_body)
|