summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--INSTALL.md261
-rw-r--r--README.rst2
-rw-r--r--changelog.d/8856.misc1
-rw-r--r--changelog.d/8977.bugfix1
-rw-r--r--changelog.d/8980.misc1
-rw-r--r--changelog.d/8987.doc1
-rw-r--r--changelog.d/8998.misc1
-rw-r--r--changelog.d/8999.misc1
-rw-r--r--changelog.d/9002.doc1
-rw-r--r--mypy.ini12
-rw-r--r--synapse/crypto/context_factory.py2
-rw-r--r--synapse/crypto/event_signing.py29
-rw-r--r--synapse/crypto/keyring.py206
-rw-r--r--synapse/federation/transport/server.py2
-rw-r--r--synapse/handlers/cas_handler.py112
-rw-r--r--synapse/handlers/groups_local.py2
-rw-r--r--synapse/handlers/initial_sync.py4
-rw-r--r--synapse/handlers/sso.py4
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/rest/client/v2_alpha/groups.py48
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py9
-rw-r--r--synapse/storage/__init__.py9
-rw-r--r--synapse/storage/_base.py36
-rw-r--r--synapse/storage/background_updates.py111
-rw-r--r--synapse/storage/databases/main/keys.py10
-rw-r--r--synapse/storage/keys.py5
-rw-r--r--synapse/storage/prepare_database.py104
-rw-r--r--synapse/storage/purge_events.py11
-rw-r--r--synapse/storage/relations.py44
-rw-r--r--synapse/storage/state.py35
-rw-r--r--tests/crypto/test_keyring.py10
-rw-r--r--tests/federation/transport/__init__.py0
-rw-r--r--tests/federation/transport/test_server.py39
-rw-r--r--tests/handlers/test_cas.py121
34 files changed, 791 insertions, 446 deletions
diff --git a/INSTALL.md b/INSTALL.md

index eb5f506de9..598ddceb8c 100644 --- a/INSTALL.md +++ b/INSTALL.md
@@ -1,19 +1,44 @@ -- [Choosing your server name](#choosing-your-server-name) -- [Picking a database engine](#picking-a-database-engine) -- [Installing Synapse](#installing-synapse) - - [Installing from source](#installing-from-source) - - [Platform-Specific Instructions](#platform-specific-instructions) - - [Prebuilt packages](#prebuilt-packages) -- [Setting up Synapse](#setting-up-synapse) - - [TLS certificates](#tls-certificates) - - [Client Well-Known URI](#client-well-known-uri) - - [Email](#email) - - [Registering a user](#registering-a-user) - - [Setting up a TURN server](#setting-up-a-turn-server) - - [URL previews](#url-previews) -- [Troubleshooting Installation](#troubleshooting-installation) - -# Choosing your server name +# Installation Instructions + +There are 3 steps to follow under **Installation Instructions**. + +- [Installation Instructions](#installation-instructions) + - [Choosing your server name](#choosing-your-server-name) + - [Installing Synapse](#installing-synapse) + - [Installing from source](#installing-from-source) + - [Platform-Specific Instructions](#platform-specific-instructions) + - [Debian/Ubuntu/Raspbian](#debianubunturaspbian) + - [ArchLinux](#archlinux) + - [CentOS/Fedora](#centosfedora) + - [macOS](#macos) + - [OpenSUSE](#opensuse) + - [OpenBSD](#openbsd) + - [Windows](#windows) + - [Prebuilt packages](#prebuilt-packages) + - [Docker images and Ansible playbooks](#docker-images-and-ansible-playbooks) + - [Debian/Ubuntu](#debianubuntu) + - [Matrix.org packages](#matrixorg-packages) + - [Downstream Debian packages](#downstream-debian-packages) + - [Downstream Ubuntu packages](#downstream-ubuntu-packages) + - [Fedora](#fedora) + - [OpenSUSE](#opensuse-1) + - [SUSE Linux Enterprise Server](#suse-linux-enterprise-server) + - [ArchLinux](#archlinux-1) + - [Void Linux](#void-linux) + - [FreeBSD](#freebsd) + - [OpenBSD](#openbsd-1) + - [NixOS](#nixos) + - [Setting up Synapse](#setting-up-synapse) + - [Using PostgreSQL](#using-postgresql) + - [TLS certificates](#tls-certificates) + - [Client Well-Known URI](#client-well-known-uri) + - [Email](#email) + - [Registering a user](#registering-a-user) + - [Setting up a TURN server](#setting-up-a-turn-server) + - [URL previews](#url-previews) + - [Troubleshooting Installation](#troubleshooting-installation) + +## Choosing your server name It is important to choose the name for your server before you install Synapse, because it cannot be changed later. @@ -29,28 +54,9 @@ that your email address is probably `user@example.com` rather than `user@email.example.com`) - but doing so may require more advanced setup: see [Setting up Federation](docs/federate.md). -# Picking a database engine +## Installing Synapse -Synapse offers two database engines: - * [PostgreSQL](https://www.postgresql.org) - * [SQLite](https://sqlite.org/) - -Almost all installations should opt to use PostgreSQL. Advantages include: - -* significant performance improvements due to the superior threading and - caching model, smarter query optimiser -* allowing the DB to be run on separate hardware - -For information on how to install and use PostgreSQL, please see -[docs/postgres.md](docs/postgres.md) - -By default Synapse uses SQLite and in doing so trades performance for convenience. -SQLite is only recommended in Synapse for testing purposes or for servers with -light workloads. - -# Installing Synapse - -## Installing from source +### Installing from source (Prebuilt packages are available for some platforms - see [Prebuilt packages](#prebuilt-packages).) @@ -68,7 +74,7 @@ these on various platforms. To install the Synapse homeserver run: -``` +```sh mkdir -p ~/synapse virtualenv -p python3 ~/synapse/env source ~/synapse/env/bin/activate @@ -85,7 +91,7 @@ prefer. This Synapse installation can then be later upgraded by using pip again with the update flag: -``` +```sh source ~/synapse/env/bin/activate pip install -U matrix-synapse ``` @@ -93,7 +99,7 @@ pip install -U matrix-synapse Before you can start Synapse, you will need to generate a configuration file. To do this, run (in your virtualenv, as before): -``` +```sh cd ~/synapse python -m synapse.app.homeserver \ --server-name my.domain.name \ @@ -111,45 +117,43 @@ wise to back them up somewhere safe. (If, for whatever reason, you do need to change your homeserver's keys, you may find that other homeserver have the old key cached. If you update the signing key, you should change the name of the key in the `<server name>.signing.key` file (the second word) to something -different. See the -[spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys) -for more information on key management). +different. See the [spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys) for more information on key management). To actually run your new homeserver, pick a working directory for Synapse to run (e.g. `~/synapse`), and: -``` +```sh cd ~/synapse source env/bin/activate synctl start ``` -### Platform-Specific Instructions +#### Platform-Specific Instructions -#### Debian/Ubuntu/Raspbian +##### Debian/Ubuntu/Raspbian Installing prerequisites on Ubuntu or Debian: -``` -sudo apt-get install build-essential python3-dev libffi-dev \ +```sh +sudo apt install build-essential python3-dev libffi-dev \ python3-pip python3-setuptools sqlite3 \ libssl-dev virtualenv libjpeg-dev libxslt1-dev ``` -#### ArchLinux +##### ArchLinux Installing prerequisites on ArchLinux: -``` +```sh sudo pacman -S base-devel python python-pip \ python-setuptools python-virtualenv sqlite3 ``` -#### CentOS/Fedora +##### CentOS/Fedora Installing prerequisites on CentOS 8 or Fedora>26: -``` +```sh sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ libwebp-devel tk-devel redhat-rpm-config \ python3-virtualenv libffi-devel openssl-devel @@ -158,7 +162,7 @@ sudo dnf groupinstall "Development Tools" Installing prerequisites on CentOS 7 or Fedora<=25: -``` +```sh sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \ python3-virtualenv libffi-devel openssl-devel @@ -170,11 +174,11 @@ uses SQLite 3.7. You may be able to work around this by installing a more recent SQLite version, but it is recommended that you instead use a Postgres database: see [docs/postgres.md](docs/postgres.md). -#### macOS +##### macOS Installing prerequisites on macOS: -``` +```sh xcode-select --install sudo easy_install pip sudo pip install virtualenv @@ -184,22 +188,22 @@ brew install pkg-config libffi On macOS Catalina (10.15) you may need to explicitly install OpenSSL via brew and inform `pip` about it so that `psycopg2` builds: -``` +```sh brew install openssl@1.1 export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/ ``` -#### OpenSUSE +##### OpenSUSE Installing prerequisites on openSUSE: -``` +```sh sudo zypper in -t pattern devel_basis sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \ python-devel libffi-devel libopenssl-devel libjpeg62-devel ``` -#### OpenBSD +##### OpenBSD A port of Synapse is available under `net/synapse`. The filesystem underlying the homeserver directory (defaults to `/var/synapse`) has to be @@ -213,73 +217,72 @@ mounted with `wxallowed` (cf. `mount(8)`). Creating a `WRKOBJDIR` for building python under `/usr/local` (which on a default OpenBSD installation is mounted with `wxallowed`): -``` +```sh doas mkdir /usr/local/pobj_wxallowed ``` Assuming `PORTS_PRIVSEP=Yes` (cf. `bsd.port.mk(5)`) and `SUDO=doas` are configured in `/etc/mk.conf`: -``` +```sh doas chown _pbuild:_pbuild /usr/local/pobj_wxallowed ``` Setting the `WRKOBJDIR` for building python: -``` +```sh echo WRKOBJDIR_lang/python/3.7=/usr/local/pobj_wxallowed \\nWRKOBJDIR_lang/python/2.7=/usr/local/pobj_wxallowed >> /etc/mk.conf ``` Building Synapse: -``` +```sh cd /usr/ports/net/synapse make install ``` -#### Windows +##### Windows If you wish to run or develop Synapse on Windows, the Windows Subsystem For Linux provides a Linux environment on Windows 10 which is capable of using the Debian, Fedora, or source installation methods. More information about WSL can -be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for -Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server +be found at <https://docs.microsoft.com/en-us/windows/wsl/install-win10> for +Windows 10 and <https://docs.microsoft.com/en-us/windows/wsl/install-on-server> for Windows Server. -## Prebuilt packages +### Prebuilt packages As an alternative to installing from source, prebuilt packages are available for a number of platforms. -### Docker images and Ansible playbooks +#### Docker images and Ansible playbooks There is an offical synapse image available at -https://hub.docker.com/r/matrixdotorg/synapse which can be used with +<https://hub.docker.com/r/matrixdotorg/synapse> which can be used with the docker-compose file available at [contrib/docker](contrib/docker). Further information on this including configuration options is available in the README on hub.docker.com. Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a Dockerfile to automate a synapse server in a single Docker image, at -https://hub.docker.com/r/avhost/docker-matrix/tags/ +<https://hub.docker.com/r/avhost/docker-matrix/tags/> Slavi Pantaleev has created an Ansible playbook, which installs the offical Docker image of Matrix Synapse along with many other Matrix-related services (Postgres database, Element, coturn, ma1sd, SSL support, etc.). For more details, see -https://github.com/spantaleev/matrix-docker-ansible-deploy - +<https://github.com/spantaleev/matrix-docker-ansible-deploy> -### Debian/Ubuntu +#### Debian/Ubuntu -#### Matrix.org packages +##### Matrix.org packages Matrix.org provides Debian/Ubuntu packages of the latest stable version of -Synapse via https://packages.matrix.org/debian/. They are available for Debian +Synapse via <https://packages.matrix.org/debian/>. They are available for Debian 9 (Stretch), Ubuntu 16.04 (Xenial), and later. To use them: -``` +```sh sudo apt install -y lsb-release wget apt-transport-https sudo wget -O /usr/share/keyrings/matrix-org-archive-keyring.gpg https://packages.matrix.org/debian/matrix-org-archive-keyring.gpg echo "deb [signed-by=/usr/share/keyrings/matrix-org-archive-keyring.gpg] https://packages.matrix.org/debian/ $(lsb_release -cs) main" | @@ -299,7 +302,7 @@ The fingerprint of the repository signing key (as shown by `gpg /usr/share/keyrings/matrix-org-archive-keyring.gpg`) is `AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`. -#### Downstream Debian packages +##### Downstream Debian packages We do not recommend using the packages from the default Debian `buster` repository at this time, as they are old and suffer from known security @@ -311,49 +314,49 @@ for information on how to use backports. If you are using Debian `sid` or testing, Synapse is available in the default repositories and it should be possible to install it simply with: -``` +```sh sudo apt install matrix-synapse ``` -#### Downstream Ubuntu packages +##### Downstream Ubuntu packages We do not recommend using the packages in the default Ubuntu repository at this time, as they are old and suffer from known security vulnerabilities. The latest version of Synapse can be installed from [our repository](#matrixorg-packages). -### Fedora +#### Fedora Synapse is in the Fedora repositories as `matrix-synapse`: -``` +```sh sudo dnf install matrix-synapse ``` Oleg Girko provides Fedora RPMs at -https://obs.infoserver.lv/project/monitor/matrix-synapse +<https://obs.infoserver.lv/project/monitor/matrix-synapse> -### OpenSUSE +#### OpenSUSE Synapse is in the OpenSUSE repositories as `matrix-synapse`: -``` +```sh sudo zypper install matrix-synapse ``` -### SUSE Linux Enterprise Server +#### SUSE Linux Enterprise Server Unofficial package are built for SLES 15 in the openSUSE:Backports:SLE-15 repository at -https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/ +<https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/> -### ArchLinux +#### ArchLinux The quickest way to get up and running with ArchLinux is probably with the community package -https://www.archlinux.org/packages/community/any/matrix-synapse/, which should pull in most of +<https://www.archlinux.org/packages/community/any/matrix-synapse/>, which should pull in most of the necessary dependencies. pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 ): -``` +```sh sudo pip install --upgrade pip ``` @@ -362,28 +365,28 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly compile it under the right architecture. (This should not be needed if installing under virtualenv): -``` +```sh sudo pip uninstall py-bcrypt sudo pip install py-bcrypt ``` -### Void Linux +#### Void Linux Synapse can be found in the void repositories as 'synapse': -``` +```sh xbps-install -Su xbps-install -S synapse ``` -### FreeBSD +#### FreeBSD Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from: - - Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean` - - Packages: `pkg install py37-matrix-synapse` +- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean` +- Packages: `pkg install py37-matrix-synapse` -### OpenBSD +#### OpenBSD As of OpenBSD 6.7 Synapse is available as a pre-compiled binary. The filesystem underlying the homeserver directory (defaults to `/var/synapse`) has to be @@ -392,20 +395,35 @@ and mounting it to `/var/synapse` should be taken into consideration. Installing Synapse: -``` +```sh doas pkg_add synapse ``` -### NixOS +#### NixOS Robin Lambertz has packaged Synapse for NixOS at: -https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix +<https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix> -# Setting up Synapse +## Setting up Synapse Once you have installed synapse as above, you will need to configure it. -## TLS certificates +### Using PostgreSQL + +By default Synapse uses [SQLite](https://sqlite.org/) and in doing so trades performance for convenience. +SQLite is only recommended in Synapse for testing purposes or for servers with +very light workloads. + +Almost all installations should opt to use [PostgreSQL](https://www.postgresql.org). Advantages include: + +- significant performance improvements due to the superior threading and + caching model, smarter query optimiser +- allowing the DB to be run on separate hardware + +For information on how to install and use PostgreSQL in Synapse, please see +[docs/postgres.md](docs/postgres.md) + +### TLS certificates The default configuration exposes a single HTTP port on the local interface: `http://localhost:8008`. It is suitable for local testing, @@ -419,19 +437,19 @@ The recommended way to do so is to set up a reverse proxy on port Alternatively, you can configure Synapse to expose an HTTPS port. To do so, you will need to edit `homeserver.yaml`, as follows: -* First, under the `listeners` section, uncomment the configuration for the +- First, under the `listeners` section, uncomment the configuration for the TLS-enabled listener. (Remove the hash sign (`#`) at the start of each line). The relevant lines are like this: - ``` - - port: 8448 - type: http - tls: true - resources: - - names: [client, federation] +```yaml + - port: 8448 + type: http + tls: true + resources: + - names: [client, federation] ``` -* You will also need to uncomment the `tls_certificate_path` and +- You will also need to uncomment the `tls_certificate_path` and `tls_private_key_path` lines under the `TLS` section. You will need to manage provisioning of these certificates yourself — Synapse had built-in ACME support, but the ACMEv1 protocol Synapse implements is deprecated, not @@ -446,7 +464,7 @@ so, you will need to edit `homeserver.yaml`, as follows: For a more detailed guide to configuring your server for federation, see [federate.md](docs/federate.md). -## Client Well-Known URI +### Client Well-Known URI Setting up the client Well-Known URI is optional but if you set it up, it will allow users to enter their full username (e.g. `@user:<server_name>`) into clients @@ -457,7 +475,7 @@ about the actual homeserver URL you are using. The URL `https://<server_name>/.well-known/matrix/client` should return JSON in the following format. -``` +```json { "m.homeserver": { "base_url": "https://<matrix.example.com>" @@ -467,7 +485,7 @@ the following format. It can optionally contain identity server information as well. -``` +```json { "m.homeserver": { "base_url": "https://<matrix.example.com>" @@ -484,7 +502,8 @@ Cross-Origin Resource Sharing (CORS) headers. A recommended value would be view it. In nginx this would be something like: -``` + +```nginx location /.well-known/matrix/client { return 200 '{"m.homeserver": {"base_url": "https://<matrix.example.com>"}}'; default_type application/json; @@ -497,11 +516,11 @@ correctly. `public_baseurl` should be set to the URL that clients will use to connect to your server. This is the same URL you put for the `m.homeserver` `base_url` above. -``` +```yaml public_baseurl: "https://<matrix.example.com>" ``` -## Email +### Email It is desirable for Synapse to have the capability to send email. This allows Synapse to send password reset emails, send verifications when an email address @@ -516,7 +535,7 @@ and `notif_from` fields filled out. You may also need to set `smtp_user`, If email is not configured, password reset, registration and notifications via email will be disabled. -## Registering a user +### Registering a user The easiest way to create a new user is to do so from a client like [Element](https://element.io/). @@ -524,7 +543,7 @@ Alternatively you can do so from the command line if you have installed via pip. This can be done as follows: -``` +```sh $ source ~/synapse/env/bin/activate $ synctl start # if not already running $ register_new_matrix_user -c homeserver.yaml http://localhost:8008 @@ -542,12 +561,12 @@ value is generated by `--generate-config`), but it should be kept secret, as anyone with knowledge of it can register users, including admin accounts, on your server even if `enable_registration` is `false`. -## Setting up a TURN server +### Setting up a TURN server For reliable VoIP calls to be routed via this homeserver, you MUST configure a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details. -## URL previews +### URL previews Synapse includes support for previewing URLs, which is disabled by default. To turn it on you must enable the `url_preview_enabled: True` config parameter @@ -561,14 +580,14 @@ This also requires the optional `lxml` python dependency to be installed. This in turn requires the `libxml2` library to be available - on Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for your OS. -# Troubleshooting Installation +### Troubleshooting Installation `pip` seems to leak *lots* of memory during installation. For instance, a Linux host with 512MB of RAM may run out of memory whilst installing Twisted. If this happens, you will have to individually install the dependencies which are failing, e.g.: -``` +```sh pip install twisted ``` diff --git a/README.rst b/README.rst
index d724cf97da..31ae5cc578 100644 --- a/README.rst +++ b/README.rst
@@ -243,6 +243,8 @@ Then update the ``users`` table in the database:: Synapse Development =================== +Join our developer community on Matrix: [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org) + Before setting up a development environment for synapse, make sure you have the system dependencies (such as the python header files) installed - see `Installing from source <INSTALL.md#installing-from-source>`_. diff --git a/changelog.d/8856.misc b/changelog.d/8856.misc new file mode 100644
index 0000000000..1507073e4f --- /dev/null +++ b/changelog.d/8856.misc
@@ -0,0 +1 @@ +Properly store the mapping of external ID to Matrix ID for CAS users. diff --git a/changelog.d/8977.bugfix b/changelog.d/8977.bugfix new file mode 100644
index 0000000000..ae0b6bec14 --- /dev/null +++ b/changelog.d/8977.bugfix
@@ -0,0 +1 @@ +Properly return 400 errors on invalid group IDs. diff --git a/changelog.d/8980.misc b/changelog.d/8980.misc new file mode 100644
index 0000000000..83ef3c5def --- /dev/null +++ b/changelog.d/8980.misc
@@ -0,0 +1 @@ +Add type hints to the base storage code. diff --git a/changelog.d/8987.doc b/changelog.d/8987.doc new file mode 100644
index 0000000000..c6e4932729 --- /dev/null +++ b/changelog.d/8987.doc
@@ -0,0 +1 @@ +Moved instructions for database setup, adjusted heading levels and improved syntax highlighting in [INSTALL.md](../INSTALL.md). Contributed by fossterer. diff --git a/changelog.d/8998.misc b/changelog.d/8998.misc new file mode 100644
index 0000000000..81346694bd --- /dev/null +++ b/changelog.d/8998.misc
@@ -0,0 +1 @@ +Fix `tests.federation.transport.RoomDirectoryFederationTests` and ensure it runs in CI. \ No newline at end of file diff --git a/changelog.d/8999.misc b/changelog.d/8999.misc new file mode 100644
index 0000000000..3987204f06 --- /dev/null +++ b/changelog.d/8999.misc
@@ -0,0 +1 @@ +Add type hints to the crypto module. diff --git a/changelog.d/9002.doc b/changelog.d/9002.doc new file mode 100644
index 0000000000..26928c9a93 --- /dev/null +++ b/changelog.d/9002.doc
@@ -0,0 +1 @@ +Link the Synapse developer room to the development section in the docs. diff --git a/mypy.ini b/mypy.ini
index 1e88909d46..6a53abfaa9 100644 --- a/mypy.ini +++ b/mypy.ini
@@ -17,6 +17,7 @@ files = synapse/api, synapse/appservice, synapse/config, + synapse/crypto, synapse/event_auth.py, synapse/events/builder.py, synapse/events/validator.py, @@ -70,16 +71,27 @@ files = synapse/server_notices, synapse/spam_checker_api, synapse/state, + synapse/storage/__init__.py, + synapse/storage/_base.py, + synapse/storage/background_updates.py, synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/events.py, + synapse/storage/databases/main/keys.py, synapse/storage/databases/main/pusher.py, synapse/storage/databases/main/registration.py, synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, synapse/storage/database.py, synapse/storage/engines, + synapse/storage/keys.py, synapse/storage/persist_events.py, + synapse/storage/prepare_database.py, + synapse/storage/purge_events.py, + synapse/storage/push_rule.py, + synapse/storage/relations.py, + synapse/storage/roommember.py, synapse/storage/state.py, + synapse/storage/types.py, synapse/storage/util, synapse/streams, synapse/types.py, diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 57fd426e87..74b67b230a 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py
@@ -227,7 +227,7 @@ class ConnectionVerifier: # This code is based on twisted.internet.ssl.ClientTLSOptions. - def __init__(self, hostname: bytes, verify_certs): + def __init__(self, hostname: bytes, verify_certs: bool): self._verify_certs = verify_certs _decoded = hostname.decode("ascii") diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 0422c43fab..8fb116ae18 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py
@@ -18,7 +18,7 @@ import collections.abc import hashlib import logging -from typing import Dict +from typing import Any, Callable, Dict, Tuple from canonicaljson import encode_canonical_json from signedjson.sign import sign_json @@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion +from synapse.events import EventBase from synapse.events.utils import prune_event, prune_event_dict from synapse.types import JsonDict logger = logging.getLogger(__name__) +Hasher = Callable[[bytes], "hashlib._Hash"] -def check_event_content_hash(event, hash_algorithm=hashlib.sha256): + +def check_event_content_hash( + event: EventBase, hash_algorithm: Hasher = hashlib.sha256 +) -> bool: """Check whether the hash for this PDU matches the contents""" name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm) logger.debug( @@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): return message_hash_bytes == expected_hash -def compute_content_hash(event_dict, hash_algorithm): +def compute_content_hash( + event_dict: Dict[str, Any], hash_algorithm: Hasher +) -> Tuple[str, bytes]: """Compute the content hash of an event, which is the hash of the unredacted event. Args: - event_dict (dict): The unredacted event as a dict + event_dict: The unredacted event as a dict hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use to hash the event Returns: - tuple[str, bytes]: A tuple of the name of hash and the hash as raw - bytes. + A tuple of the name of hash and the hash as raw bytes. """ event_dict = dict(event_dict) event_dict.pop("age_ts", None) @@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm): return hashed.name, hashed.digest() -def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): +def compute_event_reference_hash( + event, hash_algorithm: Hasher = hashlib.sha256 +) -> Tuple[str, bytes]: """Computes the event reference hash. This is the hash of the redacted event. Args: - event (FrozenEvent) + event hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use to hash the event Returns: - tuple[str, bytes]: A tuple of the name of hash and the hash as raw - bytes. + A tuple of the name of hash and the hash as raw bytes. """ tmp_event = prune_event(event) event_dict = tmp_event.get_pdu_json() @@ -156,7 +163,7 @@ def add_hashes_and_signatures( event_dict: JsonDict, signature_name: str, signing_key: SigningKey, -): +) -> None: """Add content hash and sign the event Args: diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f23eacc0d7..902128a23c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py
@@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging import urllib from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple import attr from signedjson.key import ( @@ -40,6 +42,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.config.key import TrustedKeyServer from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, @@ -47,11 +50,15 @@ from synapse.logging.context import ( run_in_background, ) from synapse.storage.keys import FetchKeyResult +from synapse.types import JsonDict from synapse.util import unwrapFirstError from synapse.util.async_helpers import yieldable_gather_results from synapse.util.metrics import Measure from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -61,16 +68,17 @@ class VerifyJsonRequest: A request to verify a JSON object. Attributes: - server_name(str): The name of the server to verify against. - - key_ids(set[str]): The set of key_ids to that could be used to verify the - JSON object + server_name: The name of the server to verify against. - json_object(dict): The JSON object to verify. + json_object: The JSON object to verify. - minimum_valid_until_ts (int): time at which we require the signing key to + minimum_valid_until_ts: time at which we require the signing key to be valid. (0 implies we don't care) + request_name: The name of the request. + + key_ids: The set of key_ids to that could be used to verify the JSON object + key_ready (Deferred[str, str, nacl.signing.VerifyKey]): A deferred (server_name, key_id, verify_key) tuple that resolves when a verify key has been fetched. The deferreds' callbacks are run with no @@ -80,12 +88,12 @@ class VerifyJsonRequest: errbacks with an M_UNAUTHORIZED SynapseError. """ - server_name = attr.ib() - json_object = attr.ib() - minimum_valid_until_ts = attr.ib() - request_name = attr.ib() - key_ids = attr.ib(init=False) - key_ready = attr.ib(default=attr.Factory(defer.Deferred)) + server_name = attr.ib(type=str) + json_object = attr.ib(type=JsonDict) + minimum_valid_until_ts = attr.ib(type=int) + request_name = attr.ib(type=str) + key_ids = attr.ib(init=False, type=List[str]) + key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred) def __attrs_post_init__(self): self.key_ids = signature_ids(self.json_object, self.server_name) @@ -96,7 +104,9 @@ class KeyLookupError(ValueError): class Keyring: - def __init__(self, hs, key_fetchers=None): + def __init__( + self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None + ): self.clock = hs.get_clock() if key_fetchers is None: @@ -112,22 +122,26 @@ class Keyring: # completes. # # These are regular, logcontext-agnostic Deferreds. - self.key_downloads = {} + self.key_downloads = {} # type: Dict[str, defer.Deferred] def verify_json_for_server( - self, server_name, json_object, validity_time, request_name - ): + self, + server_name: str, + json_object: JsonDict, + validity_time: int, + request_name: str, + ) -> defer.Deferred: """Verify that a JSON object has been signed by a given server Args: - server_name (str): name of the server which must have signed this object + server_name: name of the server which must have signed this object - json_object (dict): object to be checked + json_object: object to be checked - validity_time (int): timestamp at which we require the signing key to + validity_time: timestamp at which we require the signing key to be valid. (0 implies we don't care) - request_name (str): an identifier for this json object (eg, an event id) + request_name: an identifier for this json object (eg, an event id) for logging. Returns: @@ -138,12 +152,14 @@ class Keyring: requests = (req,) return make_deferred_yieldable(self._verify_objects(requests)[0]) - def verify_json_objects_for_server(self, server_and_json): + def verify_json_objects_for_server( + self, server_and_json: Iterable[Tuple[str, dict, int, str]] + ) -> List[defer.Deferred]: """Bulk verifies signatures of json objects, bulk fetching keys as necessary. Args: - server_and_json (iterable[Tuple[str, dict, int, str]): + server_and_json: Iterable of (server_name, json_object, validity_time, request_name) tuples. @@ -164,13 +180,14 @@ class Keyring: for server_name, json_object, validity_time, request_name in server_and_json ) - def _verify_objects(self, verify_requests): + def _verify_objects( + self, verify_requests: Iterable[VerifyJsonRequest] + ) -> List[defer.Deferred]: """Does the work of verify_json_[objects_]for_server Args: - verify_requests (iterable[VerifyJsonRequest]): - Iterable of verification requests. + verify_requests: Iterable of verification requests. Returns: List<Deferred[None]>: for each input item, a deferred indicating success @@ -182,7 +199,7 @@ class Keyring: key_lookups = [] handle = preserve_fn(_handle_key_deferred) - def process(verify_request): + def process(verify_request: VerifyJsonRequest) -> defer.Deferred: """Process an entry in the request list Adds a key request to key_lookups, and returns a deferred which @@ -222,18 +239,20 @@ class Keyring: return results - async def _start_key_lookups(self, verify_requests): + async def _start_key_lookups( + self, verify_requests: List[VerifyJsonRequest] + ) -> None: """Sets off the key fetches for each verify request Once each fetch completes, verify_request.key_ready will be resolved. Args: - verify_requests (List[VerifyJsonRequest]): + verify_requests: """ try: # map from server name to a set of outstanding request ids - server_to_request_ids = {} + server_to_request_ids = {} # type: Dict[str, Set[int]] for verify_request in verify_requests: server_name = verify_request.server_name @@ -275,11 +294,11 @@ class Keyring: except Exception: logger.exception("Error starting key lookups") - async def wait_for_previous_lookups(self, server_names) -> None: + async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None: """Waits for any previous key lookups for the given servers to finish. Args: - server_names (Iterable[str]): list of servers which we want to look up + server_names: list of servers which we want to look up Returns: Resolves once all key lookups for the given servers have @@ -304,7 +323,7 @@ class Keyring: loop_count += 1 - def _get_server_verify_keys(self, verify_requests): + def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None: """Tries to find at least one key for each verify request For each verify_request, verify_request.key_ready is called back with @@ -312,7 +331,7 @@ class Keyring: with a SynapseError if none of the keys are found. Args: - verify_requests (list[VerifyJsonRequest]): list of verify requests + verify_requests: list of verify requests """ remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} @@ -366,17 +385,19 @@ class Keyring: run_in_background(do_iterations) - async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): + async def _attempt_key_fetches_with_fetcher( + self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest] + ): """Use a key fetcher to attempt to satisfy some key requests Args: - fetcher (KeyFetcher): fetcher to use to fetch the keys - remaining_requests (set[VerifyJsonRequest]): outstanding key requests. + fetcher: fetcher to use to fetch the keys + remaining_requests: outstanding key requests. Any successfully-completed requests will be removed from the list. """ - # dict[str, dict[str, int]]: keys to fetch. + # The keys to fetch. # server_name -> key_id -> min_valid_ts - missing_keys = defaultdict(dict) + missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]] for verify_request in remaining_requests: # any completed requests should already have been removed @@ -438,16 +459,18 @@ class Keyring: remaining_requests.difference_update(completed) -class KeyFetcher: - async def get_keys(self, keys_to_fetch): +class KeyFetcher(metaclass=abc.ABCMeta): + @abc.abstractmethod + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: - keys_to_fetch (dict[str, dict[str, int]]): + keys_to_fetch: the keys to be fetched. server_name -> key_id -> min_valid_ts Returns: - Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: - map from server_name -> key_id -> FetchKeyResult + Map from server_name -> key_id -> FetchKeyResult """ raise NotImplementedError @@ -455,31 +478,35 @@ class KeyFetcher: class StoreKeyFetcher(KeyFetcher): """KeyFetcher impl which fetches keys from our data store""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - async def get_keys(self, keys_to_fetch): + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """see KeyFetcher.get_keys""" - keys_to_fetch = ( + key_ids_to_fetch = ( (server_name, key_id) for server_name, keys_for_server in keys_to_fetch.items() for key_id in keys_for_server.keys() ) - res = await self.store.get_server_verify_keys(keys_to_fetch) - keys = {} + res = await self.store.get_server_verify_keys(key_ids_to_fetch) + keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] for (server_name, key_id), key in res.items(): keys.setdefault(server_name, {})[key_id] = key return keys -class BaseV2KeyFetcher: - def __init__(self, hs): +class BaseV2KeyFetcher(KeyFetcher): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.config = hs.get_config() - async def process_v2_response(self, from_server, response_json, time_added_ms): + async def process_v2_response( + self, from_server: str, response_json: JsonDict, time_added_ms: int + ) -> Dict[str, FetchKeyResult]: """Parse a 'Server Keys' structure from the result of a /key request This is used to parse either the entirety of the response from @@ -493,16 +520,16 @@ class BaseV2KeyFetcher: to /_matrix/key/v2/query. Args: - from_server (str): the name of the server producing this result: either + from_server: the name of the server producing this result: either the origin server for a /_matrix/key/v2/server request, or the notary for a /_matrix/key/v2/query. - response_json (dict): the json-decoded Server Keys response object + response_json: the json-decoded Server Keys response object - time_added_ms (int): the timestamp to record in server_keys_json + time_added_ms: the timestamp to record in server_keys_json Returns: - Deferred[dict[str, FetchKeyResult]]: map from key_id to result object + Map from key_id to result object """ ts_valid_until_ms = response_json["valid_until_ts"] @@ -575,21 +602,22 @@ class BaseV2KeyFetcher: class PerspectivesKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the "perspectives" servers""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_federation_http_client() self.key_servers = self.config.key_servers - async def get_keys(self, keys_to_fetch): + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """see KeyFetcher.get_keys""" - async def get_key(key_server): + async def get_key(key_server: TrustedKeyServer) -> Dict: try: - result = await self.get_server_verify_key_v2_indirect( + return await self.get_server_verify_key_v2_indirect( keys_to_fetch, key_server ) - return result except KeyLookupError as e: logger.warning( "Key lookup failed from %r: %s", key_server.server_name, e @@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ).addErrback(unwrapFirstError) ) - union_of_keys = {} + union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] for result in results: for server_name, keys in result.items(): union_of_keys.setdefault(server_name, {}).update(keys) return union_of_keys - async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): + async def get_server_verify_key_v2_indirect( + self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer + ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: - keys_to_fetch (dict[str, dict[str, int]]): + keys_to_fetch: the keys to be fetched. server_name -> key_id -> min_valid_ts - key_server (synapse.config.key.TrustedKeyServer): notary server to query for - the keys + key_server: notary server to query for the keys Returns: - dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map - from server_name -> key_id -> FetchKeyResult + Map from server_name -> key_id -> FetchKeyResult Raises: KeyLookupError if there was an error processing the entire response from @@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) - keys = {} - added_keys = [] + keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] + added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]] time_now_ms = self.clock.time_msec() + assert isinstance(query_response, dict) for response in query_response["server_keys"]: # do this first, so that we can give useful errors thereafter server_name = response.get("server_name") @@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return keys - def _validate_perspectives_response(self, key_server, response): + def _validate_perspectives_response( + self, key_server: TrustedKeyServer, response: JsonDict + ) -> None: """Optionally check the signature on the result of a /key/query request Args: - key_server (synapse.config.key.TrustedKeyServer): the notary server that - produced this result + key_server: the notary server that produced this result - response (dict): the json-decoded Server Keys response object + response: the json-decoded Server Keys response object """ perspective_name = key_server.server_name perspective_keys = key_server.verify_keys @@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): class ServerKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the origin servers""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_federation_http_client() - async def get_keys(self, keys_to_fetch): + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: - keys_to_fetch (dict[str, iterable[str]]): + keys_to_fetch: the keys to be fetched. server_name -> key_ids Returns: - dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]: - map from server_name -> key_id -> FetchKeyResult + Map from server_name -> key_id -> FetchKeyResult """ results = {} - async def get_key(key_to_fetch_item): + async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None: server_name, key_ids = key_to_fetch_item try: keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) @@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher): await yieldable_gather_results(get_key, keys_to_fetch.items()) return results - async def get_server_verify_key_v2_direct(self, server_name, key_ids): + async def get_server_verify_key_v2_direct( + self, server_name: str, key_ids: Iterable[str] + ) -> Dict[str, FetchKeyResult]: """ Args: - server_name (str): - key_ids (iterable[str]): + server_name: + key_ids: Returns: - dict[str, FetchKeyResult]: map from key ID to lookup result + Map from key ID to lookup result Raises: KeyLookupError if there was a problem making the lookup """ - keys = {} # type: dict[str, FetchKeyResult] + keys = {} # type: Dict[str, FetchKeyResult] for requested_key_id in key_ids: # we may have found this key as a side-effect of asking for another. @@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) + assert isinstance(response, dict) if response["server_name"] != server_name: raise KeyLookupError( "Expected a response for server %r not %r" @@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher): return keys -async def _handle_key_deferred(verify_request) -> None: +async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None: """Waits for the key to become available, and then performs a verification Args: - verify_request (VerifyJsonRequest): + verify_request: Raises: SynapseError if there was a problem performing the verification diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2415e08736..6e4a443787 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -146,7 +146,7 @@ class Authenticator: ): raise FederationDeniedError(origin) - if not json_request["signatures"]: + if origin is None or not json_request["signatures"]: raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index e9891e1316..fca210a5a6 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py
@@ -22,6 +22,7 @@ import attr from twisted.web.client import PartialDownloadError from synapse.api.errors import HttpResponseException +from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.types import UserID, map_username_to_mxid_localpart @@ -62,6 +63,7 @@ class CasHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname + self._store = hs.get_datastore() self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -72,6 +74,9 @@ class CasHandler: self._http_client = hs.get_proxied_http_client() + # identifier for the external_ids table + self._auth_provider_id = "cas" + self._sso_handler = hs.get_sso_handler() def _build_service_param(self, args: Dict[str, str]) -> str: @@ -267,6 +272,14 @@ class CasHandler: This should be the UI Auth session id. """ + # first check if we're doing a UIA + if session: + return await self._sso_handler.complete_sso_ui_auth_request( + self._auth_provider_id, cas_response.username, session, request, + ) + + # otherwise, we're handling a login request. + # Ensure that the attributes of the logged in user meet the required # attributes. for required_attribute, required_value in self._cas_required_attributes.items(): @@ -293,54 +306,79 @@ class CasHandler: ) return - # Pull out the user-agent and IP from the request. - user_agent = request.get_user_agent("") - ip_address = self.hs.get_ip_from_request(request) - - # Get the matrix ID from the CAS username. - user_id = await self._map_cas_user_to_matrix_user( - cas_response, user_agent, ip_address - ) + # Call the mapper to register/login the user - if session: - await self._auth_handler.complete_sso_ui_auth( - user_id, session, request, - ) - else: - # If this not a UI auth request than there must be a redirect URL. - assert client_redirect_url + # If this not a UI auth request than there must be a redirect URL. + assert client_redirect_url is not None - await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url - ) + try: + await self._complete_cas_login(cas_response, request, client_redirect_url) + except MappingException as e: + logger.exception("Could not map user") + self._sso_handler.render_error(request, "mapping_error", str(e)) - async def _map_cas_user_to_matrix_user( - self, cas_response: CasResponse, user_agent: str, ip_address: str, - ) -> str: + async def _complete_cas_login( + self, + cas_response: CasResponse, + request: SynapseRequest, + client_redirect_url: str, + ) -> None: """ - Given a CAS username, retrieve the user ID for it and possibly register the user. + Given a CAS response, complete the login flow + + Retrieves the remote user ID, registers the user if necessary, and serves + a redirect back to the client with a login-token. Args: cas_response: The parsed CAS response. - user_agent: The user agent of the client making the request. - ip_address: The IP address of the client making the request. + request: The request to respond to + client_redirect_url: The redirect URL passed in by the client. - Returns: - The user ID associated with this response. + Raises: + MappingException if there was a problem mapping the response to a user. + RedirectException: some mapping providers may raise this if they need + to redirect to an interstitial page. """ - + # Note that CAS does not support a mapping provider, so the logic is hard-coded. localpart = map_username_to_mxid_localpart(cas_response.username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) - displayname = cas_response.attributes.get(self._cas_displayname_attribute, None) + async def cas_response_to_user_attributes(failures: int) -> UserAttributes: + """ + Map from CAS attributes to user attributes. + """ + # Due to the grandfathering logic matching any previously registered + # mxids it isn't expected for there to be any failures. + if failures: + raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs") + + display_name = cas_response.attributes.get( + self._cas_displayname_attribute, None + ) + + return UserAttributes(localpart=localpart, display_name=display_name) - # If the user does not exist, register it. - if not registered_user_id: - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=displayname, - user_agent_ips=[(user_agent, ip_address)], + async def grandfather_existing_users() -> Optional[str]: + # Since CAS did not always use the user_external_ids table, always + # to attempt to map to existing users. + user_id = UserID(localpart, self._hostname).to_string() + + logger.debug( + "Looking for existing account based on mapped %s", user_id, ) - return registered_user_id + users = await self._store.get_users_by_id_case_insensitive(user_id) + if users: + registered_user_id = list(users.keys())[0] + logger.info("Grandfathering mapping to %s", registered_user_id) + return registered_user_id + + return None + + await self._sso_handler.complete_sso_login_request( + self._auth_provider_id, + cas_response.username, + request, + client_redirect_url, + cas_response_to_user_attributes, + grandfather_existing_users, + ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index abd8d2af44..df29edeb83 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py
@@ -29,7 +29,7 @@ def _create_rerouter(func_name): async def f(self, group_id, *args, **kwargs): if not GroupID.is_valid(group_id): - raise SynapseError(400, "%s was not legal group ID" % (group_id,)) + raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) if self.is_mine_id(group_id): return await getattr(self.groups_server_handler, func_name)( diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index cb11754bf8..fbd8df9dcc 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py
@@ -323,9 +323,7 @@ class InitialSyncHandler(BaseHandler): member_event_id: str, is_peeking: bool, ) -> JsonDict: - room_state = await self.state_store.get_state_for_events([member_event_id]) - - room_state = room_state[member_event_id] + room_state = await self.state_store.get_state_for_event(member_event_id) limit = pagin_config.limit if pagin_config else None if limit is None: diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index b0a8c8c7d2..33cd6bc178 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py
@@ -173,7 +173,7 @@ class SsoHandler: request: SynapseRequest, client_redirect_url: str, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], - grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]], + grandfather_existing_users: Callable[[], Awaitable[Optional[str]]], extra_login_attributes: Optional[JsonDict] = None, ) -> None: """ @@ -241,7 +241,7 @@ class SsoHandler: ) # Check for grandfathering of users. - if not user_id and grandfather_existing_users: + if not user_id: user_id = await grandfather_existing_users() if user_id: # Future logins should also match this user ID. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index fe03004a01..e8947e0f9b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -568,7 +568,7 @@ class SyncHandler: event.event_id, state_filter=state_filter ) if event.is_state(): - state_ids = state_ids.copy() + state_ids = dict(state_ids) state_ids[(event.type, event.state_key)] = event.event_id return state_ids diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index a3bb095c2d..5b5da71815 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py
@@ -15,6 +15,7 @@ # limitations under the License. import logging +from functools import wraps from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -25,6 +26,22 @@ from ._base import client_patterns logger = logging.getLogger(__name__) +def _validate_group_id(f): + """Wrapper to validate the form of the group ID. + + Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. + """ + + @wraps(f) + def wrapper(self, request, group_id, *args, **kwargs): + if not GroupID.is_valid(group_id): + raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) + + return f(self, request, group_id, *args, **kwargs) + + return wrapper + + class GroupServlet(RestServlet): """Get the group profile """ @@ -37,6 +54,7 @@ class GroupServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -47,6 +65,7 @@ class GroupServlet(RestServlet): return 200, group_description + @_validate_group_id async def on_POST(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, category_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, category_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet): return 200, category + @_validate_group_id async def on_PUT(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet): return 200, category + @_validate_group_id async def on_PUT(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, role_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, role_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - if not GroupID.is_valid(group_id): - raise SynapseError(400, "%s was not legal group ID" % (group_id,)) - result = await self.groups_handler.get_rooms_in_group( group_id, requester_user_id ) @@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet): return 200, result + @_validate_group_id async def on_DELETE(self, request, group_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, room_id, config_key): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id + @_validate_group_id async def on_PUT(self, request, group_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -565,6 +603,7 @@ class GroupSelfLeaveServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -589,6 +628,7 @@ class GroupSelfJoinServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -613,6 +653,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -637,6 +678,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f843f02454..c57ac22e58 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, Set +from typing import Dict from signedjson.sign import sign_json @@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource): time_now_ms = self.clock.time_msec() - cache_misses = {} # type: Dict[str, Set[str]] + # Note that the value is unused. + cache_misses = {} # type: Dict[str, Dict[str, int]] for (server_name, key_id, from_server), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] if not results and key_id is not None: - cache_misses.setdefault(server_name, set()).add(key_id) + cache_misses.setdefault(server_name, {})[key_id] = 0 continue if key_id is not None: @@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource): ) if miss: - cache_misses.setdefault(server_name, set()).add(key_id) + cache_misses.setdefault(server_name, {})[key_id] = 0 # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(most_recent_result["key_json"])) else: diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index bbff3c8d5b..c0d9d1240f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py
@@ -27,6 +27,7 @@ There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ +from typing import TYPE_CHECKING from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore @@ -34,14 +35,18 @@ from synapse.storage.persist_events import EventsPersistenceStorage from synapse.storage.purge_events import PurgeEventsStorage from synapse.storage.state import StateGroupStorage -__all__ = ["DataStores", "DataStore"] +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + + +__all__ = ["Databases", "DataStore"] class Storage: """The high level interfaces for talking to various storage layers. """ - def __init__(self, hs, stores: Databases): + def __init__(self, hs: "HomeServer", stores: Databases): # We include the main data store here mainly so that we don't have to # rewrite all the existing code to split it into high vs low level # interfaces. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2b196ded1b..a25c4093bc 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -17,14 +17,18 @@ import logging import random from abc import ABCMeta -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Iterable, Optional, Union from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool -from synapse.types import Collection, get_domain_from_id +from synapse.storage.types import Connection +from synapse.types import Collection, StreamToken, get_domain_from_id from synapse.util import json_decoder +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta): per data store (and not one per physical database). """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine self.db_pool = database self.rand = random.SystemRandom() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: StreamToken, + rows: Iterable[Any], + ) -> None: pass - def _invalidate_state_caches(self, room_id, members_changed): + def _invalidate_state_caches( + self, room_id: str, members_changed: Iterable[str] + ) -> None: """Invalidates caches that are based on the current state, but does not stream invalidations down replication. Args: - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have - changed + room_id: Room where state changed + members_changed: The user_ids of members that have changed """ for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) @@ -64,7 +75,7 @@ class SQLBaseStore(metaclass=ABCMeta): def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] - ): + ) -> None: """Attempts to invalidate the cache of the given name, ignoring if the cache doesn't exist. Mainly used for invalidating caches on workers, where they may not have the cache. @@ -88,12 +99,15 @@ class SQLBaseStore(metaclass=ABCMeta): cache.invalidate(tuple(key)) -def db_to_json(db_content): +def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: """ Take some data from a database row and return a JSON-decoded object. Args: - db_content (memoryview|buffer|bytes|bytearray|unicode) + db_content: The JSON-encoded contents from the database. + + Returns: + The object decoded from JSON. """ # psycopg2 on Python 3 returns memoryview objects, which we need to # cast to bytes to decode diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 810721ebe9..29b8ca676a 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py
@@ -12,29 +12,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.types import Connection +from synapse.types import JsonDict from synapse.util import json_encoder from . import engines +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.storage.database import DatabasePool, LoggingTransaction + logger = logging.getLogger(__name__) class BackgroundUpdatePerformance: """Tracks the how long a background update is taking to update its items""" - def __init__(self, name): + def __init__(self, name: str): self.name = name self.total_item_count = 0 - self.total_duration_ms = 0 - self.avg_item_count = 0 - self.avg_duration_ms = 0 + self.total_duration_ms = 0.0 + self.avg_item_count = 0.0 + self.avg_duration_ms = 0.0 - def update(self, item_count, duration_ms): + def update(self, item_count: int, duration_ms: float) -> None: """Update the stats after doing an update""" self.total_item_count += item_count self.total_duration_ms += duration_ms @@ -44,7 +49,7 @@ class BackgroundUpdatePerformance: self.avg_item_count += 0.1 * (item_count - self.avg_item_count) self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms) - def average_items_per_ms(self): + def average_items_per_ms(self) -> Optional[float]: """An estimate of how long it takes to do a single update. Returns: A duration in ms as a float @@ -58,7 +63,7 @@ class BackgroundUpdatePerformance: # changes in how long the update process takes. return float(self.avg_item_count) / float(self.avg_duration_ms) - def total_items_per_ms(self): + def total_items_per_ms(self) -> Optional[float]: """An estimate of how long it takes to do a single update. Returns: A duration in ms as a float @@ -83,21 +88,25 @@ class BackgroundUpdater: BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, hs, database): + def __init__(self, hs: "HomeServer", database: "DatabasePool"): self._clock = hs.get_clock() self.db_pool = database # if a background update is currently running, its name. self._current_background_update = None # type: Optional[str] - self._background_update_performance = {} - self._background_update_handlers = {} + self._background_update_performance = ( + {} + ) # type: Dict[str, BackgroundUpdatePerformance] + self._background_update_handlers = ( + {} + ) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]] self._all_done = False - def start_doing_background_updates(self): + def start_doing_background_updates(self) -> None: run_as_background_process("background_updates", self.run_background_updates) - async def run_background_updates(self, sleep=True): + async def run_background_updates(self, sleep: bool = True) -> None: logger.info("Starting background schema updates") while True: if sleep: @@ -148,7 +157,7 @@ class BackgroundUpdater: return False - async def has_completed_background_update(self, update_name) -> bool: + async def has_completed_background_update(self, update_name: str) -> bool: """Check if the given background update has finished running. """ if self._all_done: @@ -173,8 +182,7 @@ class BackgroundUpdater: Returns once some amount of work is done. Args: - desired_duration_ms(float): How long we want to spend - updating. + desired_duration_ms: How long we want to spend updating. Returns: True if we have finished running all the background updates, otherwise False """ @@ -220,6 +228,7 @@ class BackgroundUpdater: return False async def _do_background_update(self, desired_duration_ms: float) -> int: + assert self._current_background_update is not None update_name = self._current_background_update logger.info("Starting update batch on background update '%s'", update_name) @@ -273,7 +282,11 @@ class BackgroundUpdater: return len(self._background_update_performance) - def register_background_update_handler(self, update_name, update_handler): + def register_background_update_handler( + self, + update_name: str, + update_handler: Callable[[JsonDict, int], Awaitable[int]], + ): """Register a handler for doing a background update. The handler should take two arguments: @@ -287,12 +300,12 @@ class BackgroundUpdater: The handler is responsible for updating the progress of the update. Args: - update_name(str): The name of the update that this code handles. - update_handler(function): The function that does the update. + update_name: The name of the update that this code handles. + update_handler: The function that does the update. """ self._background_update_handlers[update_name] = update_handler - def register_noop_background_update(self, update_name): + def register_noop_background_update(self, update_name: str) -> None: """Register a noop handler for a background update. This is useful when we previously did a background update, but no @@ -302,10 +315,10 @@ class BackgroundUpdater: also be called to clear the update. Args: - update_name (str): Name of update + update_name: Name of update """ - async def noop_update(progress, batch_size): + async def noop_update(progress: JsonDict, batch_size: int) -> int: await self._end_background_update(update_name) return 1 @@ -313,14 +326,14 @@ class BackgroundUpdater: def register_background_index_update( self, - update_name, - index_name, - table, - columns, - where_clause=None, - unique=False, - psql_only=False, - ): + update_name: str, + index_name: str, + table: str, + columns: Iterable[str], + where_clause: Optional[str] = None, + unique: bool = False, + psql_only: bool = False, + ) -> None: """Helper for store classes to do a background index addition To use: @@ -332,19 +345,19 @@ class BackgroundUpdater: 2. In the Store constructor, call this method Args: - update_name (str): update_name to register for - index_name (str): name of index to add - table (str): table to add index to - columns (list[str]): columns/expressions to include in index - unique (bool): true to make a UNIQUE index + update_name: update_name to register for + index_name: name of index to add + table: table to add index to + columns: columns/expressions to include in index + unique: true to make a UNIQUE index psql_only: true to only create this index on psql databases (useful for virtual sqlite tables) """ - def create_index_psql(conn): + def create_index_psql(conn: Connection) -> None: conn.rollback() # postgres insists on autocommit for the index - conn.set_session(autocommit=True) + conn.set_session(autocommit=True) # type: ignore try: c = conn.cursor() @@ -371,9 +384,9 @@ class BackgroundUpdater: logger.debug("[SQL] %s", sql) c.execute(sql) finally: - conn.set_session(autocommit=False) + conn.set_session(autocommit=False) # type: ignore - def create_index_sqlite(conn): + def create_index_sqlite(conn: Connection) -> None: # Sqlite doesn't support concurrent creation of indexes. # # We don't use partial indices on SQLite as it wasn't introduced @@ -399,7 +412,7 @@ class BackgroundUpdater: c.execute(sql) if isinstance(self.db_pool.engine, engines.PostgresEngine): - runner = create_index_psql + runner = create_index_psql # type: Optional[Callable[[Connection], None]] elif psql_only: runner = None else: @@ -433,7 +446,9 @@ class BackgroundUpdater: "background_updates", keyvalues={"update_name": update_name} ) - async def _background_update_progress(self, update_name: str, progress: dict): + async def _background_update_progress( + self, update_name: str, progress: dict + ) -> None: """Update the progress of a background update Args: @@ -441,20 +456,22 @@ class BackgroundUpdater: progress: The progress of the update. """ - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "background_update_progress", self._background_update_progress_txn, update_name, progress, ) - def _background_update_progress_txn(self, txn, update_name, progress): + def _background_update_progress_txn( + self, txn: "LoggingTransaction", update_name: str, progress: JsonDict + ) -> None: """Update the progress of a background update Args: - txn(cursor): The transaction. - update_name(str): The name of the background update task - progress(dict): The progress of the update. + txn: The transaction. + update_name: The name of the background update task + progress: The progress of the update. """ progress_json = json_encoder.encode(progress) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index f8f4bb9b3f..04ac2d0ced 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py
@@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes from synapse.storage._base import SQLBaseStore from synapse.storage.keys import FetchKeyResult +from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -44,7 +45,7 @@ class KeyStore(SQLBaseStore): ) async def get_server_verify_keys( self, server_name_and_key_ids: Iterable[Tuple[str, str]] - ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]: + ) -> Dict[Tuple[str, str], FetchKeyResult]: """ Args: server_name_and_key_ids: @@ -56,7 +57,7 @@ class KeyStore(SQLBaseStore): """ keys = {} - def _get_keys(txn, batch): + def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None: """Processes a batch of keys to fetch, and adds the result to `keys`.""" # batch_iter always returns tuples so it's safe to do len(batch) @@ -77,13 +78,12 @@ class KeyStore(SQLBaseStore): # `ts_valid_until_ms`. ts_valid_until_ms = 0 - res = FetchKeyResult( + keys[(server_name, key_id)] = FetchKeyResult( verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), valid_until_ts=ts_valid_until_ms, ) - keys[(server_name, key_id)] = res - def _txn(txn): + def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: for batch in batch_iter(server_name_and_key_ids, 50): _get_keys(txn, batch) return keys diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index afd10f7bae..c03871f393 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py
@@ -17,11 +17,12 @@ import logging import attr +from signedjson.types import VerifyKey logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True) class FetchKeyResult: - verify_key = attr.ib() # VerifyKey: the key itself - valid_until_ts = attr.ib() # int: how long we can use this key for + verify_key = attr.ib(type=VerifyKey) # the key itself + valid_until_ts = attr.ib(type=int) # how long we can use this key for diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 459754feab..f91a2eae7a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py
@@ -18,9 +18,10 @@ import logging import os import re from collections import Counter -from typing import Optional, TextIO +from typing import Generator, Iterable, List, Optional, TextIO, Tuple import attr +from typing_extensions import Counter as CounterType from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import LoggingDatabaseConnection @@ -70,7 +71,7 @@ def prepare_database( db_conn: LoggingDatabaseConnection, database_engine: BaseDatabaseEngine, config: Optional[HomeServerConfig], - databases: Collection[str] = ["main", "state"], + databases: Collection[str] = ("main", "state"), ): """Prepares a physical database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -155,7 +156,9 @@ def prepare_database( raise -def _setup_new_database(cur, database_engine, databases): +def _setup_new_database( + cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str] +) -> None: """Sets up the physical database by finding a base set of "full schemas" and then applying any necessary deltas, including schemas from the given data stores. @@ -188,10 +191,9 @@ def _setup_new_database(cur, database_engine, databases): folder as well those in the data stores specified. Args: - cur (Cursor): a database cursor - database_engine (DatabaseEngine) - databases (list[str]): The names of the databases to instantiate - on the given physical database. + cur: a database cursor + database_engine + databases: The names of the databases to instantiate on the given physical database. """ # We're about to set up a brand new database so we check that its @@ -199,12 +201,11 @@ def _setup_new_database(cur, database_engine, databases): database_engine.check_new_database(cur) current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) # First we find the highest full schema version we have valid_versions = [] - for filename in directory_entries: + for filename in os.listdir(current_dir): try: ver = int(filename) except ValueError: @@ -237,7 +238,7 @@ def _setup_new_database(cur, database_engine, databases): for database in databases ) - directory_entries = [] + directory_entries = [] # type: List[_DirectoryListing] for directory in directories: directory_entries.extend( _DirectoryListing(file_name, os.path.join(directory, file_name)) @@ -275,15 +276,15 @@ def _setup_new_database(cur, database_engine, databases): def _upgrade_existing_database( - cur, - current_version, - applied_delta_files, - upgraded, - database_engine, - config, - databases, - is_empty=False, -): + cur: Cursor, + current_version: int, + applied_delta_files: List[str], + upgraded: bool, + database_engine: BaseDatabaseEngine, + config: Optional[HomeServerConfig], + databases: Collection[str], + is_empty: bool = False, +) -> None: """Upgrades an existing physical database. Delta files can either be SQL stored in *.sql files, or python modules @@ -323,21 +324,20 @@ def _upgrade_existing_database( for a version before applying those in the next version. Args: - cur (Cursor) - current_version (int): The current version of the schema. - applied_delta_files (list): A list of deltas that have already been - applied. - upgraded (bool): Whether the current version was generated by having + cur + current_version: The current version of the schema. + applied_delta_files: A list of deltas that have already been applied. + upgraded: Whether the current version was generated by having applied deltas or from full schema file. If `True` the function will never apply delta files for the given `current_version`, since the current_version wasn't generated by applying those delta files. - database_engine (DatabaseEngine) - config (synapse.config.homeserver.HomeServerConfig|None): + database_engine + config: None if we are initialising a blank database, otherwise the application config - databases (list[str]): The names of the databases to instantiate + databases: The names of the databases to instantiate on the given physical database. - is_empty (bool): Is this a blank database? I.e. do we need to run the + is_empty: Is this a blank database? I.e. do we need to run the upgrade portions of the delta scripts. """ if is_empty: @@ -358,6 +358,7 @@ def _upgrade_existing_database( if not is_empty and "main" in databases: from synapse.storage.databases.main import check_database_before_upgrade + assert config is not None check_database_before_upgrade(cur, database_engine, config) start_ver = current_version @@ -388,10 +389,10 @@ def _upgrade_existing_database( ) # Used to check if we have any duplicate file names - file_name_counter = Counter() + file_name_counter = Counter() # type: CounterType[str] # Now find which directories have anything of interest. - directory_entries = [] + directory_entries = [] # type: List[_DirectoryListing] for directory in directories: logger.debug("Looking for schema deltas in %s", directory) try: @@ -445,11 +446,11 @@ def _upgrade_existing_database( module_name = "synapse.storage.v%d_%s" % (v, root_name) with open(absolute_path) as python_file: - module = imp.load_source(module_name, absolute_path, python_file) + module = imp.load_source(module_name, absolute_path, python_file) # type: ignore logger.info("Running script %s", relative_path) - module.run_create(cur, database_engine) + module.run_create(cur, database_engine) # type: ignore if not is_empty: - module.run_upgrade(cur, database_engine, config=config) + module.run_upgrade(cur, database_engine, config=config) # type: ignore elif ext == ".pyc" or file_name == "__pycache__": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package @@ -497,14 +498,15 @@ def _upgrade_existing_database( logger.info("Schema now up to date") -def _apply_module_schemas(txn, database_engine, config): +def _apply_module_schemas( + txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig +) -> None: """Apply the module schemas for the dynamic modules, if any Args: cur: database cursor - database_engine: synapse database engine class - config (synapse.config.homeserver.HomeServerConfig): - application config + database_engine: + config: application config """ for (mod, _config) in config.password_providers: if not hasattr(mod, "get_db_schema_files"): @@ -515,15 +517,19 @@ def _apply_module_schemas(txn, database_engine, config): ) -def _apply_module_schema_files(cur, database_engine, modname, names_and_streams): +def _apply_module_schema_files( + cur: Cursor, + database_engine: BaseDatabaseEngine, + modname: str, + names_and_streams: Iterable[Tuple[str, TextIO]], +) -> None: """Apply the module schemas for a single module Args: cur: database cursor database_engine: synapse database engine class - modname (str): fully qualified name of the module - names_and_streams (Iterable[(str, file)]): the names and streams of - schemas to be applied + modname: fully qualified name of the module + names_and_streams: the names and streams of schemas to be applied """ cur.execute( "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,), @@ -549,7 +555,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) ) -def get_statements(f): +def get_statements(f: Iterable[str]) -> Generator[str, None, None]: statement_buffer = "" in_comment = False # If we're in a /* ... */ style comment @@ -594,17 +600,19 @@ def get_statements(f): statement_buffer = statements[-1].strip() -def executescript(txn, schema_path): +def executescript(txn: Cursor, schema_path: str) -> None: with open(schema_path, "r") as f: execute_statements_from_stream(txn, f) -def execute_statements_from_stream(cur: Cursor, f: TextIO): +def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None: for statement in get_statements(f): cur.execute(statement) -def _get_or_create_schema_state(txn, database_engine): +def _get_or_create_schema_state( + txn: Cursor, database_engine: BaseDatabaseEngine +) -> Optional[Tuple[int, List[str], bool]]: # Bluntly try creating the schema_version tables. schema_path = os.path.join(dir_path, "schema", "schema_version.sql") executescript(txn, schema_path) @@ -612,7 +620,6 @@ def _get_or_create_schema_state(txn, database_engine): txn.execute("SELECT version, upgraded FROM schema_version") row = txn.fetchone() current_version = int(row[0]) if row else None - upgraded = bool(row[1]) if row else None if current_version: txn.execute( @@ -620,6 +627,7 @@ def _get_or_create_schema_state(txn, database_engine): (current_version,), ) applied_deltas = [d for d, in txn] + upgraded = bool(row[1]) return current_version, applied_deltas, upgraded return None @@ -634,5 +642,5 @@ class _DirectoryListing: `file_name` attr is kept first. """ - file_name = attr.ib() - absolute_path = attr.ib() + file_name = attr.ib(type=str) + absolute_path = attr.ib(type=str) diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index bfa0a9fd06..6c359c1aae 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py
@@ -15,7 +15,12 @@ import itertools import logging -from typing import Set +from typing import TYPE_CHECKING, Set + +from synapse.storage.databases import Databases + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -24,10 +29,10 @@ class PurgeEventsStorage: """High level interface for purging rooms and event history. """ - def __init__(self, hs, stores): + def __init__(self, hs: "HomeServer", stores: Databases): self.stores = stores - async def purge_room(self, room_id: str): + async def purge_room(self, room_id: str) -> None: """Deletes all record of a room """ diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index cec96ad6a7..2564f34b47 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py
@@ -14,10 +14,12 @@ # limitations under the License. import logging +from typing import Any, Dict, List, Optional, Tuple import attr from synapse.api.errors import SynapseError +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -27,18 +29,18 @@ class PaginationChunk: """Returned by relation pagination APIs. Attributes: - chunk (list): The rows returned by pagination - next_batch (Any|None): Token to fetch next set of results with, if + chunk: The rows returned by pagination + next_batch: Token to fetch next set of results with, if None then there are no more results. - prev_batch (Any|None): Token to fetch previous set of results with, if + prev_batch: Token to fetch previous set of results with, if None then there are no previous results. """ - chunk = attr.ib() - next_batch = attr.ib(default=None) - prev_batch = attr.ib(default=None) + chunk = attr.ib(type=List[JsonDict]) + next_batch = attr.ib(type=Optional[Any], default=None) + prev_batch = attr.ib(type=Optional[Any], default=None) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: d = {"chunk": self.chunk} if self.next_batch: @@ -59,25 +61,25 @@ class RelationPaginationToken: boundaries of the chunk as pagination tokens. Attributes: - topological (int): The topological ordering of the boundary event - stream (int): The stream ordering of the boundary event. + topological: The topological ordering of the boundary event + stream: The stream ordering of the boundary event. """ - topological = attr.ib() - stream = attr.ib() + topological = attr.ib(type=int) + stream = attr.ib(type=int) @staticmethod - def from_string(string): + def from_string(string: str) -> "RelationPaginationToken": try: t, s = string.split("-") return RelationPaginationToken(int(t), int(s)) except ValueError: raise SynapseError(400, "Invalid token") - def to_string(self): + def to_string(self) -> str: return "%d-%d" % (self.topological, self.stream) - def as_tuple(self): + def as_tuple(self) -> Tuple[Any, ...]: return attr.astuple(self) @@ -89,23 +91,23 @@ class AggregationPaginationToken: aggregation groups, we can just use them as our pagination token. Attributes: - count (int): The count of relations in the boundar group. - stream (int): The MAX stream ordering in the boundary group. + count: The count of relations in the boundary group. + stream: The MAX stream ordering in the boundary group. """ - count = attr.ib() - stream = attr.ib() + count = attr.ib(type=int) + stream = attr.ib(type=int) @staticmethod - def from_string(string): + def from_string(string: str) -> "AggregationPaginationToken": try: c, s = string.split("-") return AggregationPaginationToken(int(c), int(s)) except ValueError: raise SynapseError(400, "Invalid token") - def to_string(self): + def to_string(self) -> str: return "%d-%d" % (self.count, self.stream) - def as_tuple(self): + def as_tuple(self) -> Tuple[Any, ...]: return attr.astuple(self) diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 08a69f2f96..31ccbf23dc 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py
@@ -12,9 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar +from typing import ( + TYPE_CHECKING, + Awaitable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, +) import attr @@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.types import MutableStateMap, StateMap +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.storage.databases import Databases + logger = logging.getLogger(__name__) # Used for generic functions below @@ -330,10 +343,12 @@ class StateGroupStorage: """High level interface to fetching state for event. """ - def __init__(self, hs, stores): + def __init__(self, hs: "HomeServer", stores: "Databases"): self.stores = stores - async def get_state_group_delta(self, state_group: int): + async def get_state_group_delta( + self, state_group: int + ) -> Tuple[Optional[int], Optional[StateMap[str]]]: """Given a state group try to return a previous group and a delta between the old and the new. @@ -341,8 +356,8 @@ class StateGroupStorage: state_group: The state group used to retrieve state deltas. Returns: - Tuple[Optional[int], Optional[StateMap[str]]]: - (prev_group, delta_ids) + A tuple of the previous group and a state map of the event IDs which + make up the delta between the old and new state groups. """ return await self.stores.state.get_state_group_delta(state_group) @@ -436,7 +451,7 @@ class StateGroupStorage: async def get_state_for_events( self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() - ): + ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -472,7 +487,7 @@ class StateGroupStorage: async def get_state_ids_for_events( self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() - ): + ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) @@ -500,7 +515,7 @@ class StateGroupStorage: async def get_state_for_event( self, event_id: str, state_filter: StateFilter = StateFilter.all() - ): + ) -> StateMap[EventBase]: """ Get the state dict corresponding to a particular event @@ -516,7 +531,7 @@ class StateGroupStorage: async def get_state_ids_for_event( self, event_id: str, state_filter: StateFilter = StateFilter.all() - ): + ) -> StateMap[str]: """ Get the state dict corresponding to a particular event diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index d146f2254f..1d65ea2f9c 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py
@@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): return val def test_verify_json_objects_for_server_awaits_previous_requests(self): - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock() kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) @@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): """Tests that we correctly handle key requests for keys we've stored with a null `ts_valid_until_ms` """ - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) kr = keyring.Keyring( @@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): } } - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock(side_effect=get_keys) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) @@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase): } } - mock_fetcher1 = keyring.KeyFetcher() + mock_fetcher1 = Mock() mock_fetcher1.get_keys = Mock(side_effect=get_keys1) - mock_fetcher2 = keyring.KeyFetcher() + mock_fetcher2 = Mock() mock_fetcher2.get_keys = Mock(side_effect=get_keys2) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2)) diff --git a/tests/federation/transport/__init__.py b/tests/federation/transport/__init__.py new file mode 100644
index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/federation/transport/__init__.py
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 212fb79a00..85500e169c 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,34 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - -from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.federation.transport import server -from synapse.util.ratelimitutils import FederationRateLimiter - from tests import unittest from tests.unittest import override_config -class RoomDirectoryFederationTests(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): - class Authenticator: - def authenticate_request(self, request, content): - return defer.succeed("otherserver.nottld") - - ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig()) - server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - +class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): @override_config({"allow_public_rooms_over_federation": False}) def test_blocked_public_room_list_over_federation(self): - channel = self.make_request("GET", "/_matrix/federation/v1/publicRooms") + """Test that unauthenticated requests to the public rooms directory 403 when + allow_public_rooms_over_federation is False. + """ + channel = self.make_request( + "GET", + "/_matrix/federation/v1/publicRooms", + federation_auth_origin=b"example.com", + ) self.assertEquals(403, channel.code) @override_config({"allow_public_rooms_over_federation": True}) def test_open_public_room_list_over_federation(self): - channel = self.make_request("GET", "/_matrix/federation/v1/publicRooms") + """Test that unauthenticated requests to the public rooms directory 200 when + allow_public_rooms_over_federation is True. + """ + channel = self.make_request( + "GET", + "/_matrix/federation/v1/publicRooms", + federation_auth_origin=b"example.com", + ) self.assertEquals(200, channel.code) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py new file mode 100644
index 0000000000..bd7a1b6891 --- /dev/null +++ b/tests/handlers/test_cas.py
@@ -0,0 +1,121 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.handlers.cas_handler import CasResponse + +from tests.test_utils import simple_async_mock +from tests.unittest import HomeserverTestCase + +# These are a few constants that are used as config parameters in the tests. +BASE_URL = "https://synapse/" +SERVER_URL = "https://issuer/" + + +class CasHandlerTestCase(HomeserverTestCase): + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + cas_config = { + "enabled": True, + "server_url": SERVER_URL, + "service_url": BASE_URL, + } + config["cas_config"] = cas_config + + return config + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + + self.handler = hs.get_cas_handler() + + # Reduce the number of attempts when generating MXIDs. + sso_handler = hs.get_sso_handler() + sso_handler._MAP_USERNAME_RETRIES = 3 + + return hs + + def test_map_cas_user_to_user(self): + """Ensure that mapping the CAS user returned from a provider to an MXID works properly.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + cas_response = CasResponse("test_user", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + def test_map_cas_user_to_existing_user(self): + """Existing users can log in with CAS account.""" + store = self.hs.get_datastore() + self.get_success( + store.register_user(user_id="@test_user:test", password_hash=None) + ) + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # Map a user via SSO. + cas_response = CasResponse("test_user", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + # Subsequent calls should map to the same mxid. + auth_handler.complete_sso_login.reset_mock() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + def test_map_cas_user_to_invalid_localpart(self): + """CAS automaps invalid characters to base-64 encoding.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + cas_response = CasResponse("föö", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@f=c3=b6=c3=b6:test", request, "redirect_uri", None + ) + + +def _mock_request(): + """Returns a mock which will stand in as a SynapseRequest""" + return Mock(spec=["getClientIP", "get_user_agent"])