diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml
index f7a4ee7c13..67ccc03f6e 100644
--- a/.github/workflows/twisted_trunk.yml
+++ b/.github/workflows/twisted_trunk.yml
@@ -5,6 +5,9 @@ on:
- cron: 0 8 * * *
workflow_dispatch:
+ # NB: inputs are only present when this workflow is dispatched manually.
+ # (The default below is the default field value in the form to trigger
+ # a manual dispatch). Otherwise the inputs will evaluate to null.
inputs:
twisted_ref:
description: Commit, branch or tag to checkout from upstream Twisted.
@@ -49,7 +52,7 @@ jobs:
extras: "all"
- run: |
poetry remove twisted
- poetry add --extras tls git+https://github.com/twisted/twisted.git#${{ inputs.twisted_ref }}
+ poetry add --extras tls git+https://github.com/twisted/twisted.git#${{ inputs.twisted_ref || 'trunk' }}
poetry install --no-interaction --extras "all test"
- name: Remove warn_unused_ignores from mypy config
run: sed '/warn_unused_ignores = True/d' -i mypy.ini
diff --git a/CHANGES.md b/CHANGES.md
index 95d8227ee0..666cd31ba0 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,8 @@
+# Synapse 1.90.0 (2023-08-15)
+
+No significant changes since 1.90.0rc1.
+
+
# Synapse 1.90.0rc1 (2023-08-08)
### Features
diff --git a/Cargo.lock b/Cargo.lock
index 45e0f116e6..79d9cefcf6 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -132,9 +132,9 @@ dependencies = [
[[package]]
name = "log"
-version = "0.4.19"
+version = "0.4.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
+checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
[[package]]
name = "memchr"
diff --git a/changelog.d/15870.feature b/changelog.d/15870.feature
new file mode 100644
index 0000000000..527220d637
--- /dev/null
+++ b/changelog.d/15870.feature
@@ -0,0 +1 @@
+Implements an admin API to lock an user without deactivating them. Based on [MSC3939](https://github.com/matrix-org/matrix-spec-proposals/pull/3939).
diff --git a/changelog.d/16010.misc b/changelog.d/16010.misc
new file mode 100644
index 0000000000..1e1a148069
--- /dev/null
+++ b/changelog.d/16010.misc
@@ -0,0 +1 @@
+Update dehydrated devices implementation.
diff --git a/changelog.d/16052.bugfix b/changelog.d/16052.bugfix
new file mode 100644
index 0000000000..3c7a60f226
--- /dev/null
+++ b/changelog.d/16052.bugfix
@@ -0,0 +1 @@
+Fix long-standing bug where concurrent requests to change a user's push rules could cause a deadlock. Contributed by Nick @ Beeper (@fizzadar).
diff --git a/changelog.d/16061.misc b/changelog.d/16061.misc
new file mode 100644
index 0000000000..37928b670f
--- /dev/null
+++ b/changelog.d/16061.misc
@@ -0,0 +1 @@
+Fix database performance of read/write worker locks.
diff --git a/changelog.d/16080.bugfix b/changelog.d/16080.bugfix
new file mode 100644
index 0000000000..1ad6fb3c52
--- /dev/null
+++ b/changelog.d/16080.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bu in `/sync` where timeout=0 does not skip caching, resulting in slow calls in cases where there are no new changes. Contributed by @PlasmaIntec.
\ No newline at end of file
diff --git a/changelog.d/16085.misc b/changelog.d/16085.misc
new file mode 100644
index 0000000000..7b7a95edd4
--- /dev/null
+++ b/changelog.d/16085.misc
@@ -0,0 +1 @@
+Override global statement timeout when creating indexes in Postgres.
diff --git a/changelog.d/16089.misc b/changelog.d/16089.misc
new file mode 100644
index 0000000000..8c302e6884
--- /dev/null
+++ b/changelog.d/16089.misc
@@ -0,0 +1 @@
+Fix the type annotation on `run_db_interaction` in the Module API.
\ No newline at end of file
diff --git a/changelog.d/16091.doc b/changelog.d/16091.doc
new file mode 100644
index 0000000000..a043df4efd
--- /dev/null
+++ b/changelog.d/16091.doc
@@ -0,0 +1 @@
+Structured logging docs: add a link to explain the ELK stack
diff --git a/changelog.d/16092.misc b/changelog.d/16092.misc
new file mode 100644
index 0000000000..b520807771
--- /dev/null
+++ b/changelog.d/16092.misc
@@ -0,0 +1 @@
+Clean-up the presence code.
diff --git a/changelog.d/16094.feature b/changelog.d/16094.feature
new file mode 100644
index 0000000000..3be71badb9
--- /dev/null
+++ b/changelog.d/16094.feature
@@ -0,0 +1 @@
+Allow customising the IdP display name, icon, and brand for SAML and CAS providers (in addition to OIDC provider).
diff --git a/changelog.d/16110.misc b/changelog.d/16110.misc
new file mode 100644
index 0000000000..68efe86ddc
--- /dev/null
+++ b/changelog.d/16110.misc
@@ -0,0 +1 @@
+Run `pyupgrade` for Python 3.8+.
diff --git a/changelog.d/16112.misc b/changelog.d/16112.misc
new file mode 100644
index 0000000000..05a58c1348
--- /dev/null
+++ b/changelog.d/16112.misc
@@ -0,0 +1 @@
+Rename pagination and purge locks and add comments to explain why they exist and how they work.
diff --git a/changelog.d/16115.misc b/changelog.d/16115.misc
new file mode 100644
index 0000000000..f325d2a31d
--- /dev/null
+++ b/changelog.d/16115.misc
@@ -0,0 +1 @@
+Attempt to fix the twisted trunk job.
diff --git a/changelog.d/16117.misc b/changelog.d/16117.misc
new file mode 100644
index 0000000000..f33fa6dc17
--- /dev/null
+++ b/changelog.d/16117.misc
@@ -0,0 +1 @@
+Cache token introspection response from OIDC provider.
diff --git a/changelog.d/16123.misc b/changelog.d/16123.misc
new file mode 100644
index 0000000000..b7c6b7c2f2
--- /dev/null
+++ b/changelog.d/16123.misc
@@ -0,0 +1 @@
+Add cache to `get_server_keys_json_for_remote`.
diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py
index 895b2a7af1..710fe25699 100755
--- a/contrib/cmdclient/console.py
+++ b/contrib/cmdclient/console.py
@@ -769,7 +769,7 @@ def main(server_url, identity_server_url, username, token, config_path):
global CONFIG_JSON
CONFIG_JSON = config_path # bit cheeky, but just overwrite the global
try:
- with open(config_path, "r") as config:
+ with open(config_path) as config:
syn_cmd.config = json.load(config)
try:
http_client.verbose = "on" == syn_cmd.config["verbose"]
diff --git a/debian/changelog b/debian/changelog
index ed35abc9ee..ad9a4b3c8c 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+matrix-synapse-py3 (1.90.0) stable; urgency=medium
+
+ * New Synapse release 1.90.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Tue, 15 Aug 2023 11:17:34 +0100
+
matrix-synapse-py3 (1.90.0~rc1) stable; urgency=medium
* New Synapse release 1.90.0rc1.
diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py
index dc824038b5..400a7515aa 100755
--- a/docker/configure_workers_and_start.py
+++ b/docker/configure_workers_and_start.py
@@ -861,7 +861,7 @@ def generate_worker_files(
# Then a worker config file
convert(
"/conf/worker.yaml.j2",
- "/conf/workers/{name}.yaml".format(name=worker_name),
+ f"/conf/workers/{worker_name}.yaml",
**worker_config,
worker_log_config_filepath=log_config_filepath,
using_unix_sockets=using_unix_sockets,
diff --git a/docker/start.py b/docker/start.py
index ebcc599f04..aebc7e4aaa 100755
--- a/docker/start.py
+++ b/docker/start.py
@@ -82,7 +82,7 @@ def generate_config_from_template(
with open(filename) as handle:
value = handle.read()
else:
- log("Generating a random secret for {}".format(secret))
+ log(f"Generating a random secret for {secret}")
value = codecs.encode(os.urandom(32), "hex").decode()
with open(filename, "w") as handle:
handle.write(value)
diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md
index ac4f635099..c269ce6af0 100644
--- a/docs/admin_api/user_admin_api.md
+++ b/docs/admin_api/user_admin_api.md
@@ -146,6 +146,7 @@ Body parameters:
- `admin` - **bool**, optional, defaults to `false`. Whether the user is a homeserver administrator,
granting them access to the Admin API, among other things.
- `deactivated` - **bool**, optional. If unspecified, deactivation state will be left unchanged.
+- `locked` - **bool**, optional. If unspecified, locked state will be left unchanged.
Note: the `password` field must also be set if both of the following are true:
- `deactivated` is set to `false` and the user was previously deactivated (you are reactivating this user)
diff --git a/docs/structured_logging.md b/docs/structured_logging.md
index d43dc9eb6e..002565b223 100644
--- a/docs/structured_logging.md
+++ b/docs/structured_logging.md
@@ -3,7 +3,7 @@
A structured logging system can be useful when your logs are destined for a
machine to parse and process. By maintaining its machine-readable characteristics,
it enables more efficient searching and aggregations when consumed by software
-such as the "ELK stack".
+such as the [ELK stack](https://opensource.com/article/18/9/open-source-log-aggregation-tools).
Synapse's structured logging system is configured via the file that Synapse's
`log_config` config option points to. The file should include a formatter which
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index 2987c9332d..6601bba9f2 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -3025,6 +3025,16 @@ enable SAML login. You can either put your entire pysaml config inline using the
option, or you can specify a path to a psyaml config file with the sub-option `config_path`.
This setting has the following sub-options:
+* `idp_name`: A user-facing name for this identity provider, which is used to
+ offer the user a choice of login mechanisms.
+* `idp_icon`: An optional icon for this identity provider, which is presented
+ by clients and Synapse's own IdP picker page. If given, must be an
+ MXC URI of the format `mxc://<server-name>/<media-id>`. (An easy way to
+ obtain such an MXC URI is to upload an image to an (unencrypted) room
+ and then copy the "url" from the source of the event.)
+* `idp_brand`: An optional brand for this identity provider, allowing clients
+ to style the login flow according to the identity provider in question.
+ See the [spec](https://spec.matrix.org/latest/) for possible options here.
* `sp_config`: the configuration for the pysaml2 Service Provider. See pysaml2 docs for format of config.
Default values will be used for the `entityid` and `service` settings,
so it is not normally necessary to specify them unless you need to
@@ -3176,7 +3186,7 @@ Options for each entry include:
* `idp_icon`: An optional icon for this identity provider, which is presented
by clients and Synapse's own IdP picker page. If given, must be an
- MXC URI of the format mxc://<server-name>/<media-id>. (An easy way to
+ MXC URI of the format `mxc://<server-name>/<media-id>`. (An easy way to
obtain such an MXC URI is to upload an image to an (unencrypted) room
and then copy the "url" from the source of the event.)
@@ -3391,6 +3401,16 @@ Enable Central Authentication Service (CAS) for registration and login.
Has the following sub-options:
* `enabled`: Set this to true to enable authorization against a CAS server.
Defaults to false.
+* `idp_name`: A user-facing name for this identity provider, which is used to
+ offer the user a choice of login mechanisms.
+* `idp_icon`: An optional icon for this identity provider, which is presented
+ by clients and Synapse's own IdP picker page. If given, must be an
+ MXC URI of the format `mxc://<server-name>/<media-id>`. (An easy way to
+ obtain such an MXC URI is to upload an image to an (unencrypted) room
+ and then copy the "url" from the source of the event.)
+* `idp_brand`: An optional brand for this identity provider, allowing clients
+ to style the login flow according to the identity provider in question.
+ See the [spec](https://spec.matrix.org/latest/) for possible options here.
* `server_url`: The URL of the CAS authorization endpoint.
* `displayname_attribute`: The attribute of the CAS response to use as the display name.
If no name is given here, no displayname will be set.
@@ -3631,6 +3651,7 @@ This option has the following sub-options:
* `prefer_local_users`: Defines whether to prefer local users in search query results.
If set to true, local users are more likely to appear above remote users when searching the
user directory. Defaults to false.
+* `show_locked_users`: Defines whether to show locked users in search query results. Defaults to false.
Example configuration:
```yaml
@@ -3638,6 +3659,7 @@ user_directory:
enabled: false
search_all_users: true
prefer_local_users: true
+ show_locked_users: true
```
---
### `user_consent`
diff --git a/mypy.ini b/mypy.ini
index 1038b7d8c7..311a951aa8 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -45,6 +45,13 @@ warn_unused_ignores = False
disallow_untyped_defs = False
disallow_incomplete_defs = False
+[mypy-synapse.util.manhole]
+# This module imports something from Twisted which has a bad annotation in Twisted trunk,
+# but is unannotated in Twisted's latest release. We want to type-ignore the problem
+# in the twisted trunk job, even though it has no effect on normal mypy runs.
+warn_unused_ignores = False
+
+
;; Dependencies without annotations
;; Before ignoring a module, check to see if type stubs are available.
;; The `typeshed` project maintains stubs here:
diff --git a/poetry.lock b/poetry.lock
index 71b47a5805..db1332a04b 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -589,13 +589,13 @@ smmap = ">=3.0.1,<6"
[[package]]
name = "gitpython"
-version = "3.1.31"
+version = "3.1.32"
description = "GitPython is a Python library used to interact with Git repositories"
optional = false
python-versions = ">=3.7"
files = [
- {file = "GitPython-3.1.31-py3-none-any.whl", hash = "sha256:f04893614f6aa713a60cbbe1e6a97403ef633103cdd0ef5eb6efe0deb98dbe8d"},
- {file = "GitPython-3.1.31.tar.gz", hash = "sha256:8ce3bcf69adfdf7c7d503e78fd3b1c492af782d58893b650adb2ac8912ddd573"},
+ {file = "GitPython-3.1.32-py3-none-any.whl", hash = "sha256:e3d59b1c2c6ebb9dfa7a184daf3b6dd4914237e7488a1730a6d8f6f5d0b4187f"},
+ {file = "GitPython-3.1.32.tar.gz", hash = "sha256:8d9b8cb1e80b9735e8717c9362079d3ce4c6e5ddeebedd0361b228c3a67a62f6"},
]
[package.dependencies]
@@ -887,17 +887,17 @@ scripts = ["click (>=6.0)", "twisted (>=16.4.0)"]
[[package]]
name = "isort"
-version = "5.11.5"
+version = "5.12.0"
description = "A Python utility / library to sort Python imports."
optional = false
-python-versions = ">=3.7.0"
+python-versions = ">=3.8.0"
files = [
- {file = "isort-5.11.5-py3-none-any.whl", hash = "sha256:ba1d72fb2595a01c7895a5128f9585a5cc4b6d395f1c8d514989b9a7eb2a8746"},
- {file = "isort-5.11.5.tar.gz", hash = "sha256:6be1f76a507cb2ecf16c7cf14a37e41609ca082330be4e3436a18ef74add55db"},
+ {file = "isort-5.12.0-py3-none-any.whl", hash = "sha256:f84c2818376e66cf843d497486ea8fed8700b340f308f076c6fb1229dff318b6"},
+ {file = "isort-5.12.0.tar.gz", hash = "sha256:8bef7dde241278824a6d83f44a544709b065191b95b6e50894bdc722fcba0504"},
]
[package.extras]
-colors = ["colorama (>=0.4.3,<0.5.0)"]
+colors = ["colorama (>=0.4.3)"]
pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"]
plugins = ["setuptools"]
requirements-deprecated-finder = ["pip-api", "pipreqs"]
@@ -2921,13 +2921,13 @@ files = [
[[package]]
name = "txredisapi"
-version = "1.4.9"
+version = "1.4.10"
description = "non-blocking redis client for python"
optional = true
python-versions = "*"
files = [
- {file = "txredisapi-1.4.9-py3-none-any.whl", hash = "sha256:72e6ad09cc5fffe3bec2e55e5bfb74407bd357565fc212e6003f7e26ef7d8f78"},
- {file = "txredisapi-1.4.9.tar.gz", hash = "sha256:c9607062d05e4d0b8ef84719eb76a3fe7d5ccd606a2acf024429da51d6e84559"},
+ {file = "txredisapi-1.4.10-py3-none-any.whl", hash = "sha256:0a6ea77f27f8cf092f907654f08302a97b48fa35f24e0ad99dfb74115f018161"},
+ {file = "txredisapi-1.4.10.tar.gz", hash = "sha256:7609a6af6ff4619a3189c0adfb86aeda789afba69eb59fc1e19ac0199e725395"},
]
[package.dependencies]
@@ -2936,13 +2936,13 @@ twisted = "*"
[[package]]
name = "types-bleach"
-version = "6.0.0.3"
+version = "6.0.0.4"
description = "Typing stubs for bleach"
optional = false
python-versions = "*"
files = [
- {file = "types-bleach-6.0.0.3.tar.gz", hash = "sha256:8ce7896d4f658c562768674ffcf07492c7730e128018f03edd163ff912bfadee"},
- {file = "types_bleach-6.0.0.3-py3-none-any.whl", hash = "sha256:d43eaf30a643ca824e16e2dcdb0c87ef9226237e2fa3ac4732a50cb3f32e145f"},
+ {file = "types-bleach-6.0.0.4.tar.gz", hash = "sha256:357b0226f65c4f20ab3b13ca8d78a6b91c78aad256d8ec168d4e90fc3303ebd4"},
+ {file = "types_bleach-6.0.0.4-py3-none-any.whl", hash = "sha256:2b8767eb407c286b7f02803678732e522e04db8d56cbc9f1270bee49627eae92"},
]
[[package]]
@@ -2991,13 +2991,13 @@ files = [
[[package]]
name = "types-pillow"
-version = "10.0.0.1"
+version = "10.0.0.2"
description = "Typing stubs for Pillow"
optional = false
python-versions = "*"
files = [
- {file = "types-Pillow-10.0.0.1.tar.gz", hash = "sha256:834a07a04504f8bf37936679bc6a5802945e7644d0727460c0c4d4307967e2a3"},
- {file = "types_Pillow-10.0.0.1-py3-none-any.whl", hash = "sha256:be576b67418f1cb3b93794cf7946581be1009a33a10085b3c132eb0875a819b4"},
+ {file = "types-Pillow-10.0.0.2.tar.gz", hash = "sha256:fe09380ab22d412ced989a067e9ee4af719fa3a47ba1b53b232b46514a871042"},
+ {file = "types_Pillow-10.0.0.2-py3-none-any.whl", hash = "sha256:29d51a3ce6ef51fabf728a504d33b4836187ff14256b2e86996d55c91ab214b1"},
]
[[package]]
diff --git a/pyproject.toml b/pyproject.toml
index ca532e2c7c..86680cb8e5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -89,7 +89,7 @@ manifest-path = "rust/Cargo.toml"
[tool.poetry]
name = "matrix-synapse"
-version = "1.90.0rc1"
+version = "1.90.0"
description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "Apache-2.0"
diff --git a/scripts-dev/build_debian_packages.py b/scripts-dev/build_debian_packages.py
index bb89ba581c..c03e3418c0 100755
--- a/scripts-dev/build_debian_packages.py
+++ b/scripts-dev/build_debian_packages.py
@@ -47,7 +47,7 @@ can be passed on the commandline for debugging.
projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
-class Builder(object):
+class Builder:
def __init__(
self,
redirect_stdout: bool = False,
diff --git a/scripts-dev/check_schema_delta.py b/scripts-dev/check_schema_delta.py
index fee4a8bd3d..467be96fdf 100755
--- a/scripts-dev/check_schema_delta.py
+++ b/scripts-dev/check_schema_delta.py
@@ -43,7 +43,7 @@ def main(force_colors: bool) -> None:
diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None)
# Get the schema version of the local file to check against current schema on develop
- with open("synapse/storage/schema/__init__.py", "r") as file:
+ with open("synapse/storage/schema/__init__.py") as file:
local_schema = file.read()
new_locals: Dict[str, Any] = {}
exec(local_schema, new_locals)
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 63f0b25ddd..5ad334b4d8 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -247,7 +247,7 @@ def main() -> None:
def read_args_from_config(args: argparse.Namespace) -> None:
- with open(args.config, "r") as fh:
+ with open(args.config) as fh:
config = yaml.safe_load(fh)
if not args.server_name:
diff --git a/scripts-dev/release.py b/scripts-dev/release.py
index 89ffba8d92..4ac8eaa889 100755
--- a/scripts-dev/release.py
+++ b/scripts-dev/release.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python
-# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/scripts-dev/sign_json.py b/scripts-dev/sign_json.py
index bb217799fb..00cbaf68f5 100755
--- a/scripts-dev/sign_json.py
+++ b/scripts-dev/sign_json.py
@@ -145,7 +145,7 @@ Example usage:
def read_args_from_config(args: argparse.Namespace) -> None:
- with open(args.config, "r") as fh:
+ with open(args.config) as fh:
config = yaml.safe_load(fh)
if not args.server_name:
args.server_name = config["server_name"]
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 6c1801862b..2f9c22a833 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -25,7 +25,11 @@ from synapse.util.rust import check_rust_lib_up_to_date
from synapse.util.stringutils import strtobool
# Check that we're not running on an unsupported Python version.
-if sys.version_info < (3, 8):
+#
+# Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the
+# if-statement completely.
+py_version = sys.version_info
+if py_version < (3, 8):
print("Synapse requires Python 3.8 or above.")
sys.exit(1)
@@ -78,7 +82,7 @@ try:
except ImportError:
pass
-import synapse.util
+import synapse.util # noqa: E402
__version__ = synapse.util.SYNAPSE_VERSION
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 22c84fbd5b..49242800b8 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -123,7 +123,7 @@ BOOLEAN_COLUMNS = {
"redactions": ["have_censored"],
"room_stats_state": ["is_federatable"],
"rooms": ["is_public", "has_auth_chain_index"],
- "users": ["shadow_banned", "approved"],
+ "users": ["shadow_banned", "approved", "locked"],
"un_partial_stated_event_stream": ["rejection_status_changed"],
"users_who_share_rooms": ["share_private"],
"per_user_experimental_features": ["enabled"],
@@ -1205,10 +1205,10 @@ class CursesProgress(Progress):
self.total_processed = 0
self.total_remaining = 0
- super(CursesProgress, self).__init__()
+ super().__init__()
def update(self, table: str, num_done: int) -> None:
- super(CursesProgress, self).update(table, num_done)
+ super().update(table, num_done)
self.total_processed = 0
self.total_remaining = 0
@@ -1304,7 +1304,7 @@ class TerminalProgress(Progress):
"""Just prints progress to the terminal"""
def update(self, table: str, num_done: int) -> None:
- super(TerminalProgress, self).update(table, num_done)
+ super().update(table, num_done)
data = self.tables[table]
diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py
index 0adf94bba6..f97aecf8d5 100644
--- a/synapse/_scripts/update_synapse_database.py
+++ b/synapse/_scripts/update_synapse_database.py
@@ -38,7 +38,7 @@ class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore # type: ignore [assignment]
def __init__(self, config: HomeServerConfig):
- super(MockHomeserver, self).__init__(
+ super().__init__(
hostname=config.server.server_name,
config=config,
reactor=reactor,
diff --git a/synapse/api/auth/__init__.py b/synapse/api/auth/__init__.py
index 90cfe39d76..bb3f50f2dd 100644
--- a/synapse/api/auth/__init__.py
+++ b/synapse/api/auth/__init__.py
@@ -60,6 +60,7 @@ class Auth(Protocol):
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
+ allow_locked: bool = False,
) -> Requester:
"""Get a registered user's ID.
diff --git a/synapse/api/auth/internal.py b/synapse/api/auth/internal.py
index e2ae198b19..6a5fd44ec0 100644
--- a/synapse/api/auth/internal.py
+++ b/synapse/api/auth/internal.py
@@ -58,6 +58,7 @@ class InternalAuth(BaseAuth):
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
+ allow_locked: bool = False,
) -> Requester:
"""Get a registered user's ID.
@@ -79,7 +80,7 @@ class InternalAuth(BaseAuth):
parent_span = active_span()
with start_active_span("get_user_by_req"):
requester = await self._wrapped_get_user_by_req(
- request, allow_guest, allow_expired
+ request, allow_guest, allow_expired, allow_locked
)
if parent_span:
@@ -107,6 +108,7 @@ class InternalAuth(BaseAuth):
request: SynapseRequest,
allow_guest: bool,
allow_expired: bool,
+ allow_locked: bool,
) -> Requester:
"""Helper for get_user_by_req
@@ -126,6 +128,17 @@ class InternalAuth(BaseAuth):
access_token, allow_expired=allow_expired
)
+ # Deny the request if the user account is locked.
+ if not allow_locked and await self.store.get_user_locked_status(
+ requester.user.to_string()
+ ):
+ raise AuthError(
+ 401,
+ "User account has been locked",
+ errcode=Codes.USER_LOCKED,
+ additional_fields={"soft_logout": True},
+ )
+
# Deny the request if the user account has expired.
# This check is only done for regular users, not appservice ones.
if not allow_expired:
diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py
index bd4fc9c0ee..3a516093f5 100644
--- a/synapse/api/auth/msc3861_delegated.py
+++ b/synapse/api/auth/msc3861_delegated.py
@@ -27,6 +27,7 @@ from twisted.web.http_headers import Headers
from synapse.api.auth.base import BaseAuth
from synapse.api.errors import (
AuthError,
+ Codes,
HttpResponseException,
InvalidClientTokenError,
OAuthInsufficientScopeError,
@@ -38,6 +39,7 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.types import Requester, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.caches.expiringcache import ExpiringCache
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -105,6 +107,14 @@ class MSC3861DelegatedAuth(BaseAuth):
self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
+ self._clock = hs.get_clock()
+ self._token_cache: ExpiringCache[str, IntrospectionToken] = ExpiringCache(
+ cache_name="introspection_token_cache",
+ clock=self._clock,
+ max_len=10000,
+ expiry_ms=5 * 60 * 1000,
+ )
+
if isinstance(auth_method, PrivateKeyJWTWithKid):
# Use the JWK as the client secret when using the private_key_jwt method
assert self._config.jwk, "No JWK provided"
@@ -143,6 +153,20 @@ class MSC3861DelegatedAuth(BaseAuth):
Returns:
The introspection response
"""
+ # check the cache before doing a request
+ introspection_token = self._token_cache.get(token, None)
+
+ if introspection_token:
+ # check the expiration field of the token (if it exists)
+ exp = introspection_token.get("exp", None)
+ if exp:
+ time_now = self._clock.time()
+ expired = time_now > exp
+ if not expired:
+ return introspection_token
+ else:
+ return introspection_token
+
metadata = await self._issuer_metadata.get()
introspection_endpoint = metadata.get("introspection_endpoint")
raw_headers: Dict[str, str] = {
@@ -156,7 +180,10 @@ class MSC3861DelegatedAuth(BaseAuth):
# Fill the body/headers with credentials
uri, raw_headers, body = self._client_auth.prepare(
- method="POST", uri=introspection_endpoint, headers=raw_headers, body=body
+ method="POST",
+ uri=introspection_endpoint,
+ headers=raw_headers,
+ body=body,
)
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
@@ -186,7 +213,17 @@ class MSC3861DelegatedAuth(BaseAuth):
"The introspection endpoint returned an invalid JSON response."
)
- return IntrospectionToken(**resp)
+ expiration = resp.get("exp", None)
+ if expiration:
+ if self._clock.time() > expiration:
+ raise InvalidClientTokenError("Token is expired.")
+
+ introspection_token = IntrospectionToken(**resp)
+
+ # add token to cache
+ self._token_cache[token] = introspection_token
+
+ return introspection_token
async def is_server_admin(self, requester: Requester) -> bool:
return "urn:synapse:admin:*" in requester.scope
@@ -196,6 +233,7 @@ class MSC3861DelegatedAuth(BaseAuth):
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
+ allow_locked: bool = False,
) -> Requester:
access_token = self.get_access_token_from_request(request)
@@ -205,6 +243,17 @@ class MSC3861DelegatedAuth(BaseAuth):
# so that we don't provision the user if they don't have enough permission:
requester = await self.get_user_by_access_token(access_token, allow_expired)
+ # Deny the request if the user account is locked.
+ if not allow_locked and await self.store.get_user_locked_status(
+ requester.user.to_string()
+ ):
+ raise AuthError(
+ 401,
+ "User account has been locked",
+ errcode=Codes.USER_LOCKED,
+ additional_fields={"soft_logout": True},
+ )
+
if not allow_guest and requester.is_guest:
raise OAuthInsufficientScopeError([SCOPE_MATRIX_API])
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index dc32553d0c..bf311b636d 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -18,8 +18,7 @@
"""Contains constants from the specification."""
import enum
-
-from typing_extensions import Final
+from typing import Final
# the max size of a (canonical-json-encoded) event
MAX_PDU_SIZE = 65536
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 3546aaf7c3..7ffd72c42c 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -80,6 +80,8 @@ class Codes(str, Enum):
WEAK_PASSWORD = "M_WEAK_PASSWORD"
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
+ # USER_LOCKED = "M_USER_LOCKED"
+ USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
# Part of MSC3848
# https://github.com/matrix-org/matrix-spec-proposals/pull/3848
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 9152c06bd6..c4e63e7411 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -47,6 +47,10 @@ class CasConfig(Config):
required_attributes
)
+ self.idp_name = cas_config.get("idp_name", "CAS")
+ self.idp_icon = cas_config.get("idp_icon")
+ self.idp_brand = cas_config.get("idp_brand")
+
else:
self.cas_server_url = None
self.cas_service_url = None
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index 49ca663dde..c69e24cf26 100644
--- a/synapse/config/saml2.py
+++ b/synapse/config/saml2.py
@@ -89,8 +89,14 @@ class SAML2Config(Config):
"grandfathered_mxid_source_attribute", "uid"
)
+ # refers to a SAML IdP entity ID
self.saml2_idp_entityid = saml2_config.get("idp_entityid", None)
+ # IdP properties for Matrix clients
+ self.idp_name = saml2_config.get("idp_name", "SAML")
+ self.idp_icon = saml2_config.get("idp_icon")
+ self.idp_brand = saml2_config.get("idp_brand")
+
# user_mapping_provider may be None if the key is present but has no value
ump_dict = saml2_config.get("user_mapping_provider") or {}
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c9e18b91e9..f60ec2ea66 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -35,3 +35,4 @@ class UserDirectoryConfig(Config):
self.user_directory_search_prefer_local_users = user_directory_config.get(
"prefer_local_users", False
)
+ self.show_locked_users = user_directory_config.get("show_locked_users", False)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index a90d99c4d6..f9915e5a3f 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -63,7 +63,7 @@ from synapse.federation.federation_base import (
)
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
-from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
+from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
@@ -1245,7 +1245,7 @@ class FederationServer(FederationBase):
# while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME`
# lock.
async with self._worker_lock_handler.acquire_read_write_lock(
- DELETE_ROOM_LOCK_NAME, room_id, write=False
+ NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
await self._federation_event_handler.on_receive_pdu(
origin, event
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 119c7f8384..0e812a6d8b 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -67,6 +67,7 @@ class AdminHandler:
"name",
"admin",
"deactivated",
+ "locked",
"shadow_banned",
"creation_ts",
"appservice_id",
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index fc467bc7c1..5c71637038 100644
--- a/synapse/handlers/cas.py
+++ b/synapse/handlers/cas.py
@@ -76,12 +76,13 @@ class CasHandler:
self.idp_id = "cas"
# user-facing name of this auth provider
- self.idp_name = "CAS"
+ self.idp_name = hs.config.cas.idp_name
- # we do not currently support brands/icons for CAS auth, but this is required by
- # the SsoIdentityProvider protocol type.
- self.idp_icon = None
- self.idp_brand = None
+ # MXC URI for icon for this auth provider
+ self.idp_icon = hs.config.cas.idp_icon
+
+ # optional brand identifier for this auth provider
+ self.idp_brand = hs.config.cas.idp_brand
self._sso_handler = hs.get_sso_handler()
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b7bf70a72d..5ae427d52c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -385,6 +385,7 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender = hs.get_federation_sender()
self._account_data_handler = hs.get_account_data_handler()
self._storage_controllers = hs.get_storage_controllers()
+ self.db_pool = hs.get_datastores().main.db_pool
self.device_list_updater = DeviceListUpdater(hs, self)
@@ -656,15 +657,17 @@ class DeviceHandler(DeviceWorkerHandler):
device_id: Optional[str],
device_data: JsonDict,
initial_device_display_name: Optional[str] = None,
+ keys_for_device: Optional[JsonDict] = None,
) -> str:
- """Store a dehydrated device for a user. If the user had a previous
- dehydrated device, it is removed.
+ """Store a dehydrated device for a user, optionally storing the keys associated with
+ it as well. If the user had a previous dehydrated device, it is removed.
Args:
user_id: the user that we are storing the device for
device_id: device id supplied by client
device_data: the dehydrated device information
initial_device_display_name: The display name to use for the device
+ keys_for_device: keys for the dehydrated device
Returns:
device id of the dehydrated device
"""
@@ -673,11 +676,16 @@ class DeviceHandler(DeviceWorkerHandler):
device_id,
initial_device_display_name,
)
+
+ time_now = self.clock.time_msec()
+
old_device_id = await self.store.store_dehydrated_device(
- user_id, device_id, device_data
+ user_id, device_id, device_data, time_now, keys_for_device
)
+
if old_device_id is not None:
await self.delete_devices(user_id, [old_device_id])
+
return device_id
async def rehydrate_device(
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 15e94a03cb..17ff8821d9 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -367,19 +367,6 @@ class DeviceMessageHandler:
errcode=Codes.INVALID_PARAM,
)
- # if we have a since token, delete any to-device messages before that token
- # (since we now know that the device has received them)
- deleted = await self.store.delete_messages_for_device(
- user_id, device_id, since_stream_id
- )
- logger.debug(
- "Deleted %d to-device messages up to %d for user_id %s device_id %s",
- deleted,
- since_stream_id,
- user_id,
- device_id,
- )
-
to_token = self.event_sources.get_current_token().to_device_key
messages, stream_id = await self.store.get_messages_for_device(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d485f21e49..a74db1dccf 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -53,7 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
-from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
+from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -1034,7 +1034,7 @@ class EventCreationHandler:
)
async with self._worker_lock_handler.acquire_read_write_lock(
- DELETE_ROOM_LOCK_NAME, room_id, write=False
+ NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
return await self._create_and_send_nonmember_event_locked(
requester=requester,
@@ -1978,7 +1978,7 @@ class EventCreationHandler:
for room_id in room_ids:
async with self._worker_lock_handler.acquire_read_write_lock(
- DELETE_ROOM_LOCK_NAME, room_id, write=False
+ NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
dummy_event_sent = await self._send_dummy_event_for_room(room_id)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index da34658470..1be6ebc6d9 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -24,6 +24,7 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig
from synapse.handlers.room import ShutdownRoomResponse
+from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.admin._base import assert_user_is_admin
@@ -46,9 +47,10 @@ logger = logging.getLogger(__name__)
BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
-PURGE_HISTORY_LOCK_NAME = "purge_history_lock"
-
-DELETE_ROOM_LOCK_NAME = "delete_room_lock"
+# This is used to avoid purging a room several time at the same moment,
+# and also paginating during a purge. Pagination can trigger backfill,
+# which would create old events locally, and would potentially clash with the room delete.
+PURGE_PAGINATION_LOCK_NAME = "purge_pagination_lock"
@attr.s(slots=True, auto_attribs=True)
@@ -363,7 +365,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
async with self._worker_locks.acquire_read_write_lock(
- PURGE_HISTORY_LOCK_NAME, room_id, write=True
+ PURGE_PAGINATION_LOCK_NAME, room_id, write=True
):
await self._storage_controllers.purge_events.purge_history(
room_id, token, delete_local_events
@@ -421,7 +423,10 @@ class PaginationHandler:
force: set true to skip checking for joined users.
"""
async with self._worker_locks.acquire_multi_read_write_lock(
- [(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)],
+ [
+ (PURGE_PAGINATION_LOCK_NAME, room_id),
+ (NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id),
+ ],
write=True,
):
# first check that we have no users in this room
@@ -483,7 +488,7 @@ class PaginationHandler:
room_token = from_token.room_key
async with self._worker_locks.acquire_read_write_lock(
- PURGE_HISTORY_LOCK_NAME, room_id, write=False
+ PURGE_PAGINATION_LOCK_NAME, room_id, write=False
):
(membership, member_event_id) = (None, None)
if not use_admin_priviledge:
@@ -761,7 +766,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
async with self._worker_locks.acquire_read_write_lock(
- PURGE_HISTORY_LOCK_NAME, room_id, write=True
+ PURGE_PAGINATION_LOCK_NAME, room_id, write=True
):
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
self._delete_by_id[
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index cd7df0525f..e8e9db4b91 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -30,9 +30,9 @@ from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
- Awaitable,
Callable,
Collection,
+ ContextManager,
Dict,
Generator,
Iterable,
@@ -44,7 +44,6 @@ from typing import (
)
from prometheus_client import Counter
-from typing_extensions import ContextManager
import synapse.metrics
from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
@@ -54,7 +53,10 @@ from synapse.appservice import ApplicationService
from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background
from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.replication.http.presence import (
ReplicationBumpPresenceActiveTime,
ReplicationPresenceSetState,
@@ -141,6 +143,8 @@ class BasePresenceHandler(abc.ABC):
self.state = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
+ self._presence_enabled = hs.config.server.use_presence
+
self._federation = None
if hs.should_send_federation():
self._federation = hs.get_federation_sender()
@@ -149,6 +153,15 @@ class BasePresenceHandler(abc.ABC):
self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
+ self.VALID_PRESENCE: Tuple[str, ...] = (
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
+ )
+
+ if self._busy_presence_enabled:
+ self.VALID_PRESENCE += (PresenceState.BUSY,)
+
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {state.user_id: state for state in active_presence}
@@ -395,8 +408,6 @@ class WorkerPresenceHandler(BasePresenceHandler):
self._presence_writer_instance = hs.config.worker.writers.presence[0]
- self._presence_enabled = hs.config.server.use_presence
-
# Route presence EDUs to the right worker
hs.get_federation_registry().register_instances_for_edu(
EduTypes.PRESENCE,
@@ -421,8 +432,6 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
)
- self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
-
hs.get_reactor().addSystemEventTrigger(
"before",
"shutdown",
@@ -490,7 +499,9 @@ class WorkerPresenceHandler(BasePresenceHandler):
# what the spec wants: see comment in the BasePresenceHandler version
# of this function.
await self.set_state(
- UserID.from_string(user_id), {"presence": presence_state}, True
+ UserID.from_string(user_id),
+ {"presence": presence_state},
+ ignore_status_msg=True,
)
curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
@@ -601,22 +612,13 @@ class WorkerPresenceHandler(BasePresenceHandler):
"""
presence = state["presence"]
- valid_presence = (
- PresenceState.ONLINE,
- PresenceState.UNAVAILABLE,
- PresenceState.OFFLINE,
- PresenceState.BUSY,
- )
-
- if presence not in valid_presence or (
- presence == PresenceState.BUSY and not self._busy_presence_enabled
- ):
+ if presence not in self.VALID_PRESENCE:
raise SynapseError(400, "Invalid presence state")
user_id = target_user.to_string()
# If presence is disabled, no-op
- if not self.hs.config.server.use_presence:
+ if not self._presence_enabled:
return
# Proxy request to instance that writes presence
@@ -633,7 +635,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
- if not self.hs.config.server.use_presence:
+ if not self._presence_enabled:
return
# Proxy request to instance that writes presence
@@ -649,7 +651,6 @@ class PresenceHandler(BasePresenceHandler):
self.hs = hs
self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier()
- self._presence_enabled = hs.config.server.use_presence
federation_registry = hs.get_federation_registry()
@@ -700,8 +701,6 @@ class PresenceHandler(BasePresenceHandler):
self._on_shutdown,
)
- self._next_serial = 1
-
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
self.user_to_num_current_syncs: Dict[str, int] = {}
@@ -723,21 +722,16 @@ class PresenceHandler(BasePresenceHandler):
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
- def run_timeout_handler() -> Awaitable[None]:
- return run_as_background_process(
- "handle_presence_timeouts", self._handle_timeouts
- )
-
self.clock.call_later(
- 30, self.clock.looping_call, run_timeout_handler, 5000
+ 30, self.clock.looping_call, self._handle_timeouts, 5000
)
- def run_persister() -> Awaitable[None]:
- return run_as_background_process(
- "persist_presence_changes", self._persist_unpersisted_changes
- )
-
- self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
+ self.clock.call_later(
+ 60,
+ self.clock.looping_call,
+ self._persist_unpersisted_changes,
+ 60 * 1000,
+ )
LaterGauge(
"synapse_handlers_presence_wheel_timer_size",
@@ -783,6 +777,7 @@ class PresenceHandler(BasePresenceHandler):
)
logger.info("Finished _on_shutdown")
+ @wrap_as_background_process("persist_presence_changes")
async def _persist_unpersisted_changes(self) -> None:
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
@@ -898,6 +893,7 @@ class PresenceHandler(BasePresenceHandler):
states, [destination]
)
+ @wrap_as_background_process("handle_presence_timeouts")
async def _handle_timeouts(self) -> None:
"""Checks the presence of users that have timed out and updates as
appropriate.
@@ -955,7 +951,7 @@ class PresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
- if not self.hs.config.server.use_presence:
+ if not self._presence_enabled:
return
user_id = user.to_string()
@@ -990,56 +986,51 @@ class PresenceHandler(BasePresenceHandler):
client that is being used by a user.
presence_state: The presence state indicated in the sync request
"""
- # Override if it should affect the user's presence, if presence is
- # disabled.
- if not self.hs.config.server.use_presence:
- affect_presence = False
+ if not affect_presence or not self._presence_enabled:
+ return _NullContextManager()
- if affect_presence:
- curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
- self.user_to_num_current_syncs[user_id] = curr_sync + 1
+ curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
+ self.user_to_num_current_syncs[user_id] = curr_sync + 1
- prev_state = await self.current_state_for_user(user_id)
+ prev_state = await self.current_state_for_user(user_id)
- # If they're busy then they don't stop being busy just by syncing,
- # so just update the last sync time.
- if prev_state.state != PresenceState.BUSY:
- # XXX: We set_state separately here and just update the last_active_ts above
- # This keeps the logic as similar as possible between the worker and single
- # process modes. Using set_state will actually cause last_active_ts to be
- # updated always, which is not what the spec calls for, but synapse has done
- # this for... forever, I think.
- await self.set_state(
- UserID.from_string(user_id), {"presence": presence_state}, True
- )
- # Retrieve the new state for the logic below. This should come from the
- # in-memory cache.
- prev_state = await self.current_state_for_user(user_id)
+ # If they're busy then they don't stop being busy just by syncing,
+ # so just update the last sync time.
+ if prev_state.state != PresenceState.BUSY:
+ # XXX: We set_state separately here and just update the last_active_ts above
+ # This keeps the logic as similar as possible between the worker and single
+ # process modes. Using set_state will actually cause last_active_ts to be
+ # updated always, which is not what the spec calls for, but synapse has done
+ # this for... forever, I think.
+ await self.set_state(
+ UserID.from_string(user_id),
+ {"presence": presence_state},
+ ignore_status_msg=True,
+ )
+ # Retrieve the new state for the logic below. This should come from the
+ # in-memory cache.
+ prev_state = await self.current_state_for_user(user_id)
- # To keep the single process behaviour consistent with worker mode, run the
- # same logic as `update_external_syncs_row`, even though it looks weird.
- if prev_state.state == PresenceState.OFFLINE:
- await self._update_states(
- [
- prev_state.copy_and_replace(
- state=PresenceState.ONLINE,
- last_active_ts=self.clock.time_msec(),
- last_user_sync_ts=self.clock.time_msec(),
- )
- ]
- )
- # otherwise, set the new presence state & update the last sync time,
- # but don't update last_active_ts as this isn't an indication that
- # they've been active (even though it's probably been updated by
- # set_state above)
- else:
- await self._update_states(
- [
- prev_state.copy_and_replace(
- last_user_sync_ts=self.clock.time_msec()
- )
- ]
- )
+ # To keep the single process behaviour consistent with worker mode, run the
+ # same logic as `update_external_syncs_row`, even though it looks weird.
+ if prev_state.state == PresenceState.OFFLINE:
+ await self._update_states(
+ [
+ prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=self.clock.time_msec(),
+ last_user_sync_ts=self.clock.time_msec(),
+ )
+ ]
+ )
+ # otherwise, set the new presence state & update the last sync time,
+ # but don't update last_active_ts as this isn't an indication that
+ # they've been active (even though it's probably been updated by
+ # set_state above)
+ else:
+ await self._update_states(
+ [prev_state.copy_and_replace(last_user_sync_ts=self.clock.time_msec())]
+ )
async def _end() -> None:
try:
@@ -1061,8 +1052,7 @@ class PresenceHandler(BasePresenceHandler):
try:
yield
finally:
- if affect_presence:
- run_in_background(_end)
+ run_in_background(_end)
return _user_syncing()
@@ -1229,20 +1219,11 @@ class PresenceHandler(BasePresenceHandler):
status_msg = state.get("status_msg", None)
presence = state["presence"]
- valid_presence = (
- PresenceState.ONLINE,
- PresenceState.UNAVAILABLE,
- PresenceState.OFFLINE,
- PresenceState.BUSY,
- )
-
- if presence not in valid_presence or (
- presence == PresenceState.BUSY and not self._busy_presence_enabled
- ):
+ if presence not in self.VALID_PRESENCE:
raise SynapseError(400, "Invalid presence state")
# If presence is disabled, no-op
- if not self.hs.config.server.use_presence:
+ if not self._presence_enabled:
return
user_id = target_user.to_string()
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index bd8277e736..bb409f97b7 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -39,7 +39,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
-from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
+from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -629,7 +629,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
async with self.member_linearizer.queue(key):
async with self._worker_lock_handler.acquire_read_write_lock(
- DELETE_ROOM_LOCK_NAME, room_id, write=False
+ NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
diff = self.clock.time_msec() - then
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 6083c9f4b5..d00035c332 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -74,12 +74,13 @@ class SamlHandler:
self.idp_id = "saml"
# user-facing name of this auth provider
- self.idp_name = "SAML"
+ self.idp_name = hs.config.saml2.idp_name
- # we do not currently support icons/brands for SAML auth, but this is required by
- # the SsoIdentityProvider protocol type.
- self.idp_icon = None
- self.idp_brand = None
+ # MXC URI for icon for this auth provider
+ self.idp_icon = hs.config.saml2.idp_icon
+
+ # optional brand identifier for this auth provider
+ self.idp_brand = hs.config.saml2.idp_brand
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 4d29328a74..e9a544e754 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -24,13 +24,14 @@ from typing import (
Iterable,
List,
Mapping,
+ NoReturn,
Optional,
Set,
)
from urllib.parse import urlencode
import attr
-from typing_extensions import NoReturn, Protocol
+from typing_extensions import Protocol
from twisted.web.iweb import IRequest
from twisted.web.server import Request
@@ -791,7 +792,7 @@ class SsoHandler:
if code != 200:
raise Exception(
- "GET request to download sso avatar image returned {}".format(code)
+ f"GET request to download sso avatar image returned {code}"
)
# upload name includes hash of the image file's content so that we can
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 7cabf7980a..3dde19fc81 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -14,9 +14,15 @@
# limitations under the License.
import logging
from collections import Counter
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
-
-from typing_extensions import Counter as CounterType
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Counter as CounterType,
+ Dict,
+ Iterable,
+ Optional,
+ Tuple,
+)
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.metrics import event_processing_positions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c010405be6..60a9f341b5 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -387,16 +387,16 @@ class SyncHandler:
from_token=since_token,
)
- # if nothing has happened in any of the users' rooms since /sync was called,
- # the resultant next_batch will be the same as since_token (since the result
- # is generated when wait_for_events is first called, and not regenerated
- # when wait_for_events times out).
- #
- # If that happens, we mustn't cache it, so that when the client comes back
- # with the same cache token, we don't immediately return the same empty
- # result, causing a tightloop. (#8518)
- if result.next_batch == since_token:
- cache_context.should_cache = False
+ # if nothing has happened in any of the users' rooms since /sync was called,
+ # the resultant next_batch will be the same as since_token (since the result
+ # is generated when wait_for_events is first called, and not regenerated
+ # when wait_for_events times out).
+ #
+ # If that happens, we mustn't cache it, so that when the client comes back
+ # with the same cache token, we don't immediately return the same empty
+ # result, causing a tightloop. (#8518)
+ if result.next_batch == since_token:
+ cache_context.should_cache = False
if result:
if sync_config.filter_collection.lazy_load_members():
@@ -1442,11 +1442,9 @@ class SyncHandler:
# Now we have our list of joined room IDs, exclude as configured and freeze
joined_room_ids = frozenset(
- (
- room_id
- for room_id in mutable_joined_room_ids
- if room_id not in mutable_rooms_to_exclude
- )
+ room_id
+ for room_id in mutable_joined_room_ids
+ if room_id not in mutable_rooms_to_exclude
)
logger.debug(
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 05197edc95..a0f5568000 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -94,6 +94,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.worker.should_update_user_directory
self.search_all_users = hs.config.userdirectory.user_directory_search_all_users
+ self.show_locked_users = hs.config.userdirectory.show_locked_users
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self._hs = hs
@@ -144,7 +145,9 @@ class UserDirectoryHandler(StateDeltasHandler):
]
}
"""
- results = await self.store.search_user_dir(user_id, search_term, limit)
+ results = await self.store.search_user_dir(
+ user_id, search_term, limit, self.show_locked_users
+ )
# Remove any spammy users from the results.
non_spammy_users = []
diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py
index 72df773a86..58efe7116b 100644
--- a/synapse/handlers/worker_lock.py
+++ b/synapse/handlers/worker_lock.py
@@ -42,7 +42,11 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
-DELETE_ROOM_LOCK_NAME = "delete_room_lock"
+# This lock is used to avoid creating an event while we are purging the room.
+# We take a read lock when creating an event, and a write one when purging a room.
+# This is because it is fine to create several events concurrently, since referenced events
+# will not disappear under our feet as long as we don't delete the room.
+NEW_EVENT_DURING_PURGE_LOCK_NAME = "new_event_during_purge_lock"
class WorkerLocksHandler:
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 5a61b21eaf..284fbac524 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -18,10 +18,9 @@ import traceback
from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor
-from typing import Callable, Optional
+from typing import Callable, Deque, Optional
import attr
-from typing_extensions import Deque
from zope.interface import implementer
from twisted.application.internet import ClientService
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index acee1dafd3..9ad8e038ae 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -31,7 +31,7 @@ from typing import (
import attr
import jinja2
-from typing_extensions import ParamSpec
+from typing_extensions import Concatenate, ParamSpec
from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall
@@ -885,7 +885,7 @@ class ModuleApi:
def run_db_interaction(
self,
desc: str,
- func: Callable[P, T],
+ func: Callable[Concatenate[LoggingTransaction, P], T],
*args: P.args,
**kwargs: P.kwargs,
) -> "defer.Deferred[T]":
diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py
index e191450323..32db7cce8d 100644
--- a/synapse/module_api/callbacks/spamchecker_callbacks.py
+++ b/synapse/module_api/callbacks/spamchecker_callbacks.py
@@ -426,9 +426,7 @@ class SpamCheckerModuleApiCallbacks:
generally discouraged as it doesn't support internationalization.
"""
for callback in self._check_event_for_spam_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(event))
if res is False or res == self.NOT_SPAM:
# This spam-checker accepts the event.
@@ -481,9 +479,7 @@ class SpamCheckerModuleApiCallbacks:
True if the event should be silently dropped
"""
for callback in self._should_drop_federated_event_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res: Union[bool, str] = await delay_cancellation(callback(event))
if res:
return res
@@ -505,9 +501,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise.
"""
for callback in self._user_may_join_room_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(user_id, room_id, is_invited))
# Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is True or res is self.NOT_SPAM:
@@ -546,9 +540,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, Codes otherwise.
"""
for callback in self._user_may_invite_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(
callback(inviter_userid, invitee_userid, room_id)
)
@@ -593,9 +585,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, Codes otherwise.
"""
for callback in self._user_may_send_3pid_invite_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(
callback(inviter_userid, medium, address, room_id)
)
@@ -630,9 +620,7 @@ class SpamCheckerModuleApiCallbacks:
userid: The ID of the user attempting to create a room
"""
for callback in self._user_may_create_room_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(userid))
if res is True or res is self.NOT_SPAM:
continue
@@ -666,9 +654,7 @@ class SpamCheckerModuleApiCallbacks:
"""
for callback in self._user_may_create_room_alias_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(userid, room_alias))
if res is True or res is self.NOT_SPAM:
continue
@@ -701,9 +687,7 @@ class SpamCheckerModuleApiCallbacks:
room_id: The ID of the room that would be published
"""
for callback in self._user_may_publish_room_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(userid, room_id))
if res is True or res is self.NOT_SPAM:
continue
@@ -742,9 +726,7 @@ class SpamCheckerModuleApiCallbacks:
True if the user is spammy.
"""
for callback in self._check_username_for_spam_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
# Make a copy of the user profile object to ensure the spam checker cannot
# modify it.
res = await delay_cancellation(callback(user_profile.copy()))
@@ -776,9 +758,7 @@ class SpamCheckerModuleApiCallbacks:
"""
for callback in self._check_registration_for_spam_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
behaviour = await delay_cancellation(
callback(email_threepid, username, request_info, auth_provider_id)
)
@@ -820,9 +800,7 @@ class SpamCheckerModuleApiCallbacks:
"""
for callback in self._check_media_file_for_spam_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(file_wrapper, file_info))
# Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is False or res is self.NOT_SPAM:
@@ -869,9 +847,7 @@ class SpamCheckerModuleApiCallbacks:
"""
for callback in self._check_login_for_spam_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(
callback(
user_id,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index a2cabba7b1..38adcbe1d0 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -17,6 +17,7 @@ from typing import (
TYPE_CHECKING,
Any,
Awaitable,
+ Deque,
Dict,
Iterable,
Iterator,
@@ -29,7 +30,6 @@ from typing import (
)
from prometheus_client import Counter
-from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index e0257daa75..04d9ef25b7 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -280,6 +280,17 @@ class UserRestServletV2(RestServlet):
HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
)
+ lock = body.get("locked", False)
+ if not isinstance(lock, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "'locked' parameter is not of type boolean"
+ )
+
+ if deactivate and lock:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "An user can't be deactivated and locked"
+ )
+
approved: Optional[bool] = None
if "approved" in body and self._msc3866_enabled:
approved = body["approved"]
@@ -397,6 +408,12 @@ class UserRestServletV2(RestServlet):
target_user.to_string()
)
+ if "locked" in body:
+ if lock and not user["locked"]:
+ await self.store.set_user_locked_status(user_id, True)
+ elif not lock and user["locked"]:
+ await self.store.set_user_locked_status(user_id, False)
+
if "user_type" in body:
await self.store.set_user_type(target_user, user_type)
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 51f17f80da..925f037743 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -29,7 +29,6 @@ from synapse.http.servlet import (
parse_integer,
)
from synapse.http.site import SynapseRequest
-from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.rest.client.models import AuthenticationData
from synapse.rest.models import RequestBodyModel
@@ -480,13 +479,6 @@ class DehydratedDeviceV2Servlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = handler
- if hs.config.worker.worker_app is None:
- # if main process
- self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
- else:
- # then a worker
- self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
-
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -549,18 +541,12 @@ class DehydratedDeviceV2Servlet(RestServlet):
"Device key(s) not found, these must be provided.",
)
- # TODO: Those two operations, creating a device and storing the
- # device's keys should be atomic.
device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(),
submission.device_id,
submission.device_data.dict(),
submission.initial_device_display_name,
- )
-
- # TODO: Do we need to do something with the result here?
- await self.key_uploader(
- user_id=user_id, device_id=submission.device_id, keys=submission.dict()
+ device_info,
)
return 200, {"device_id": device_id}
diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py
index 94ad90942f..2e104d4888 100644
--- a/synapse/rest/client/logout.py
+++ b/synapse/rest/client/logout.py
@@ -40,7 +40,9 @@ class LogoutRestServlet(RestServlet):
self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_expired=True)
+ requester = await self.auth.get_user_by_req(
+ request, allow_expired=True, allow_locked=True
+ )
if requester.device_id is None:
# The access token wasn't associated with a device.
@@ -67,7 +69,9 @@ class LogoutAllRestServlet(RestServlet):
self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_expired=True)
+ requester = await self.auth.get_user_by_req(
+ request, allow_expired=True, allow_locked=True
+ )
user_id = requester.user.to_string()
# first delete all of the user's devices
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 5c9fece3ba..5ed3b83a03 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -32,6 +32,7 @@ from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.types import JsonDict
+from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -53,26 +54,32 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker.worker_app is not None
self._push_rules_handler = hs.get_push_rules_handler()
+ self._push_rule_linearizer = Linearizer(name="push_rules")
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
+ requester = await self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ async with self._push_rule_linearizer.queue(user_id):
+ return await self.handle_put(request, path, user_id)
+
+ async def handle_put(
+ self, request: SynapseRequest, path: str, user_id: str
+ ) -> Tuple[int, JsonDict]:
spec = _rule_spec_from_path(path.split("/"))
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, str(e))
- requester = await self.auth.get_user_by_req(request)
-
if "/" in spec.rule_id or "\\" in spec.rule_id:
raise SynapseError(400, "rule_id may not contain slashes")
content = parse_json_value_from_request(request)
- user_id = requester.user.to_string()
-
if spec.attr:
try:
await self._push_rules_handler.set_rule_attr(user_id, spec, content)
@@ -126,11 +133,20 @@ class PushRuleRestServlet(RestServlet):
if self._is_worker:
raise Exception("Cannot handle DELETE /push_rules on worker")
- spec = _rule_spec_from_path(path.split("/"))
-
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
+ async with self._push_rule_linearizer.queue(user_id):
+ return await self.handle_delete(request, path, user_id)
+
+ async def handle_delete(
+ self,
+ request: SynapseRequest,
+ path: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
+ spec = _rule_spec_from_path(path.split("/"))
+
namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
try:
diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py
index 4a5d9e13e7..b1f6b5d1b7 100644
--- a/synapse/rest/client/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/room_upgrade_rest_servlet.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
+from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -81,7 +81,7 @@ class RoomUpgradeRestServlet(RestServlet):
try:
async with self._worker_lock_handler.acquire_read_write_lock(
- DELETE_ROOM_LOCK_NAME, room_id, write=False
+ NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 8f3865d412..981fd1f58a 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -14,7 +14,7 @@
import logging
import re
-from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
from signedjson.sign import sign_json
@@ -27,6 +27,7 @@ from synapse.http.servlet import (
parse_integer,
parse_json_object_from_request,
)
+from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
@@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
) -> JsonDict:
logger.info("Handling query for keys %r", query)
- store_queries = []
+ server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
for server_name, key_ids in query.items():
- if not key_ids:
- key_ids = (None,)
- for key_id in key_ids:
- store_queries.append((server_name, key_id, None))
+ if key_ids:
+ results: Mapping[
+ str, Optional[FetchKeyResultForRemote]
+ ] = await self.store.get_server_keys_json_for_remote(
+ server_name, key_ids
+ )
+ else:
+ results = await self.store.get_all_server_keys_json_for_remote(
+ server_name
+ )
- cached = await self.store.get_server_keys_json_for_remote(store_queries)
+ server_keys.update(
+ ((server_name, key_id), res) for key_id, res in results.items()
+ )
json_results: Set[bytes] = set()
@@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
# Map server_name->key_id->int. Note that the value of the int is unused.
# XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {}
- for (server_name, key_id, _), key_results in cached.items():
- results = [(result["ts_added_ms"], result) for result in key_results]
-
- if key_id is None:
+ for (server_name, key_id), key_result in server_keys.items():
+ if not query[server_name]:
# all keys were requested. Just return what we have without worrying
# about validity
- for _, result in results:
- # Cast to bytes since postgresql returns a memoryview.
- json_results.add(bytes(result["key_json"]))
+ if key_result:
+ json_results.add(key_result.key_json)
continue
miss = False
- if not results:
+ if key_result is None:
miss = True
else:
- ts_added_ms, most_recent_result = max(results)
- ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
+ ts_added_ms = key_result.added_ts
+ ts_valid_until_ms = key_result.valid_until_ts
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
if req_valid_until is not None:
@@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
ts_valid_until_ms,
time_now_ms,
)
- # Cast to bytes since postgresql returns a memoryview.
- json_results.add(bytes(most_recent_result["key_json"]))
+
+ json_results.add(key_result.key_json)
if miss and query_remote_on_cache_miss:
# only bother attempting to fetch keys from servers on our whitelist
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 2d5ddc3e7b..ddca0af1da 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -238,6 +238,7 @@ class BackgroundUpdater:
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
+ self.hs = hs
self._database_name = database.name()
@@ -758,6 +759,11 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql)
c.execute(sql)
+ # override the global statement timeout to avoid accidentally squashing
+ # a long-running index creation process
+ timeout_sql = "SET SESSION statement_timeout = 0"
+ c.execute(timeout_sql)
+
sql = (
"CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
" ON %(table)s"
@@ -778,6 +784,12 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql)
c.execute(sql)
finally:
+ # mypy ignore - `statement_timeout` is defined on PostgresEngine
+ # reset the global timeout to the default
+ default_timeout = self.db_pool.engine.statement_timeout # type: ignore[attr-defined]
+ undo_timeout_sql = f"SET statement_timeout = {default_timeout}"
+ conn.cursor().execute(undo_timeout_sql)
+
conn.set_session(autocommit=False) # type: ignore
def create_index_sqlite(conn: Connection) -> None:
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 35cd1089d6..abd1d149db 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,7 +45,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
+from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.logging.opentracing import (
SynapseTags,
@@ -357,7 +357,7 @@ class EventsPersistenceStorageController:
# it. We might already have taken out the lock, but since this is just a
# "read" lock its inherently reentrant.
async with self.hs.get_worker_locks_handler().acquire_read_write_lock(
- DELETE_ROOM_LOCK_NAME, room_id, write=False
+ NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
if isinstance(task, _PersistEventsTask):
return await self._persist_event_batch(room_id, task)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d9df437e51..e4162f846b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,7 @@ from typing import (
cast,
)
+from canonicaljson import encode_canonical_json
from typing_extensions import Literal
from synapse.api.constants import EduTypes
@@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
def _store_dehydrated_device_txn(
- self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ device_data: str,
+ time: int,
+ keys: Optional[JsonDict] = None,
) -> Optional[str]:
+ # TODO: make keys non-optional once support for msc2697 is dropped
+ if keys:
+ device_keys = keys.get("device_keys", None)
+ if device_keys:
+ # Type ignore - this function is defined on EndToEndKeyStore which we do
+ # have access to due to hs.get_datastore() "magic"
+ self._set_e2e_device_keys_txn( # type: ignore[attr-defined]
+ txn, user_id, device_id, time, device_keys
+ )
+
+ one_time_keys = keys.get("one_time_keys", None)
+ if one_time_keys:
+ key_list = []
+ for key_id, key_obj in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append(
+ (
+ algorithm,
+ key_id,
+ encode_canonical_json(key_obj).decode("ascii"),
+ )
+ )
+ self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)
+
+ fallback_keys = keys.get("fallback_keys", None)
+ if fallback_keys:
+ self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)
+
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="dehydrated_devices",
@@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
keyvalues={"user_id": user_id},
values={"device_id": device_id, "device_data": device_data},
)
+
return old_device_id
async def store_dehydrated_device(
- self, user_id: str, device_id: str, device_data: JsonDict
+ self,
+ user_id: str,
+ device_id: str,
+ device_data: JsonDict,
+ time_now: int,
+ keys: Optional[dict] = None,
) -> Optional[str]:
"""Store a dehydrated device for a user.
@@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id: the user that we are storing the device for
device_id: the ID of the dehydrated device
device_data: the dehydrated device information
+ time_now: current time at the request in milliseconds
+ keys: keys for the dehydrated device
+
Returns:
device id of the user's previous dehydrated device, if any
"""
+
return await self.db_pool.runInteraction(
"store_dehydrated_device_txn",
self._store_dehydrated_device_txn,
user_id,
device_id,
json_encoder.encode(device_data),
+ time_now,
+ keys,
)
async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 91ae9c457d..b49dea577c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -522,36 +522,57 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
- def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
- set_tag("user_id", user_id)
- set_tag("device_id", device_id)
- set_tag("new_keys", str(new_keys))
- # We are protected from race between lookup and insertion due to
- # a unique constraint. If there is a race of two calls to
- # `add_e2e_one_time_keys` then they'll conflict and we will only
- # insert one set.
- self.db_pool.simple_insert_many_txn(
- txn,
- table="e2e_one_time_keys_json",
- keys=(
- "user_id",
- "device_id",
- "algorithm",
- "key_id",
- "ts_added_ms",
- "key_json",
- ),
- values=[
- (user_id, device_id, algorithm, key_id, time_now, json_bytes)
- for algorithm, key_id, json_bytes in new_keys
- ],
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
await self.db_pool.runInteraction(
- "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
+ "add_e2e_one_time_keys_insert",
+ self._add_e2e_one_time_keys_txn,
+ user_id,
+ device_id,
+ time_now,
+ new_keys,
+ )
+
+ def _add_e2e_one_time_keys_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ time_now: int,
+ new_keys: Iterable[Tuple[str, str, str]],
+ ) -> None:
+ """Insert some new one time keys for a device. Errors if any of the keys already exist.
+
+ Args:
+ user_id: id of user to get keys for
+ device_id: id of device to get keys for
+ time_now: insertion time to record (ms since epoch)
+ new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note
+ that the key JSON must be in canonical JSON form
+ """
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("new_keys", str(new_keys))
+ # We are protected from race between lookup and insertion due to
+ # a unique constraint. If there is a race of two calls to
+ # `add_e2e_one_time_keys` then they'll conflict and we will only
+ # insert one set.
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ keys=(
+ "user_id",
+ "device_id",
+ "algorithm",
+ "key_id",
+ "ts_added_ms",
+ "key_json",
+ ),
+ values=[
+ (user_id, device_id, algorithm, key_id, time_now, json_bytes)
+ for algorithm, key_id, json_bytes in new_keys
+ ],
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
@cached(max_entries=10000)
@@ -723,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
device_id: str,
fallback_keys: JsonDict,
) -> None:
+ """Set the user's e2e fallback keys.
+
+ Args:
+ user_id: the user whose keys are being set
+ device_id: the device whose keys are being set
+ fallback_keys: the keys to set. This is a map from key ID (which is
+ of the form "algorithm:id") to key data.
+ """
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
@@ -1304,42 +1333,69 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
+
+ Args:
+ user_id: user_id of the user to store keys for
+ device_id: device_id of the device to store keys for
+ time_now: time at the request to store the keys
+ device_keys: the keys to store
"""
- def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
- set_tag("user_id", user_id)
- set_tag("device_id", device_id)
- set_tag("time_now", time_now)
- set_tag("device_keys", str(device_keys))
+ return await self.db_pool.runInteraction(
+ "set_e2e_device_keys",
+ self._set_e2e_device_keys_txn,
+ user_id,
+ device_id,
+ time_now,
+ device_keys,
+ )
- old_key_json = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- retcol="key_json",
- allow_none=True,
- )
+ def _set_e2e_device_keys_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ time_now: int,
+ device_keys: JsonDict,
+ ) -> bool:
+ """Stores device keys for a device. Returns whether there was a change
+ or the keys were already in the database.
- # In py3 we need old_key_json to match new_key_json type. The DB
- # returns unicode while encode_canonical_json returns bytes.
- new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+ Args:
+ user_id: user_id of the user to store keys for
+ device_id: device_id of the device to store keys for
+ time_now: time at the request to store the keys
+ device_keys: the keys to store
+ """
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("time_now", time_now)
+ set_tag("device_keys", str(device_keys))
+
+ old_key_json = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcol="key_json",
+ allow_none=True,
+ )
- if old_key_json == new_key_json:
- log_kv({"Message": "Device key already stored."})
- return False
+ # In py3 we need old_key_json to match new_key_json type. The DB
+ # returns unicode while encode_canonical_json returns bytes.
+ new_key_json = encode_canonical_json(device_keys).decode("utf-8")
- self.db_pool.simple_upsert_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- values={"ts_added_ms": time_now, "key_json": new_key_json},
- )
- log_kv({"message": "Device keys stored."})
- return True
+ if old_key_json == new_key_json:
+ log_kv({"Message": "Device key already stored."})
+ return False
- return await self.db_pool.runInteraction(
- "set_e2e_device_keys", _set_e2e_device_keys_txn
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ values={"ts_added_ms": time_now, "key_json": new_key_json},
)
+ log_kv({"message": "Device keys stored."})
+ return True
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index fff417f9e3..047de6283a 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
-from typing_extensions import TYPE_CHECKING
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 1666e3c43b..a3b4744855 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,14 +16,13 @@
import itertools
import json
import logging
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
+from typing import Dict, Iterable, Mapping, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
-from synapse.storage.keys import FetchKeyResult
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
db_binary_type = memoryview
-class KeyStore(SQLBaseStore):
+class KeyStore(CacheInvalidationWorkerStore):
"""Persistence for signature verification keys"""
@cached()
@@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
- self._get_server_keys_json.invalidate((((server_name, key_id),)))
+ await self.invalidate_cache_and_stream(
+ "_get_server_keys_json", ((server_name, key_id),)
+ )
+ await self.invalidate_cache_and_stream(
+ "get_server_key_json_for_remote", (server_name, key_id)
+ )
@cached()
def _get_server_keys_json(
@@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
- async def get_server_keys_json_for_remote(
- self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
- ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
- """Retrieve the key json for a list of server_keys and key ids.
- If no keys are found for a given server, key_id and source then
- that server, key_id, and source triplet entry will be an empty list.
- The JSON is returned as a byte array so that it can be efficiently
- used in an HTTP response.
+ @cached()
+ def get_server_key_json_for_remote(
+ self,
+ server_name: str,
+ key_id: str,
+ ) -> Optional[FetchKeyResultForRemote]:
+ raise NotImplementedError()
- Args:
- server_keys: List of (server_name, key_id, source) triplets.
+ @cachedList(
+ cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
+ )
+ async def get_server_keys_json_for_remote(
+ self, server_name: str, key_ids: Iterable[str]
+ ) -> Dict[str, Optional[FetchKeyResultForRemote]]:
+ """Fetch the cached keys for the given server/key IDs.
- Returns:
- A mapping from (server_name, key_id, source) triplets to a list of dicts
+ If we have multiple entries for a given key ID, returns the most recent.
"""
+ rows = await self.db_pool.simple_select_many_batch(
+ table="server_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
+ )
- def _get_server_keys_json_txn(
- txn: LoggingTransaction,
- ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
- results = {}
- for server_name, key_id, from_server in server_keys:
- keyvalues = {"server_name": server_name}
- if key_id is not None:
- keyvalues["key_id"] = key_id
- if from_server is not None:
- keyvalues["from_server"] = from_server
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "server_keys_json",
- keyvalues=keyvalues,
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
- ),
- )
- results[(server_name, key_id, from_server)] = rows
- return results
+ if not rows:
+ return {}
+
+ # We sort the rows so that the most recently added entry is picked up.
+ rows.sort(key=lambda r: r["ts_added_ms"])
+
+ return {
+ row["key_id"]: FetchKeyResultForRemote(
+ # Cast to bytes since postgresql returns a memoryview.
+ key_json=bytes(row["key_json"]),
+ valid_until_ts=row["ts_valid_until_ms"],
+ added_ts=row["ts_added_ms"],
+ )
+ for row in rows
+ }
- return await self.db_pool.runInteraction(
- "get_server_keys_json", _get_server_keys_json_txn
+ async def get_all_server_keys_json_for_remote(
+ self,
+ server_name: str,
+ ) -> Dict[str, FetchKeyResultForRemote]:
+ """Fetch the cached keys for the given server.
+
+ If we have multiple entries for a given key ID, returns the most recent.
+ """
+ rows = await self.db_pool.simple_select_list(
+ table="server_keys_json",
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
)
+
+ if not rows:
+ return {}
+
+ rows.sort(key=lambda r: r["ts_added_ms"])
+
+ return {
+ row["key_id"]: FetchKeyResultForRemote(
+ # Cast to bytes since postgresql returns a memoryview.
+ key_json=bytes(row["key_json"]),
+ valid_until_ts=row["ts_valid_until_ms"],
+ added_ts=row["ts_added_ms"],
+ )
+ for row in rows
+ }
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 1680bf6168..54d40e7a3a 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -26,7 +26,6 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
-from synapse.storage.engines import PostgresEngine
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -96,6 +95,10 @@ class LockStore(SQLBaseStore):
self._acquiring_locks: Set[Tuple[str, str]] = set()
+ self._clock.looping_call(
+ self._reap_stale_read_write_locks, _LOCK_TIMEOUT_MS / 10.0
+ )
+
@wrap_as_background_process("LockStore._on_shutdown")
async def _on_shutdown(self) -> None:
"""Called when the server is shutting down"""
@@ -216,6 +219,7 @@ class LockStore(SQLBaseStore):
lock_name,
lock_key,
write,
+ db_autocommit=True,
)
except self.database_engine.module.IntegrityError:
return None
@@ -233,61 +237,22 @@ class LockStore(SQLBaseStore):
# `worker_read_write_locks` and seeing if that fails any
# constraints. If it doesn't then we have acquired the lock,
# otherwise we haven't.
- #
- # Before that though we clear the table of any stale locks.
now = self._clock.time_msec()
token = random_string(6)
- delete_sql = """
- DELETE FROM worker_read_write_locks
- WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
- """
-
- insert_sql = """
- INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
- VALUES (?, ?, ?, ?, ?, ?)
- """
-
- if isinstance(self.database_engine, PostgresEngine):
- # For Postgres we can send these queries at the same time.
- txn.execute(
- delete_sql + ";" + insert_sql,
- (
- # DELETE args
- now - _LOCK_TIMEOUT_MS,
- lock_name,
- lock_key,
- # UPSERT args
- lock_name,
- lock_key,
- write,
- self._instance_name,
- token,
- now,
- ),
- )
- else:
- # For SQLite these need to be two queries.
- txn.execute(
- delete_sql,
- (
- now - _LOCK_TIMEOUT_MS,
- lock_name,
- lock_key,
- ),
- )
- txn.execute(
- insert_sql,
- (
- lock_name,
- lock_key,
- write,
- self._instance_name,
- token,
- now,
- ),
- )
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="worker_read_write_locks",
+ values={
+ "lock_name": lock_name,
+ "lock_key": lock_key,
+ "write_lock": write,
+ "instance_name": self._instance_name,
+ "token": token,
+ "last_renewed_ts": now,
+ },
+ )
lock = Lock(
self._reactor,
@@ -351,6 +316,24 @@ class LockStore(SQLBaseStore):
return locks
+ @wrap_as_background_process("_reap_stale_read_write_locks")
+ async def _reap_stale_read_write_locks(self) -> None:
+ delete_sql = """
+ DELETE FROM worker_read_write_locks
+ WHERE last_renewed_ts < ?
+ """
+
+ def reap_stale_read_write_locks_txn(txn: LoggingTransaction) -> None:
+ txn.execute(delete_sql, (self._clock.time_msec() - _LOCK_TIMEOUT_MS,))
+ if txn.rowcount:
+ logger.info("Reaped %d stale locks", txn.rowcount)
+
+ await self.db_pool.runInteraction(
+ "_reap_stale_read_write_locks",
+ reap_stale_read_write_locks_txn,
+ db_autocommit=True,
+ )
+
class Lock:
"""An async context manager that manages an acquired lock, ensuring it is
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c582cf0573..d3a01d526f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -205,7 +205,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
name, password_hash, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
- COALESCE(approved, TRUE) AS approved
+ COALESCE(approved, TRUE) AS approved,
+ COALESCE(locked, FALSE) AS locked
FROM users
WHERE name = ?
""",
@@ -230,10 +231,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
- boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
+ boolean_columns = [
+ "admin",
+ "deactivated",
+ "shadow_banned",
+ "approved",
+ "locked",
+ ]
for column in boolean_columns:
- if not isinstance(row[column], bool):
- row[column] = bool(row[column])
+ row[column] = bool(row[column])
return row
@@ -1116,6 +1122,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Convert the integer into a boolean.
return res == 1
+ @cached()
+ async def get_user_locked_status(self, user_id: str) -> bool:
+ """Retrieve the value for the `locked` property for the provided user.
+
+ Args:
+ user_id: The ID of the user to retrieve the status for.
+
+ Returns:
+ True if the user was locked, false if the user is still active.
+ """
+
+ res = await self.db_pool.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="locked",
+ desc="get_user_locked_status",
+ )
+
+ # Convert the potential integer into a boolean.
+ return bool(res)
+
async def get_threepid_validation_session(
self,
medium: Optional[str],
@@ -2111,6 +2138,33 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
+ async def set_user_locked_status(self, user_id: str, locked: bool) -> None:
+ """Set the `locked` property for the provided user to the provided value.
+
+ Args:
+ user_id: The ID of the user to set the status for.
+ locked: The value to set for `locked`.
+ """
+
+ await self.db_pool.runInteraction(
+ "set_user_locked_status",
+ self.set_user_locked_status_txn,
+ user_id,
+ locked,
+ )
+
+ def set_user_locked_status_txn(
+ self, txn: LoggingTransaction, user_id: str, locked: bool
+ ) -> None:
+ self.db_pool.simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"locked": locked},
+ )
+ self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,))
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
def update_user_approval_status_txn(
self, txn: LoggingTransaction, user_id: str, approved: bool
) -> None:
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index f34b7ce8f4..6298f0984d 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -19,6 +19,7 @@ from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
+ Counter,
Dict,
Iterable,
List,
@@ -28,8 +29,6 @@ from typing import (
cast,
)
-from typing_extensions import Counter
-
from twisted.internet.defer import DeferredLock
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 2a136f2ff6..f0dc31fee6 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -995,7 +995,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
async def search_user_dir(
- self, user_id: str, search_term: str, limit: int
+ self,
+ user_id: str,
+ search_term: str,
+ limit: int,
+ show_locked_users: bool = False,
) -> SearchResult:
"""Searches for users in directory
@@ -1029,6 +1033,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
"""
+ if not show_locked_users:
+ where_clause += " AND (u.locked IS NULL OR u.locked = FALSE)"
+
# We allow manipulating the ranking algorithm by injecting statements
# based on config options.
additional_ordering_statements = []
@@ -1060,6 +1067,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
SELECT d.user_id AS user_id, display_name, avatar_url
FROM matching_users as t
INNER JOIN user_directory AS d USING (user_id)
+ LEFT JOIN users AS u ON t.user_id = u.name
WHERE
%(where_clause)s
ORDER BY
@@ -1115,6 +1123,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search as t
INNER JOIN user_directory AS d USING (user_id)
+ LEFT JOIN users AS u ON t.user_id = u.name
WHERE
%(where_clause)s
AND value MATCH ?
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 0363cdc038..0b5b3bf03e 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -145,5 +145,5 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
This is not provided by DBAPI2, and so needs engine-specific support.
"""
- with open(filepath, "rt") as f:
+ with open(filepath) as f:
cls.executescript(cursor, f.read())
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 71584f3f74..e74b2269d2 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -25,3 +25,10 @@ logger = logging.getLogger(__name__)
class FetchKeyResult:
verify_key: VerifyKey # the key itself
valid_until_ts: int # how long we can use this key for
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FetchKeyResultForRemote:
+ key_json: bytes # the full key JSON
+ valid_until_ts: int # how long we can use this key for, in milliseconds.
+ added_ts: int # When we added this key, in milliseconds.
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 38b7abd801..31501fd573 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -16,10 +16,18 @@ import logging
import os
import re
from collections import Counter
-from typing import Collection, Generator, Iterable, List, Optional, TextIO, Tuple
+from typing import (
+ Collection,
+ Counter as CounterType,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ TextIO,
+ Tuple,
+)
import attr
-from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction
diff --git a/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql b/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql
new file mode 100644
index 0000000000..21c7971441
--- /dev/null
+++ b/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql
@@ -0,0 +1,16 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE users ADD locked BOOLEAN DEFAULT FALSE NOT NULL;
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 39a1ae4ac3..073f682aca 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -21,6 +21,7 @@ from typing import (
Any,
ClassVar,
Dict,
+ Final,
List,
Mapping,
Match,
@@ -38,7 +39,7 @@ import attr
from immutabledict import immutabledict
from signedjson.key import decode_verify_key_bytes
from signedjson.types import VerifyKey
-from typing_extensions import Final, TypedDict
+from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 4041e49e71..943ad54456 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -22,6 +22,7 @@ import logging
from contextlib import asynccontextmanager
from typing import (
Any,
+ AsyncContextManager,
AsyncIterator,
Awaitable,
Callable,
@@ -42,7 +43,7 @@ from typing import (
)
import attr
-from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec
+from typing_extensions import Concatenate, Literal, ParamSpec
from twisted.internet import defer
from twisted.internet.defer import CancelledError
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index 644c341e8c..db6c40a3e1 100644
--- a/synapse/util/macaroons.py
+++ b/synapse/util/macaroons.py
@@ -218,7 +218,7 @@ class MacaroonGenerator:
# to avoid validating those as guest tokens, we explicitely verify if
# the macaroon includes the "guest = true" caveat.
is_guest = any(
- (caveat.caveat_id == "guest = true" for caveat in macaroon.caveats)
+ caveat.caveat_id == "guest = true" for caveat in macaroon.caveats
)
if not is_guest:
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 48b8195ca1..8cb766860e 100644
--- a/synapse/util/manhole.py
+++ b/synapse/util/manhole.py
@@ -98,7 +98,9 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory:
SynapseManhole, dict(globals, __name__="__console__")
)
- factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
+ # type-ignore: This is an error in Twisted's annotations. See
+ # https://github.com/twisted/twisted/issues/11812 and /11813 .
+ factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) # type: ignore[arg-type]
# conch has the wrong type on these dicts (says bytes to bytes,
# should be bytes to Keys judging by how it's used).
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 2ad55ac13e..cde4a0780f 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -20,6 +20,7 @@ import typing
from typing import (
Any,
Callable,
+ ContextManager,
DefaultDict,
Dict,
Iterator,
@@ -33,7 +34,6 @@ from typing import (
from weakref import WeakSet
from prometheus_client.core import Counter
-from typing_extensions import ContextManager
from twisted.internet import defer
diff --git a/synapse/visibility.py b/synapse/visibility.py
index fc71dc92a4..eac10f6438 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -17,6 +17,7 @@ from enum import Enum, auto
from typing import (
Collection,
Dict,
+ Final,
FrozenSet,
List,
Mapping,
@@ -27,7 +28,6 @@ from typing import (
)
import attr
-from typing_extensions import Final
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index cdb0048122..ce96574915 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -69,6 +69,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.store.get_user_by_access_token = simple_async_mock(user_info)
self.store.mark_access_token_as_used = simple_async_mock(None)
+ self.store.get_user_locked_status = simple_async_mock(False)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
@@ -293,6 +294,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.store.insert_client_ip = simple_async_mock(None)
self.store.mark_access_token_as_used = simple_async_mock(None)
+ self.store.get_user_locked_status = simple_async_mock(False)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@@ -311,6 +313,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
token_used=True,
)
)
+ self.store.get_user_locked_status = simple_async_mock(False)
self.store.insert_client_ip = simple_async_mock(None)
self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={})
diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py
index 9305b758d7..93af614def 100644
--- a/tests/app/test_phone_stats_home.py
+++ b/tests/app/test_phone_stats_home.py
@@ -26,7 +26,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
def make_homeserver(
self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer:
- hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock)
+ hs = super().make_homeserver(reactor, clock)
# We don't want our tests to actually report statistics, so check
# that it's not enabled
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 7c63b2ea4c..2be341ac7b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -312,7 +312,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
[("server9", get_key_id(key1))]
)
result = self.get_success(d)
- self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0)
+ self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0)
def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped."""
@@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
- lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
- [lookup_triplet]
+ SERVER_NAME, [testverifykey_id]
)
)
- res_keys = key_json[lookup_triplet]
- self.assertEqual(len(res_keys), 1)
- res = res_keys[0]
- self.assertEqual(res["key_id"], testverifykey_id)
- self.assertEqual(res["from_server"], SERVER_NAME)
- self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
- self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+ res = key_json[testverifykey_id]
+ self.assertIsNotNone(res)
+ assert res is not None
+ self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+ self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
# we expect it to be encoded as canonical json *before* it hits the db
- self.assertEqual(
- bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
- )
+ self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
# change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
@@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
- lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
- [lookup_triplet]
+ SERVER_NAME, [testverifykey_id]
)
)
- res_keys = key_json[lookup_triplet]
- self.assertEqual(len(res_keys), 1)
- res = res_keys[0]
- self.assertEqual(res["key_id"], testverifykey_id)
- self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
- self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
- self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
- self.assertEqual(
- bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
- )
+ res = key_json[testverifykey_id]
+ self.assertIsNotNone(res)
+ assert res is not None
+ self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+ self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+ self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
def test_get_multiple_keys_from_perspectives(self) -> None:
"""Check that we can correctly request multiple keys for the same server"""
@@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
- lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
- [lookup_triplet]
+ SERVER_NAME, [testverifykey_id]
)
)
- res_keys = key_json[lookup_triplet]
- self.assertEqual(len(res_keys), 1)
- res = res_keys[0]
- self.assertEqual(res["key_id"], testverifykey_id)
- self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
- self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
- self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
- self.assertEqual(
- bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
- )
+ res = key_json[testverifykey_id]
+ self.assertIsNotNone(res)
+ assert res is not None
+ self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+ self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+ self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
def test_invalid_perspectives_responses(self) -> None:
"""Check that invalid responses from the perspectives server are rejected"""
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 647ee09279..e1e58fa6e6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -566,15 +566,16 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(res["events"]), 1)
self.assertEqual(res["events"][0]["content"]["body"], "foo")
- # Fetch the message of the dehydrated device again, which should return nothing
- # and delete the old messages
+ # Fetch the message of the dehydrated device again, which should return
+ # the same message as it has not been deleted
res = self.get_success(
self.message_handler.get_events_for_dehydrated_device(
requester=requester,
device_id=stored_dehydrated_device_id,
- since_token=res["next_batch"],
+ since_token=None,
limit=10,
)
)
self.assertTrue(len(res["next_batch"]) > 1)
- self.assertEqual(len(res["events"]), 0)
+ self.assertEqual(len(res["events"]), 1)
+ self.assertEqual(res["events"][0]["content"]["body"], "foo")
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index 6309d7b36e..82c26e303f 100644
--- a/tests/handlers/test_oauth_delegation.py
+++ b/tests/handlers/test_oauth_delegation.py
@@ -491,6 +491,68 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
+ def test_introspection_token_cache(self) -> None:
+ access_token = "open_sesame"
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={"active": "true", "scope": "guest", "jti": access_token},
+ )
+ )
+
+ # first call should cache response
+ # Mpyp ignores below are due to mypy not understanding the dynamic substitution of msc3861 auth code
+ # for regular auth code via the config
+ self.get_success(
+ self.auth._introspect_token(access_token) # type: ignore[attr-defined]
+ )
+ introspection_token = self.auth._token_cache.get(access_token) # type: ignore[attr-defined]
+ self.assertEqual(introspection_token["jti"], access_token)
+ # there's been one http request
+ self.http_client.request.assert_called_once()
+
+ # second call should pull from cache, there should still be only one http request
+ token = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
+ self.http_client.request.assert_called_once()
+ self.assertEqual(token["jti"], access_token)
+
+ # advance past five minutes and check that cache expired - there should be more than one http call now
+ self.reactor.advance(360)
+ token_2 = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
+ self.assertEqual(self.http_client.request.call_count, 2)
+ self.assertEqual(token_2["jti"], access_token)
+
+ # test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
+ # token with a soon-to-expire `exp` field to the cache
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": "true",
+ "scope": "guest",
+ "jti": "stale",
+ "exp": self.clock.time() + 100,
+ },
+ )
+ )
+ self.get_success(
+ self.auth._introspect_token("stale") # type: ignore[attr-defined]
+ )
+ introspection_token = self.auth._token_cache.get("stale") # type: ignore[attr-defined]
+ self.assertEqual(introspection_token["jti"], "stale")
+ self.assertEqual(self.http_client.request.call_count, 1)
+
+ # advance the reactor past the token expiry but less than the cache expiry
+ self.reactor.advance(120)
+ self.assertEqual(self.auth._token_cache.get("stale"), introspection_token) # type: ignore[attr-defined]
+
+ # check that the next call causes another http request (which will fail because the token is technically expired
+ # but the important thing is we discard the token from the cache and try the network)
+ self.get_failure(
+ self.auth._introspect_token("stale"), InvalidClientTokenError # type: ignore[attr-defined]
+ )
+ self.assertEqual(self.http_client.request.call_count, 2)
+
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index aed2a4c07a..6a0b5fc0bd 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -514,7 +514,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(response.code, 200)
# Send the body
- request.write('{ "a": 1 }'.encode("ascii"))
+ request.write(b'{ "a": 1 }')
request.finish()
self.reactor.pump((0.1,))
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index b3310abe1b..fe631d7ecb 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -757,7 +757,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(channel.json_body["creator"], user_id)
# Check room alias.
- self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}")
+ self.assertEqual(room_alias, f"#foo-bar:{self.module_api.server_name}")
# Let's try a room with no alias.
room_id, room_alias = self.get_success(
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 1527b4a82d..6e78daa830 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -116,7 +116,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(request.method, b"GET")
self.assertEqual(
request.path,
- f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"),
+ f"/_matrix/media/r0/download/{target}/{media_id}".encode(),
)
self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 9af9db6e3e..41a959b4d6 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -29,7 +29,16 @@ from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.media.filepath import MediaFilePaths
-from synapse.rest.client import devices, login, logout, profile, register, room, sync
+from synapse.rest.client import (
+ devices,
+ login,
+ logout,
+ profile,
+ register,
+ room,
+ sync,
+ user_directory,
+)
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -1477,6 +1486,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
login.register_servlets,
sync.register_servlets,
register.register_servlets,
+ user_directory.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -2464,6 +2474,105 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# This key was removed intentionally. Ensure it is not accidentally re-included.
self.assertNotIn("password_hash", channel.json_body)
+ def test_locked_user(self) -> None:
+ # User can sync
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/sync",
+ access_token=self.other_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Lock user
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"locked": True},
+ )
+
+ # User is not authorized to sync anymore
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/sync",
+ access_token=self.other_user_token,
+ )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.USER_LOCKED, channel.json_body["errcode"])
+ self.assertTrue(channel.json_body["soft_logout"])
+
+ @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
+ def test_locked_user_not_in_user_dir(self) -> None:
+ # User is available in the user dir
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/user_directory/search",
+ {"search_term": self.other_user},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("results", channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+
+ # Lock user
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"locked": True},
+ )
+
+ # User is not available anymore in the user dir
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/user_directory/search",
+ {"search_term": self.other_user},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("results", channel.json_body)
+ self.assertEqual(0, len(channel.json_body["results"]))
+
+ @override_config(
+ {
+ "user_directory": {
+ "enabled": True,
+ "search_all_users": True,
+ "show_locked_users": True,
+ }
+ }
+ )
+ def test_locked_user_in_user_dir_with_show_locked_users_option(self) -> None:
+ # User is available in the user dir
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/user_directory/search",
+ {"search_term": self.other_user},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("results", channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+
+ # Lock user
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"locked": True},
+ )
+
+ # User is still available in the user dir
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/user_directory/search",
+ {"search_term": self.other_user},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("results", channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
def test_change_name_deactivate_user_user_directory(self) -> None:
"""
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index 3cf29c10ea..60099f8c59 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError
from synapse.rest import admin, devices, room, sync
from synapse.rest.client import account, keys, login, register
from synapse.server import HomeServer
-from synapse.types import JsonDict, create_requester
+from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -282,6 +282,17 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
"<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
},
},
+ "fallback_keys": {
+ "alg1:device1": "f4llb4ckk3y",
+ "signed_<algorithm>:<device_id>": {
+ "fallback": "true",
+ "key": "f4llb4ckk3y",
+ "signatures": {
+ "<user_id>": {"<algorithm>:<device_id>": "<key_base64>"}
+ },
+ },
+ },
+ "one_time_keys": {"alg1:k1": "0net1m3k3y"},
}
channel = self.make_request(
"PUT",
@@ -312,6 +323,55 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
}
self.assertEqual(device_data, expected_device_data)
+ # test that the keys are correctly uploaded
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ user: ["device1"],
+ },
+ },
+ token,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body["device_keys"][user][device_id]["keys"],
+ content["device_keys"]["keys"],
+ )
+ # first claim should return the onetime key we uploaded
+ res = self.get_success(
+ self.hs.get_e2e_keys_handler().claim_one_time_keys(
+ {user: {device_id: {"alg1": 1}}},
+ UserID.from_string(user),
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ self.assertEqual(
+ res,
+ {
+ "failures": {},
+ "one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}},
+ },
+ )
+ # second claim should return fallback key
+ res2 = self.get_success(
+ self.hs.get_e2e_keys_handler().claim_one_time_keys(
+ {user: {device_id: {"alg1": 1}}},
+ UserID.from_string(user),
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ self.assertEqual(
+ res2,
+ {
+ "failures": {},
+ "one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}},
+ },
+ )
+
# create another device for the user
(
new_device_id,
@@ -348,10 +408,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
expected_content = {"body": "test_message"}
self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
+
+ # fetch messages again and make sure that the message was not deleted
+ channel = self.make_request(
+ "POST",
+ f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+ content={},
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
next_batch_token = channel.json_body.get("next_batch")
- # fetch messages again and make sure that the message was deleted and we are returned an
- # empty array
+ # make sure fetching messages with next batch token works - there are no unfetched
+ # messages so we should receive an empty array
content = {"next_batch": next_batch_token}
channel = self.make_request(
"POST",
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 180b635ea6..4e0a387bd3 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -627,8 +627,8 @@ class RedactionsTestCase(HomeserverTestCase):
redact_event = timeline[-1]
self.assertEqual(redact_event["type"], EventTypes.Redaction)
# The redacts key should be in the content and the redacts keys.
- self.assertEquals(redact_event["content"]["redacts"], event_id)
- self.assertEquals(redact_event["redacts"], event_id)
+ self.assertEqual(redact_event["content"]["redacts"], event_id)
+ self.assertEqual(redact_event["redacts"], event_id)
# But it isn't actually part of the event.
def get_event(txn: LoggingTransaction) -> JsonDict:
@@ -642,10 +642,10 @@ class RedactionsTestCase(HomeserverTestCase):
event_json = self.get_success(
main_datastore.db_pool.runInteraction("get_event", get_event)
)
- self.assertEquals(event_json["type"], EventTypes.Redaction)
+ self.assertEqual(event_json["type"], EventTypes.Redaction)
if expect_content:
self.assertNotIn("redacts", event_json)
- self.assertEquals(event_json["content"]["redacts"], event_id)
+ self.assertEqual(event_json["content"]["redacts"], event_id)
else:
- self.assertEquals(event_json["redacts"], event_id)
+ self.assertEqual(event_json["redacts"], event_id)
self.assertNotIn("redacts", event_json["content"])
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 75439416c1..9bfe913e45 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -129,7 +129,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
return [ev["event_id"] for ev in channel.json_body["chunk"]]
def _get_bundled_aggregations(self) -> JsonDict:
@@ -142,7 +142,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
f"/_matrix/client/v3/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
return channel.json_body["unsigned"].get("m.relations", {})
def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
@@ -1602,7 +1602,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
threads = channel.json_body["chunk"]
return [
(
@@ -1634,7 +1634,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
##################################################
# Check the test data is configured as expected. #
##################################################
- self.assertEquals(self._get_related_events(), list(reversed(thread_replies)))
+ self.assertEqual(self._get_related_events(), list(reversed(thread_replies)))
relations = self._get_bundled_aggregations()
self.assertDictContainsSubset(
{"count": 3, "current_user_participated": True},
@@ -1655,7 +1655,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(thread_replies.pop())
# The thread should still exist, but the latest event should be updated.
- self.assertEquals(self._get_related_events(), list(reversed(thread_replies)))
+ self.assertEqual(self._get_related_events(), list(reversed(thread_replies)))
relations = self._get_bundled_aggregations()
self.assertDictContainsSubset(
{"count": 2, "current_user_participated": True},
@@ -1674,7 +1674,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(thread_replies.pop(0))
# Nothing should have changed (except the thread count).
- self.assertEquals(self._get_related_events(), thread_replies)
+ self.assertEqual(self._get_related_events(), thread_replies)
relations = self._get_bundled_aggregations()
self.assertDictContainsSubset(
{"count": 1, "current_user_participated": True},
@@ -1691,11 +1691,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# Redact the last remaining event. #
####################################
self._redact(thread_replies.pop(0))
- self.assertEquals(thread_replies, [])
+ self.assertEqual(thread_replies, [])
# The event should no longer be considered a thread.
- self.assertEquals(self._get_related_events(), [])
- self.assertEquals(self._get_bundled_aggregations(), {})
+ self.assertEqual(self._get_related_events(), [])
+ self.assertEqual(self._get_bundled_aggregations(), {})
self.assertEqual(self._get_threads(), [])
def test_redact_parent_edit(self) -> None:
@@ -1749,8 +1749,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# The relations are returned.
event_ids = self._get_related_events()
relations = self._get_bundled_aggregations()
- self.assertEquals(event_ids, [related_event_id])
- self.assertEquals(
+ self.assertEqual(event_ids, [related_event_id])
+ self.assertEqual(
relations[RelationTypes.REFERENCE],
{"chunk": [{"event_id": related_event_id}]},
)
@@ -1772,7 +1772,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# The unredacted relation should still exist.
event_ids = self._get_related_events()
relations = self._get_bundled_aggregations()
- self.assertEquals(len(event_ids), 1)
+ self.assertEqual(len(event_ids), 1)
self.assertDictContainsSubset(
{
"count": 1,
@@ -1816,7 +1816,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
threads = self._get_threads(channel.json_body)
self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)])
@@ -1829,7 +1829,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
# Tuple of (thread ID, latest event ID) for each thread.
threads = self._get_threads(channel.json_body)
self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)])
@@ -1850,7 +1850,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2])
@@ -1864,7 +1864,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body)
@@ -1899,7 +1899,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(
thread_roots, [thread_3, thread_2, thread_1], channel.json_body
@@ -1911,7 +1911,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
@@ -1943,6 +1943,6 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 4f6347be15..88e579dc39 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -1362,7 +1362,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
# Ensure the event was persisted with the correct timestamp.
res = self.get_success(self.main_store.get_event(event_id))
- self.assertEquals(ts, res.origin_server_ts)
+ self.assertEqual(ts, res.origin_server_ts)
def test_send_state_event_ts(self) -> None:
"""Test sending a state event with a custom timestamp."""
@@ -1384,7 +1384,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
# Ensure the event was persisted with the correct timestamp.
res = self.get_success(self.main_store.get_event(event_id))
- self.assertEquals(ts, res.origin_server_ts)
+ self.assertEqual(ts, res.origin_server_ts)
def test_send_membership_event_ts(self) -> None:
"""Test sending a membership event with a custom timestamp."""
@@ -1406,7 +1406,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
# Ensure the event was persisted with the correct timestamp.
res = self.get_success(self.main_store.get_event(event_id))
- self.assertEquals(ts, res.origin_server_ts)
+ self.assertEqual(ts, res.origin_server_ts)
class RoomJoinRatelimitTestCase(RoomBase):
diff --git a/tests/server.py b/tests/server.py
index c84a524e8c..481fe34c5c 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -26,6 +26,7 @@ from typing import (
Any,
Awaitable,
Callable,
+ Deque,
Dict,
Iterable,
List,
@@ -41,7 +42,7 @@ from typing import (
from unittest.mock import Mock
import attr
-from typing_extensions import Deque, ParamSpec
+from typing_extensions import ParamSpec
from zope.interface import implementer
from twisted.internet import address, threads, udp
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 383da83dfb..f541f1d6be 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
from twisted.internet import defer, reactor
from twisted.internet.base import ReactorBase
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
-from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
+from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS, _RENEWAL_INTERVAL_MS
from synapse.util import Clock
from tests import unittest
@@ -380,8 +381,8 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase):
self.get_success(lock.__aenter__())
# Wait for ages with the lock, we should not be able to get the lock.
- self.reactor.advance(5 * _LOCK_TIMEOUT_MS / 1000)
- self.pump()
+ for _ in range(0, 10):
+ self.reactor.advance((_RENEWAL_INTERVAL_MS / 1000))
lock2 = self.get_success(
self.store.try_acquire_read_write_lock("name", "key", write=True)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 5e1324a169..71302facd1 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -40,7 +40,7 @@ from tests.test_utils import make_awaitable
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None:
- super(ApplicationServiceStoreTestCase, self).setUp()
+ super().setUp()
self.as_yaml_files: List[str] = []
@@ -71,7 +71,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
except Exception:
pass
- super(ApplicationServiceStoreTestCase, self).tearDown()
+ super().tearDown()
def _add_appservice(
self, as_token: str, id: str, url: str, hs_token: str, sender: str
@@ -110,7 +110,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None:
- super(ApplicationServiceTransactionStoreTestCase, self).setUp()
+ super().setUp()
self.as_yaml_files: List[str] = []
self.hs.config.appservice.app_service_config_files = self.as_yaml_files
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 27f450e22d..b8823d6993 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -20,7 +20,7 @@ from tests import unittest
class DataStoreTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None:
- super(DataStoreTestCase, self).setUp()
+ super().setUp()
self.store = self.hs.get_datastores().main
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 05ea802008..ba41459d08 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -48,6 +48,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
"creation_ts": 0,
"user_type": None,
"deactivated": 0,
+ "locked": 0,
"shadow_banned": 0,
"approved": 1,
},
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index f183c38477..52ffa91c81 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -318,14 +318,14 @@ class MessageSearchTest(HomeserverTestCase):
result = self.get_success(
store.search_msgs([self.room_id], query, ["content.body"])
)
- self.assertEquals(
+ self.assertEqual(
result["count"],
1 if expect_to_contain else 0,
f"expected '{query}' to match '{self.PHRASE}'"
if expect_to_contain
else f"'{query}' unexpectedly matched '{self.PHRASE}'",
)
- self.assertEquals(
+ self.assertEqual(
len(result["results"]),
1 if expect_to_contain else 0,
"results array length should match count",
@@ -336,14 +336,14 @@ class MessageSearchTest(HomeserverTestCase):
result = self.get_success(
store.search_rooms([self.room_id], query, ["content.body"], 10)
)
- self.assertEquals(
+ self.assertEqual(
result["count"],
1 if expect_to_contain else 0,
f"expected '{query}' to match '{self.PHRASE}'"
if expect_to_contain
else f"'{query}' unexpectedly matched '{self.PHRASE}'",
)
- self.assertEquals(
+ self.assertEqual(
len(result["results"]),
1 if expect_to_contain else 0,
"results array length should match count",
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 9ed330f554..a46c29ddf4 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -31,7 +31,7 @@ TEST_ROOM_ID = "!TEST:ROOM"
class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None:
- super(FilterEventsForServerTestCase, self).setUp()
+ super().setUp()
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self._storage_controllers = self.hs.get_storage_controllers()
|