diff --git a/.github/workflows/latest_deps.yml b/.github/workflows/latest_deps.yml
index 1a61d179d9..c537a5a60f 100644
--- a/.github/workflows/latest_deps.yml
+++ b/.github/workflows/latest_deps.yml
@@ -32,12 +32,15 @@ jobs:
with:
python-version: "3.x"
poetry-version: "1.2.0b1"
+ extras: "all"
# Dump installed versions for debugging.
- run: poetry run pip list > before.txt
# Upgrade all runtime dependencies only. This is intended to mimic a fresh
# `pip install matrix-synapse[all]` as closely as possible.
- run: poetry update --no-dev
- run: poetry run pip list > after.txt && (diff -u before.txt after.txt || true)
+ - name: Remove warn_unused_ignores from mypy config
+ run: sed '/warn_unused_ignores = True/d' -i mypy.ini
- run: poetry run mypy
trial:
runs-on: ubuntu-latest
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index cad4cb6d77..efa35b71df 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -20,13 +20,9 @@ jobs:
- run: scripts-dev/config-lint.sh
lint:
- # This does a vanilla `poetry install` - no extras. I'm slightly anxious
- # that we might skip some typechecks on code that uses extras. However,
- # I think the right way to fix this is to mark any extras needed for
- # typechecking as development dependencies. To detect this, we ought to
- # turn up mypy's strictness: disallow unknown imports and be accept fewer
- # uses of `Any`.
uses: "matrix-org/backend-meta/.github/workflows/python-poetry-ci.yml@v1"
+ with:
+ typechecking-extras: "all"
lint-crlf:
runs-on: ubuntu-latest
diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml
index 8fc1affb77..5f0671f350 100644
--- a/.github/workflows/twisted_trunk.yml
+++ b/.github/workflows/twisted_trunk.yml
@@ -24,6 +24,8 @@ jobs:
poetry remove twisted
poetry add --extras tls git+https://github.com/twisted/twisted.git#trunk
poetry install --no-interaction --extras "all test"
+ - name: Remove warn_unused_ignores from mypy config
+ run: sed '/warn_unused_ignores = True/d' -i mypy.ini
- run: poetry run mypy
trial:
diff --git a/README.rst b/README.rst
index d71d733679..80201d4eb1 100644
--- a/README.rst
+++ b/README.rst
@@ -296,8 +296,8 @@ directory of your choice::
Synapse has a number of external dependencies. We maintain a fixed development
environment using [poetry](https://python-poetry.org/). First, install poetry. We recommend
- pip install --user pipx
- pipx install poetry
+ | pip install --user pipx
+ | pipx install poetry
as described `here <https://python-poetry.org/docs/#installing-with-pipx>`_.
(See `poetry's installation docs <https://python-poetry.org/docs/#installation>`
diff --git a/changelog.d/12356.misc b/changelog.d/12356.misc
new file mode 100644
index 0000000000..43e1929106
--- /dev/null
+++ b/changelog.d/12356.misc
@@ -0,0 +1 @@
+Fix scripts-dev to pass typechecking.
\ No newline at end of file
diff --git a/changelog.d/12406.feature b/changelog.d/12406.feature
new file mode 100644
index 0000000000..e345afdee7
--- /dev/null
+++ b/changelog.d/12406.feature
@@ -0,0 +1 @@
+Add a module API to allow modules to change actions for existing push rules of local users.
diff --git a/changelog.d/12480.misc b/changelog.d/12480.misc
new file mode 100644
index 0000000000..18a85e7b15
--- /dev/null
+++ b/changelog.d/12480.misc
@@ -0,0 +1 @@
+Use supervisord to supervise Postgres and Caddy in the Complement image to reduce restart time.
\ No newline at end of file
diff --git a/changelog.d/12485.misc b/changelog.d/12485.misc
new file mode 100644
index 0000000000..e793d08e5e
--- /dev/null
+++ b/changelog.d/12485.misc
@@ -0,0 +1 @@
+Add some type hints to datastore.
\ No newline at end of file
diff --git a/changelog.d/12505.misc b/changelog.d/12505.misc
new file mode 100644
index 0000000000..a691d7962f
--- /dev/null
+++ b/changelog.d/12505.misc
@@ -0,0 +1 @@
+Use `make_awaitable` instead of `defer.succeed` for return values of mocks in tests.
diff --git a/changelog.d/12526.feature b/changelog.d/12526.feature
new file mode 100644
index 0000000000..c01596282c
--- /dev/null
+++ b/changelog.d/12526.feature
@@ -0,0 +1 @@
+Add new `enable_registration_token_3pid_bypass` configuration option to allow registrations via token as an alternative to verifying a 3pid.
\ No newline at end of file
diff --git a/changelog.d/12531.misc b/changelog.d/12531.misc
new file mode 100644
index 0000000000..412fc9b6dc
--- /dev/null
+++ b/changelog.d/12531.misc
@@ -0,0 +1 @@
+Remove unused `# type: ignore`s.
diff --git a/changelog.d/12541.docker b/changelog.d/12541.docker
new file mode 100644
index 0000000000..c3b9c31657
--- /dev/null
+++ b/changelog.d/12541.docker
@@ -0,0 +1 @@
+Explicitly opt-in to using [BuildKit-specific features](https://github.com/moby/buildkit/blob/master/frontend/dockerfile/docs/syntax.md) in the Dockerfile. This fixes issues with building images in some GitLab CI environments.
diff --git a/changelog.d/12544.bugfix b/changelog.d/12544.bugfix
new file mode 100644
index 0000000000..b5169cd831
--- /dev/null
+++ b/changelog.d/12544.bugfix
@@ -0,0 +1 @@
+Fix a bug where attempting to send a large amount of read receipts to an application service all at once would result in duplicate content and abnormally high memory usage. Contributed by Brad & Nick @ Beeper.
diff --git a/changelog.d/12556.misc b/changelog.d/12556.misc
new file mode 100644
index 0000000000..dc245397fb
--- /dev/null
+++ b/changelog.d/12556.misc
@@ -0,0 +1 @@
+Release script: confirm the commit to be tagged before tagging.
diff --git a/changelog.d/12564.misc b/changelog.d/12564.misc
new file mode 100644
index 0000000000..207c322464
--- /dev/null
+++ b/changelog.d/12564.misc
@@ -0,0 +1 @@
+Consistently check if an object is a `frozendict`.
diff --git a/changelog.d/12576.misc b/changelog.d/12576.misc
new file mode 100644
index 0000000000..71022c8633
--- /dev/null
+++ b/changelog.d/12576.misc
@@ -0,0 +1 @@
+Allow unused `#type: ignore` comments in bleeding edge CI jobs.
diff --git a/changelog.d/12579.doc b/changelog.d/12579.doc
new file mode 100644
index 0000000000..bcec5fe1af
--- /dev/null
+++ b/changelog.d/12579.doc
@@ -0,0 +1 @@
+Add missing linebreak to pipx install instructions.
diff --git a/changelog.d/12580.bugfix b/changelog.d/12580.bugfix
new file mode 100644
index 0000000000..bedce405e2
--- /dev/null
+++ b/changelog.d/12580.bugfix
@@ -0,0 +1 @@
+Fix a long standing bug where status codes would almost always get logged as 200!, irrespective of the actual status code, when clients disconnect before a request has finished processing.
diff --git a/changelog.d/12581.misc b/changelog.d/12581.misc
new file mode 100644
index 0000000000..38d40b262b
--- /dev/null
+++ b/changelog.d/12581.misc
@@ -0,0 +1 @@
+Improve docstrings for the receipts store.
diff --git a/changelog.d/12582.misc b/changelog.d/12582.misc
new file mode 100644
index 0000000000..5fa9c9afe8
--- /dev/null
+++ b/changelog.d/12582.misc
@@ -0,0 +1 @@
+Use constants for read-receipts in tests.
diff --git a/changelog.d/12589.misc b/changelog.d/12589.misc
new file mode 100644
index 0000000000..d362828d2e
--- /dev/null
+++ b/changelog.d/12589.misc
@@ -0,0 +1 @@
+Remove special-case for `twisted` logger from default log config.
diff --git a/changelog.d/12594.bugfix b/changelog.d/12594.bugfix
new file mode 100644
index 0000000000..7411d9c079
--- /dev/null
+++ b/changelog.d/12594.bugfix
@@ -0,0 +1 @@
+Fix race when persisting an event and deleting a room that could lead to outbound federation breaking.
diff --git a/changelog.d/12608.misc b/changelog.d/12608.misc
new file mode 100644
index 0000000000..38272118fb
--- /dev/null
+++ b/changelog.d/12608.misc
@@ -0,0 +1 @@
+Remove redundant lines of config from `mypy.ini`.
\ No newline at end of file
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 4523c60645..ccc6a9f778 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -1,3 +1,4 @@
+# syntax=docker/dockerfile:1
# Dockerfile to build the matrixdotorg/synapse docker images.
#
# Note that it uses features which are only available in BuildKit - see
diff --git a/docker/Dockerfile-workers b/docker/Dockerfile-workers
index 9ccb2b22a7..24b03585f9 100644
--- a/docker/Dockerfile-workers
+++ b/docker/Dockerfile-workers
@@ -20,6 +20,9 @@ RUN rm /etc/nginx/sites-enabled/default
# Copy Synapse worker, nginx and supervisord configuration template files
COPY ./docker/conf-workers/* /conf/
+# Copy a script to prefix log lines with the supervisor program name
+COPY ./docker/prefix-log /usr/local/bin/
+
# Expose nginx listener port
EXPOSE 8080/tcp
diff --git a/docker/complement/SynapseWorkers.Dockerfile b/docker/complement/SynapseWorkers.Dockerfile
index 65df2d114d..9a4438e730 100644
--- a/docker/complement/SynapseWorkers.Dockerfile
+++ b/docker/complement/SynapseWorkers.Dockerfile
@@ -34,13 +34,16 @@ WORKDIR /data
# Copy the caddy config
COPY conf-workers/caddy.complement.json /root/caddy.json
+COPY conf-workers/postgres.supervisord.conf /etc/supervisor/conf.d/postgres.conf
+COPY conf-workers/caddy.supervisord.conf /etc/supervisor/conf.d/caddy.conf
+
# Copy the entrypoint
COPY conf-workers/start-complement-synapse-workers.sh /
# Expose caddy's listener ports
EXPOSE 8008 8448
-ENTRYPOINT /start-complement-synapse-workers.sh
+ENTRYPOINT ["/start-complement-synapse-workers.sh"]
# Update the healthcheck to have a shorter check interval
HEALTHCHECK --start-period=5s --interval=1s --timeout=1s \
diff --git a/docker/complement/conf-workers/caddy.supervisord.conf b/docker/complement/conf-workers/caddy.supervisord.conf
new file mode 100644
index 0000000000..d9ddb51dac
--- /dev/null
+++ b/docker/complement/conf-workers/caddy.supervisord.conf
@@ -0,0 +1,7 @@
+[program:caddy]
+command=/usr/local/bin/prefix-log /root/caddy run --config /root/caddy.json
+autorestart=unexpected
+stdout_logfile=/dev/stdout
+stdout_logfile_maxbytes=0
+stderr_logfile=/dev/stderr
+stderr_logfile_maxbytes=0
diff --git a/docker/complement/conf-workers/postgres.supervisord.conf b/docker/complement/conf-workers/postgres.supervisord.conf
new file mode 100644
index 0000000000..5608342d1a
--- /dev/null
+++ b/docker/complement/conf-workers/postgres.supervisord.conf
@@ -0,0 +1,16 @@
+[program:postgres]
+command=/usr/local/bin/prefix-log /usr/bin/pg_ctlcluster 13 main start --foreground
+
+# Lower priority number = starts first
+priority=1
+
+autorestart=unexpected
+stdout_logfile=/dev/stdout
+stdout_logfile_maxbytes=0
+stderr_logfile=/dev/stderr
+stderr_logfile_maxbytes=0
+
+# Use 'Fast Shutdown' mode which aborts current transactions and closes connections quickly.
+# (Default (TERM) is 'Smart Shutdown' which stops accepting new connections but
+# lets existing connections close gracefully.)
+stopsignal=INT
diff --git a/docker/complement/conf-workers/start-complement-synapse-workers.sh b/docker/complement/conf-workers/start-complement-synapse-workers.sh
index 2c1e05bd62..b9a6b55bbe 100755
--- a/docker/complement/conf-workers/start-complement-synapse-workers.sh
+++ b/docker/complement/conf-workers/start-complement-synapse-workers.sh
@@ -12,12 +12,6 @@ function log {
# Replace the server name in the caddy config
sed -i "s/{{ server_name }}/${SERVER_NAME}/g" /root/caddy.json
-log "starting postgres"
-pg_ctlcluster 13 main start
-
-log "starting caddy"
-/root/caddy start --config /root/caddy.json
-
# Set the server name of the homeserver
export SYNAPSE_SERVER_NAME=${SERVER_NAME}
diff --git a/docker/conf/log.config b/docker/conf/log.config
index 7a216a36a0..dc8c70befd 100644
--- a/docker/conf/log.config
+++ b/docker/conf/log.config
@@ -2,11 +2,7 @@ version: 1
formatters:
precise:
-{% if worker_name %}
- format: '%(asctime)s - worker:{{ worker_name }} - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
-{% else %}
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
-{% endif %}
handlers:
{% if LOG_FILE_PATH %}
diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py
index 3bda6c300b..33fc20d218 100755
--- a/docker/configure_workers_and_start.py
+++ b/docker/configure_workers_and_start.py
@@ -171,7 +171,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
# Templates for sections that may be inserted multiple times in config files
SUPERVISORD_PROCESS_CONFIG_BLOCK = """
[program:synapse_{name}]
-command=/usr/local/bin/python -m {app} \
+command=/usr/local/bin/prefix-log /usr/local/bin/python -m {app} \
--config-path="{config_path}" \
--config-path=/conf/workers/shared.yaml \
--config-path=/conf/workers/{name}.yaml
diff --git a/docker/prefix-log b/docker/prefix-log
new file mode 100755
index 0000000000..0e26a4f19d
--- /dev/null
+++ b/docker/prefix-log
@@ -0,0 +1,12 @@
+#!/bin/bash
+#
+# Prefixes all lines on stdout and stderr with the process name (as determined by
+# the SUPERVISOR_PROCESS_NAME env var, which is automatically set by Supervisor).
+#
+# Usage:
+# prefix-log command [args...]
+#
+
+exec 1> >(awk '{print "'"${SUPERVISOR_PROCESS_NAME}"' | "$0}' >&1)
+exec 2> >(awk '{print "'"${SUPERVISOR_PROCESS_NAME}"' | "$0}' >&2)
+exec "$@"
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index b8d8c0dbf0..67184c6b1a 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1323,6 +1323,12 @@ oembed:
#
#registration_requires_token: true
+# Allow users to submit a token during registration to bypass any required 3pid
+# steps configured in `registrations_require_3pid`.
+# Defaults to false, requiring that registration tokens (if enabled) complete a 3pid flow.
+#
+#enable_registration_token_3pid_bypass: false
+
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#
diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml
index 2485ad25ed..3065a0e2d9 100644
--- a/docs/sample_log_config.yaml
+++ b/docs/sample_log_config.yaml
@@ -62,13 +62,6 @@ loggers:
# information such as access tokens.
level: INFO
- twisted:
- # We send the twisted logging directly to the file handler,
- # to work around https://github.com/matrix-org/synapse/issues/3471
- # when using "buffer" logger. Use "console" to log to stderr instead.
- handlers: [file]
- propagate: false
-
root:
level: INFO
diff --git a/mypy.ini b/mypy.ini
index a663bf6975..78699e3187 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -7,6 +7,7 @@ show_error_codes = True
show_traceback = True
mypy_path = stubs
warn_unreachable = True
+warn_unused_ignores = True
local_partial_types = True
no_implicit_optional = True
@@ -23,10 +24,6 @@ files =
# https://docs.python.org/3/library/re.html#re.X
exclude = (?x)
^(
- |scripts-dev/build_debian_packages.py
- |scripts-dev/federation_client.py
- |scripts-dev/release.py
-
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
@@ -134,6 +131,11 @@ disallow_untyped_defs = True
[mypy-synapse.metrics.*]
disallow_untyped_defs = True
+[mypy-synapse.metrics._reactor_metrics]
+# This module imports select.epoll. That exists on Linux, but doesn't on macOS.
+# See https://github.com/matrix-org/synapse/pull/11771.
+warn_unused_ignores = False
+
[mypy-synapse.module_api.*]
disallow_untyped_defs = True
@@ -239,63 +241,29 @@ disallow_untyped_defs = True
[mypy-authlib.*]
ignore_missing_imports = True
-[mypy-bcrypt]
-ignore_missing_imports = True
-
[mypy-canonicaljson]
ignore_missing_imports = True
[mypy-constantly]
ignore_missing_imports = True
-[mypy-daemonize]
-ignore_missing_imports = True
-
-[mypy-h11]
-ignore_missing_imports = True
-
-[mypy-hiredis]
-ignore_missing_imports = True
-
-[mypy-hyperlink]
-ignore_missing_imports = True
-
[mypy-ijson.*]
ignore_missing_imports = True
-[mypy-importlib_metadata.*]
-ignore_missing_imports = True
-
-[mypy-jaeger_client.*]
-ignore_missing_imports = True
-
-[mypy-josepy.*]
-ignore_missing_imports = True
-
-[mypy-jwt.*]
-ignore_missing_imports = True
-
[mypy-lxml]
ignore_missing_imports = True
[mypy-msgpack]
ignore_missing_imports = True
-[mypy-nacl.*]
-ignore_missing_imports = True
-
+# Note: WIP stubs available at
+# https://github.com/microsoft/python-type-stubs/tree/64934207f523ad6b611e6cfe039d85d7175d7d0d/netaddr
[mypy-netaddr]
ignore_missing_imports = True
[mypy-parameterized.*]
ignore_missing_imports = True
-[mypy-phonenumbers.*]
-ignore_missing_imports = True
-
-[mypy-prometheus_client.*]
-ignore_missing_imports = True
-
[mypy-pymacaroons.*]
ignore_missing_imports = True
@@ -308,23 +276,14 @@ ignore_missing_imports = True
[mypy-saml2.*]
ignore_missing_imports = True
-[mypy-sentry_sdk]
-ignore_missing_imports = True
-
[mypy-service_identity.*]
ignore_missing_imports = True
-[mypy-signedjson.*]
+[mypy-srvlookup.*]
ignore_missing_imports = True
[mypy-treq.*]
ignore_missing_imports = True
-[mypy-twisted.*]
-ignore_missing_imports = True
-
-[mypy-zope]
-ignore_missing_imports = True
-
[mypy-incremental.*]
ignore_missing_imports = True
diff --git a/poetry.lock b/poetry.lock
index 8c7af1fa1e..e27a44989c 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -309,14 +309,15 @@ smmap = ">=3.0.1,<6"
[[package]]
name = "gitpython"
-version = "3.1.14"
-description = "Python Git Library"
+version = "3.1.27"
+description = "GitPython is a python library used to interact with Git repositories"
category = "dev"
optional = false
-python-versions = ">=3.4"
+python-versions = ">=3.7"
[package.dependencies]
gitdb = ">=4.0.1,<5"
+typing-extensions = {version = ">=3.7.4.3", markers = "python_version < \"3.8\""}
[[package]]
name = "hiredis"
@@ -1316,6 +1317,14 @@ optional = false
python-versions = "*"
[[package]]
+name = "types-commonmark"
+version = "0.9.2"
+description = "Typing stubs for commonmark"
+category = "dev"
+optional = false
+python-versions = "*"
+
+[[package]]
name = "types-cryptography"
version = "3.3.15"
description = "Typing stubs for cryptography"
@@ -1553,7 +1562,7 @@ url_preview = ["lxml"]
[metadata]
lock-version = "1.1"
python-versions = "^3.7"
-content-hash = "f482a4f594a165dfe01ce253a22510d5faf38647ab0dcebc35789350cafd9bf0"
+content-hash = "3825cef058b8c9f520ef4b7acb92519be95db9a663a61c2e89a5fe431ed55655"
[metadata.files]
attrs = [
@@ -1766,8 +1775,8 @@ gitdb = [
{file = "gitdb-4.0.9.tar.gz", hash = "sha256:bac2fd45c0a1c9cf619e63a90d62bdc63892ef92387424b855792a6cabe789aa"},
]
gitpython = [
- {file = "GitPython-3.1.14-py3-none-any.whl", hash = "sha256:3283ae2fba31c913d857e12e5ba5f9a7772bbc064ae2bb09efafa71b0dd4939b"},
- {file = "GitPython-3.1.14.tar.gz", hash = "sha256:be27633e7509e58391f10207cd32b2a6cf5b908f92d9cd30da2e514e1137af61"},
+ {file = "GitPython-3.1.27-py3-none-any.whl", hash = "sha256:5b68b000463593e05ff2b261acff0ff0972df8ab1b70d3cdbd41b546c8b8fc3d"},
+ {file = "GitPython-3.1.27.tar.gz", hash = "sha256:1c885ce809e8ba2d88a29befeb385fcea06338d3640712b59ca623c220bb5704"},
]
hiredis = [
{file = "hiredis-2.0.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b4c8b0bc5841e578d5fb32a16e0c305359b987b850a06964bd5a62739d688048"},
@@ -2588,6 +2597,10 @@ types-bleach = [
{file = "types-bleach-4.1.4.tar.gz", hash = "sha256:2d30c2c4fb6854088ac636471352c9a51bf6c089289800d2a8060820a01cd43a"},
{file = "types_bleach-4.1.4-py3-none-any.whl", hash = "sha256:edffe173ed6d7b6f3543036a96204a9319c3bf6c3645917b14274e43f000cc9b"},
]
+types-commonmark = [
+ {file = "types-commonmark-0.9.2.tar.gz", hash = "sha256:b894b67750c52fd5abc9a40a9ceb9da4652a391d75c1b480bba9cef90f19fc86"},
+ {file = "types_commonmark-0.9.2-py3-none-any.whl", hash = "sha256:56f20199a1f9a2924443211a0ef97f8b15a8a956a7f4e9186be6950bf38d6d02"},
+]
types-cryptography = [
{file = "types-cryptography-3.3.15.tar.gz", hash = "sha256:a7983a75a7b88a18f88832008f0ef140b8d1097888ec1a0824ec8fb7e105273b"},
{file = "types_cryptography-3.3.15-py3-none-any.whl", hash = "sha256:d9b0dd5465d7898d400850e7f35e5518aa93a7e23d3e11757cd81b4777089046"},
diff --git a/pyproject.toml b/pyproject.toml
index f0f029f016..62e26fd95b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -251,6 +251,7 @@ flake8 = "*"
mypy = "==0.931"
mypy-zope = "==0.3.5"
types-bleach = ">=4.1.0"
+types-commonmark = ">=0.9.2"
types-jsonschema = ">=3.2.0"
types-opentracing = ">=2.4.2"
types-Pillow = ">=8.3.4"
@@ -270,7 +271,8 @@ idna = ">=2.5"
# The following are used by the release script
click = "==8.1.0"
-GitPython = "==3.1.14"
+# GitPython was == 3.1.14; bumped to 3.1.20, the first release with type hints.
+GitPython = ">=3.1.20"
commonmark = "==0.9.1"
pygithub = "==1.55"
# The following are executed as commands by the release script.
diff --git a/scripts-dev/build_debian_packages.py b/scripts-dev/build_debian_packages.py
index e3e6878686..38564893e9 100755
--- a/scripts-dev/build_debian_packages.py
+++ b/scripts-dev/build_debian_packages.py
@@ -17,7 +17,8 @@ import subprocess
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
-from typing import Optional, Sequence
+from types import FrameType
+from typing import Collection, Optional, Sequence, Set
DISTS = (
"debian:buster", # oldstable: EOL 2022-08
@@ -41,15 +42,17 @@ projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
class Builder(object):
def __init__(
- self, redirect_stdout=False, docker_build_args: Optional[Sequence[str]] = None
+ self,
+ redirect_stdout: bool = False,
+ docker_build_args: Optional[Sequence[str]] = None,
):
self.redirect_stdout = redirect_stdout
self._docker_build_args = tuple(docker_build_args or ())
- self.active_containers = set()
+ self.active_containers: Set[str] = set()
self._lock = threading.Lock()
self._failed = False
- def run_build(self, dist, skip_tests=False):
+ def run_build(self, dist: str, skip_tests: bool = False) -> None:
"""Build deb for a single distribution"""
if self._failed:
@@ -63,7 +66,7 @@ class Builder(object):
self._failed = True
raise
- def _inner_build(self, dist, skip_tests=False):
+ def _inner_build(self, dist: str, skip_tests: bool = False) -> None:
tag = dist.split(":", 1)[1]
# Make the dir where the debs will live.
@@ -138,7 +141,7 @@ class Builder(object):
stdout.close()
print("Completed build of %s" % (dist,))
- def kill_containers(self):
+ def kill_containers(self) -> None:
with self._lock:
active = list(self.active_containers)
@@ -156,8 +159,10 @@ class Builder(object):
self.active_containers.remove(c)
-def run_builds(builder, dists, jobs=1, skip_tests=False):
- def sig(signum, _frame):
+def run_builds(
+ builder: Builder, dists: Collection[str], jobs: int = 1, skip_tests: bool = False
+) -> None:
+ def sig(signum: int, _frame: Optional[FrameType]) -> None:
print("Caught SIGINT")
builder.kill_containers()
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 079d2f5ed0..763dd02c47 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -38,7 +38,7 @@ import argparse
import base64
import json
import sys
-from typing import Any, Optional
+from typing import Any, Dict, Optional, Tuple
from urllib import parse as urlparse
import requests
@@ -47,13 +47,14 @@ import signedjson.types
import srvlookup
import yaml
from requests.adapters import HTTPAdapter
+from urllib3 import HTTPConnectionPool
# uncomment the following to enable debug logging of http requests
# from httplib import HTTPConnection
# HTTPConnection.debuglevel = 1
-def encode_base64(input_bytes):
+def encode_base64(input_bytes: bytes) -> str:
"""Encode bytes as a base64 string without any padding."""
input_len = len(input_bytes)
@@ -63,7 +64,7 @@ def encode_base64(input_bytes):
return output_string
-def encode_canonical_json(value):
+def encode_canonical_json(value: object) -> bytes:
return json.dumps(
value,
# Encode code-points outside of ASCII as UTF-8 rather than \u escapes
@@ -130,7 +131,7 @@ def request(
sig,
destination,
)
- authorization_headers.append(header.encode("ascii"))
+ authorization_headers.append(header)
print("Authorization: %s" % header, file=sys.stderr)
dest = "matrix://%s%s" % (destination, path)
@@ -139,7 +140,10 @@ def request(
s = requests.Session()
s.mount("matrix://", MatrixConnectionAdapter())
- headers = {"Host": destination, "Authorization": authorization_headers[0]}
+ headers: Dict[str, str] = {
+ "Host": destination,
+ "Authorization": authorization_headers[0],
+ }
if method == "POST":
headers["Content-Type"] = "application/json"
@@ -154,7 +158,7 @@ def request(
)
-def main():
+def main() -> None:
parser = argparse.ArgumentParser(
description="Signs and sends a federation request to a matrix homeserver"
)
@@ -212,6 +216,7 @@ def main():
if not args.server_name or not args.signing_key:
read_args_from_config(args)
+ assert isinstance(args.signing_key, str)
algorithm, version, key_base64 = args.signing_key.split()
key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64)
@@ -233,7 +238,7 @@ def main():
print("")
-def read_args_from_config(args):
+def read_args_from_config(args: argparse.Namespace) -> None:
with open(args.config, "r") as fh:
config = yaml.safe_load(fh)
@@ -250,7 +255,7 @@ def read_args_from_config(args):
class MatrixConnectionAdapter(HTTPAdapter):
@staticmethod
- def lookup(s, skip_well_known=False):
+ def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]:
if s[-1] == "]":
# ipv6 literal (with no port)
return s, 8448
@@ -276,7 +281,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
return s, 8448
@staticmethod
- def get_well_known(server_name):
+ def get_well_known(server_name: str) -> Optional[str]:
uri = "https://%s/.well-known/matrix/server" % (server_name,)
print("fetching %s" % (uri,), file=sys.stderr)
@@ -299,7 +304,9 @@ class MatrixConnectionAdapter(HTTPAdapter):
print("Invalid response from %s: %s" % (uri, e), file=sys.stderr)
return None
- def get_connection(self, url, proxies=None):
+ def get_connection(
+ self, url: str, proxies: Optional[Dict[str, str]] = None
+ ) -> HTTPConnectionPool:
parsed = urlparse.urlparse(url)
(host, port) = self.lookup(parsed.netloc)
diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index 1217e14874..c775865212 100644
--- a/scripts-dev/mypy_synapse_plugin.py
+++ b/scripts-dev/mypy_synapse_plugin.py
@@ -16,7 +16,7 @@
can crop up, e.g the cache descriptors.
"""
-from typing import Callable, Optional
+from typing import Callable, Optional, Type
from mypy.nodes import ARG_NAMED_OPT
from mypy.plugin import MethodSigContext, Plugin
@@ -94,7 +94,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
return signature
-def plugin(version: str):
+def plugin(version: str) -> Type[SynapsePlugin]:
# This is the entry point of the plugin, and let's us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version
# string.
diff --git a/scripts-dev/release.py b/scripts-dev/release.py
index 9d7c7c445f..f4269e09bb 100755
--- a/scripts-dev/release.py
+++ b/scripts-dev/release.py
@@ -25,7 +25,7 @@ import sys
import urllib.request
from os import path
from tempfile import TemporaryDirectory
-from typing import List, Optional
+from typing import Any, List, Optional, cast
import attr
import click
@@ -36,7 +36,9 @@ from github import Github
from packaging import version
-def run_until_successful(command, *args, **kwargs):
+def run_until_successful(
+ command: str, *args: Any, **kwargs: Any
+) -> subprocess.CompletedProcess:
while True:
completed_process = subprocess.run(command, *args, **kwargs)
exit_code = completed_process.returncode
@@ -50,7 +52,7 @@ def run_until_successful(command, *args, **kwargs):
@click.group()
-def cli():
+def cli() -> None:
"""An interactive script to walk through the parts of creating a release.
Requires the dev dependencies be installed, which can be done via:
@@ -81,19 +83,13 @@ def cli():
@cli.command()
-def prepare():
+def prepare() -> None:
"""Do the initial stages of creating a release, including creating release
branch, updating changelog and pushing to GitHub.
"""
# Make sure we're in a git repo.
- try:
- repo = git.Repo()
- except git.InvalidGitRepositoryError:
- raise click.ClickException("Not in Synapse repo.")
-
- if repo.is_dirty():
- raise click.ClickException("Uncommitted changes exist.")
+ repo = get_repo_and_check_clean_checkout()
click.secho("Updating git repo...")
repo.remote().fetch()
@@ -161,22 +157,21 @@ def prepare():
click.get_current_context().abort()
# Switch to the release branch.
- parsed_new_version: version.Version = version.parse(new_version)
+ # Cast safety: parse() won't return a version.LegacyVersion from our
+ # version string format.
+ parsed_new_version = cast(version.Version, version.parse(new_version))
# We assume for debian changelogs that we only do RCs or full releases.
assert not parsed_new_version.is_devrelease
assert not parsed_new_version.is_postrelease
- release_branch_name = (
- f"release-v{parsed_new_version.major}.{parsed_new_version.minor}"
- )
+ release_branch_name = get_release_branch_name(parsed_new_version)
release_branch = find_ref(repo, release_branch_name)
if release_branch:
if release_branch.is_remote():
# If the release branch only exists on the remote we check it out
# locally.
repo.git.checkout(release_branch_name)
- release_branch = repo.active_branch
else:
# If a branch doesn't exist we create one. We ask which one branch it
# should be based off, defaulting to sensible values depending on the
@@ -198,13 +193,15 @@ def prepare():
click.get_current_context().abort()
# Check out the base branch and ensure it's up to date
- repo.head.reference = base_branch
+ repo.head.set_reference(base_branch, "check out the base branch")
repo.head.reset(index=True, working_tree=True)
if not base_branch.is_remote():
update_branch(repo)
# Create the new release branch
- release_branch = repo.create_head(release_branch_name, commit=base_branch)
+ # Type ignore will no longer be needed after GitPython 3.1.28.
+ # See https://github.com/gitpython-developers/GitPython/pull/1419
+ repo.create_head(release_branch_name, commit=base_branch) # type: ignore[arg-type]
# Switch to the release branch and ensure it's up to date.
repo.git.checkout(release_branch_name)
@@ -265,17 +262,11 @@ def prepare():
@cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"])
-def tag(gh_token: Optional[str]):
+def tag(gh_token: Optional[str]) -> None:
"""Tags the release and generates a draft GitHub release"""
# Make sure we're in a git repo.
- try:
- repo = git.Repo()
- except git.InvalidGitRepositoryError:
- raise click.ClickException("Not in Synapse repo.")
-
- if repo.is_dirty():
- raise click.ClickException("Uncommitted changes exist.")
+ repo = get_repo_and_check_clean_checkout()
click.secho("Updating git repo...")
repo.remote().fetch()
@@ -288,12 +279,26 @@ def tag(gh_token: Optional[str]):
if tag_name in repo.tags:
raise click.ClickException(f"Tag {tag_name} already exists!\n")
+ # Check we're on the right release branch
+ release_branch = get_release_branch_name(current_version)
+ if repo.active_branch.name != release_branch:
+ click.echo(
+ f"Need to be on the release branch ({release_branch}) before tagging. "
+ f"Currently on ({repo.active_branch.name})."
+ )
+ click.get_current_context().abort()
+
# Get the appropriate changelogs and tag.
changes = get_changes_for_version(current_version)
click.echo_via_pager(changes)
if click.confirm("Edit text?", default=False):
- changes = click.edit(changes, require_save=False)
+ edited_changes = click.edit(changes, require_save=False)
+ # This assert is for mypy's benefit. click's docs are a little unclear, but
+ # when `require_save=False`, not saving the temp file in the editor returns
+ # the original string.
+ assert edited_changes is not None
+ changes = edited_changes
repo.create_tag(tag_name, message=changes, sign=True)
@@ -347,22 +352,16 @@ def tag(gh_token: Optional[str]):
@cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=True)
-def publish(gh_token: str):
- """Publish release."""
+def publish(gh_token: str) -> None:
+ """Publish release on GitHub."""
# Make sure we're in a git repo.
- try:
- repo = git.Repo()
- except git.InvalidGitRepositoryError:
- raise click.ClickException("Not in Synapse repo.")
-
- if repo.is_dirty():
- raise click.ClickException("Uncommitted changes exist.")
+ get_repo_and_check_clean_checkout()
current_version = get_package_version()
tag_name = f"v{current_version}"
- if not click.confirm(f"Publish {tag_name}?", default=True):
+ if not click.confirm(f"Publish release {tag_name} on GitHub?", default=True):
return
# Publish the draft release
@@ -390,12 +389,19 @@ def publish(gh_token: str):
@cli.command()
-def upload():
+def upload() -> None:
"""Upload release to pypi."""
current_version = get_package_version()
tag_name = f"v{current_version}"
+ # Check we have the right tag checked out.
+ repo = get_repo_and_check_clean_checkout()
+ tag = repo.tag(f"refs/tags/{tag_name}")
+ if repo.head.commit != tag.commit:
+ click.echo("Tag {tag_name} (tag.commit) is not currently checked out!")
+ click.get_current_context().abort()
+
pypi_asset_names = [
f"matrix_synapse-{current_version}-py3-none-any.whl",
f"matrix-synapse-{current_version}.tar.gz",
@@ -418,7 +424,7 @@ def upload():
@cli.command()
-def announce():
+def announce() -> None:
"""Generate markdown to announce the release."""
current_version = get_package_version()
@@ -459,20 +465,36 @@ def get_package_version() -> version.Version:
return version.Version(version_string)
+def get_release_branch_name(version_number: version.Version) -> str:
+ return f"release-v{version_number.major}.{version_number.minor}"
+
+
+def get_repo_and_check_clean_checkout() -> git.Repo:
+ """Get the project repo and check it's not got any uncommitted changes."""
+ try:
+ repo = git.Repo()
+ except git.InvalidGitRepositoryError:
+ raise click.ClickException("Not in Synapse repo.")
+ if repo.is_dirty():
+ raise click.ClickException("Uncommitted changes exist.")
+ return repo
+
+
def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]:
"""Find the branch/ref, looking first locally then in the remote."""
- if ref_name in repo.refs:
- return repo.refs[ref_name]
+ if ref_name in repo.references:
+ return repo.references[ref_name]
elif ref_name in repo.remote().refs:
return repo.remote().refs[ref_name]
else:
return None
-def update_branch(repo: git.Repo):
+def update_branch(repo: git.Repo) -> None:
"""Ensure branch is up to date if it has a remote"""
- if repo.active_branch.tracking_branch():
- repo.git.merge(repo.active_branch.tracking_branch().name)
+ tracking_branch = repo.active_branch.tracking_branch()
+ if tracking_branch:
+ repo.git.merge(tracking_branch.name)
def get_changes_for_version(wanted_version: version.Version) -> str:
@@ -536,7 +558,9 @@ def get_changes_for_version(wanted_version: version.Version) -> str:
return "\n".join(version_changelog)
-def generate_and_write_changelog(current_version: version.Version, new_version: str):
+def generate_and_write_changelog(
+ current_version: version.Version, new_version: str
+) -> None:
# We do this by getting a draft so that we can edit it before writing to the
# changelog.
result = run_until_successful(
@@ -558,8 +582,8 @@ def generate_and_write_changelog(current_version: version.Version, new_version:
f.write(existing_content)
# Remove all the news fragments
- for f in glob.iglob("changelog.d/*.*"):
- os.remove(f)
+ for filename in glob.iglob("changelog.d/*.*"):
+ os.remove(filename)
if __name__ == "__main__":
diff --git a/scripts-dev/sign_json.py b/scripts-dev/sign_json.py
index 9459543106..bb217799fb 100755
--- a/scripts-dev/sign_json.py
+++ b/scripts-dev/sign_json.py
@@ -27,7 +27,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.util import json_encoder
-def main():
+def main() -> None:
parser = argparse.ArgumentParser(
description="""Adds a signature to a JSON object.
diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi
index e18d617281..3a4f9c3076 100644
--- a/stubs/sortedcontainers/sorteddict.pyi
+++ b/stubs/sortedcontainers/sorteddict.pyi
@@ -115,9 +115,7 @@ class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]):
def __getitem__(self, index: slice) -> List[_KT_co]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
-class SortedItemsView( # type: ignore
- ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]]
-):
+class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]]):
def __iter__(self) -> Iterator[Tuple[_KT_co, _VT_co]]: ...
@overload
def __getitem__(self, index: int) -> Tuple[_KT_co, _VT_co]: ...
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 37321f9133..d28b87a3f4 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -48,7 +48,6 @@ from twisted.logger import LoggingFile, LogLevel
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.python.threadpool import ThreadPool
-import synapse
from synapse.api.constants import MAX_PDU_SIZE
from synapse.app import check_bind_error
from synapse.app.phone_stats_home import start_phone_stats_home
@@ -60,6 +59,7 @@ from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.logging.context import PreserveLoggingContext
+from synapse.logging.opentracing import init_tracer
from synapse.metrics import install_gc_manager, register_threadpool
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
@@ -431,7 +431,7 @@ async def start(hs: "HomeServer") -> None:
refresh_certificate(hs)
# Start the tracer
- synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
+ init_tracer(hs) # noqa
# Instantiate the modules so they can register their web resources to the module API
# before we start the listeners.
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 99db9e1e39..470b8b4492 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -110,13 +110,6 @@ loggers:
# information such as access tokens.
level: INFO
- twisted:
- # We send the twisted logging directly to the file handler,
- # to work around https://github.com/matrix-org/synapse/issues/3471
- # when using "buffer" logger. Use "console" to log to stderr instead.
- handlers: [file]
- propagate: false
-
root:
level: INFO
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 39e9acb62a..70eb7e6a97 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -43,6 +43,9 @@ class RegistrationConfig(Config):
self.registration_requires_token = config.get(
"registration_requires_token", False
)
+ self.enable_registration_token_3pid_bypasss = config.get(
+ "enable_registration_token_3pid_bypasss", False
+ )
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -309,6 +312,12 @@ class RegistrationConfig(Config):
#
#registration_requires_token: true
+ # Allow users to submit a token during registration to bypass any required 3pid
+ # steps configured in `registrations_require_3pid`.
+ # Defaults to false, requiring that registration tokens (if enabled) complete a 3pid flow.
+ #
+ #enable_registration_token_3pid_bypass: false
+
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#
diff --git a/synapse/config/server.py b/synapse/config/server.py
index d771045b52..b6cd326416 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -186,7 +186,7 @@ KNOWN_RESOURCES = {
class HttpResourceConfig:
names: List[str] = attr.ib(
factory=list,
- validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), # type: ignore
+ validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)),
)
compress: bool = attr.ib(
default=False,
@@ -231,9 +231,7 @@ class ManholeConfig:
class LimitRemoteRoomsConfig:
enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False)
complexity: Union[float, int] = attr.ib(
- validator=attr.validators.instance_of(
- (float, int) # type: ignore[arg-type] # noqa
- ),
+ validator=attr.validators.instance_of((float, int)), # noqa
default=1.0,
)
complexity_error: str = attr.ib(
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index f8d3ba5456..a6c48308b3 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -27,7 +27,6 @@ from typing import (
)
import attr
-from frozendict import frozendict
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
@@ -204,7 +203,9 @@ def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
key_to_move = field.pop(-1)
sub_dict = src
for sub_field in field: # e.g. sub_field => "content"
- if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]:
+ if sub_field in sub_dict and isinstance(
+ sub_dict[sub_field], collections.abc.Mapping
+ ):
sub_dict = sub_dict[sub_field]
else:
return
@@ -622,7 +623,7 @@ def validate_canonicaljson(value: Any) -> None:
# Note that Infinity, -Infinity, and NaN are also considered floats.
raise SynapseError(400, "Bad JSON value: float", Codes.BAD_JSON)
- elif isinstance(value, (dict, frozendict)):
+ elif isinstance(value, collections.abc.Mapping):
for v in value.values():
validate_canonicaljson(v)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index beab1227b8..884b5d60b4 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -268,8 +268,8 @@ class FederationServer(FederationBase):
transaction_id=transaction_id,
destination=destination,
origin=origin,
- origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore
- pdus=transaction_data.get("pdus"), # type: ignore
+ origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore[arg-type]
+ pdus=transaction_data.get("pdus"),
edus=transaction_data.get("edus"),
)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 1421050b9a..9ce06dfa28 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -229,21 +229,21 @@ class TransportLayerClient:
"""
logger.debug(
"send_data dest=%s, txid=%s",
- transaction.destination, # type: ignore
- transaction.transaction_id, # type: ignore
+ transaction.destination,
+ transaction.transaction_id,
)
- if transaction.destination == self.server_name: # type: ignore
+ if transaction.destination == self.server_name:
raise RuntimeError("Transport layer cannot send to itself!")
# FIXME: This is only used by the tests. The actual json sent is
# generated by the json_data_callback.
json_data = transaction.get_dict()
- path = _create_v1_path("/send/%s", transaction.transaction_id) # type: ignore
+ path = _create_v1_path("/send/%s", transaction.transaction_id)
return await self.client.put_json(
- transaction.destination, # type: ignore
+ transaction.destination,
path=path,
data=json_data,
json_data_callback=json_data_callback,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 1b57840506..b3894666cc 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -416,7 +416,7 @@ class ApplicationServicesHandler:
return typing
async def _handle_receipts(
- self, service: ApplicationService, new_token: Optional[int]
+ self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
"""
Return the latest read receipts that the given application service should receive.
@@ -447,7 +447,7 @@ class ApplicationServicesHandler:
receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as(
- service=service, from_key=from_key
+ service=service, from_key=from_key, to_key=new_token
)
return receipts
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 86991d26ce..22678d486d 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -481,7 +481,7 @@ class AuthHandler:
sid = authdict["session"]
# Convert the URI and method to strings.
- uri = request.uri.decode("utf-8") # type: ignore
+ uri = request.uri.decode("utf-8")
method = request.method.decode("utf-8")
# If there's no session ID, create a new session.
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 724b9cfcb4..f6ffb7d18d 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -966,7 +966,7 @@ class OidcProvider:
"Mapping provider does not support de-duplicating Matrix IDs"
)
- attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
+ attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token
)
diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py
new file mode 100644
index 0000000000..2599160bcc
--- /dev/null
+++ b/synapse/handlers/push_rules.py
@@ -0,0 +1,138 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, List, Optional, Union
+
+import attr
+
+from synapse.api.errors import SynapseError, UnrecognizedRequestError
+from synapse.push.baserules import BASE_RULE_IDS
+from synapse.storage.push_rule import RuleNotFoundException
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RuleSpec:
+ scope: str
+ template: str
+ rule_id: str
+ attr: Optional[str]
+
+
+class PushRulesHandler:
+ """A class to handle changes in push rules for users."""
+
+ def __init__(self, hs: "HomeServer"):
+ self._notifier = hs.get_notifier()
+ self._main_store = hs.get_datastores().main
+
+ async def set_rule_attr(
+ self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
+ ) -> None:
+ """Set an attribute (enabled or actions) on an existing push rule.
+
+ Notifies listeners (e.g. sync handler) of the change.
+
+ Args:
+ user_id: the user for which to modify the push rule.
+ spec: the spec of the push rule to modify.
+ val: the value to change the attribute to.
+
+ Raises:
+ RuleNotFoundException if the rule being modified doesn't exist.
+ SynapseError(400) if the value is malformed.
+ UnrecognizedRequestError if the attribute to change is unknown.
+ InvalidRuleException if we're trying to change the actions on a rule but
+ the provided actions aren't compliant with the spec.
+ """
+ if spec.attr not in ("enabled", "actions"):
+ # for the sake of potential future expansion, shouldn't report
+ # 404 in the case of an unknown request so check it corresponds to
+ # a known attribute first.
+ raise UnrecognizedRequestError()
+
+ namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
+ rule_id = spec.rule_id
+ is_default_rule = rule_id.startswith(".")
+ if is_default_rule:
+ if namespaced_rule_id not in BASE_RULE_IDS:
+ raise RuleNotFoundException("Unknown rule %r" % (namespaced_rule_id,))
+ if spec.attr == "enabled":
+ if isinstance(val, dict) and "enabled" in val:
+ val = val["enabled"]
+ if not isinstance(val, bool):
+ # Legacy fallback
+ # This should *actually* take a dict, but many clients pass
+ # bools directly, so let's not break them.
+ raise SynapseError(400, "Value for 'enabled' must be boolean")
+ await self._main_store.set_push_rule_enabled(
+ user_id, namespaced_rule_id, val, is_default_rule
+ )
+ elif spec.attr == "actions":
+ if not isinstance(val, dict):
+ raise SynapseError(400, "Value must be a dict")
+ actions = val.get("actions")
+ if not isinstance(actions, list):
+ raise SynapseError(400, "Value for 'actions' must be dict")
+ check_actions(actions)
+ rule_id = spec.rule_id
+ is_default_rule = rule_id.startswith(".")
+ if is_default_rule:
+ if namespaced_rule_id not in BASE_RULE_IDS:
+ raise RuleNotFoundException(
+ "Unknown rule %r" % (namespaced_rule_id,)
+ )
+ await self._main_store.set_push_rule_actions(
+ user_id, namespaced_rule_id, actions, is_default_rule
+ )
+ else:
+ raise UnrecognizedRequestError()
+
+ self.notify_user(user_id)
+
+ def notify_user(self, user_id: str) -> None:
+ """Notify listeners about a push rule change.
+
+ Args:
+ user_id: the user ID the change is for.
+ """
+ stream_id = self._main_store.get_max_push_rules_stream_id()
+ self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
+
+
+def check_actions(actions: List[Union[str, JsonDict]]) -> None:
+ """Check if the given actions are spec compliant.
+
+ Args:
+ actions: the actions to check.
+
+ Raises:
+ InvalidRuleException if the rules aren't compliant with the spec.
+ """
+ if not isinstance(actions, list):
+ raise InvalidRuleException("No actions found")
+
+ for a in actions:
+ if a in ["notify", "dont_notify", "coalesce"]:
+ pass
+ elif isinstance(a, dict) and "set_tweak" in a:
+ pass
+ else:
+ raise InvalidRuleException("Unrecognised action %s" % a)
+
+
+class InvalidRuleException(Exception):
+ pass
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 6250bb3bdf..cfe860decc 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -239,13 +239,14 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
return events, to_key
async def get_new_events_as(
- self, from_key: int, service: ApplicationService
+ self, from_key: int, to_key: int, service: ApplicationService
) -> Tuple[List[JsonDict], int]:
"""Returns a set of new read receipt events that an appservice
may be interested in.
Args:
from_key: the stream position at which events should be fetched from
+ to_key: the stream position up to which events should be fetched to
service: The appservice which may be interested
Returns:
@@ -255,7 +256,6 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
* The current read receipt stream token.
"""
from_key = int(from_key)
- to_key = self.get_current_key()
if from_key == to_key:
return [], to_key
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 5efb561273..b5dc9f74b3 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections.abc
import logging
from typing import (
TYPE_CHECKING,
@@ -24,7 +25,6 @@ from typing import (
)
import attr
-from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
@@ -380,7 +380,7 @@ class RelationsHandler:
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
- if isinstance(relates_to, (dict, frozendict)):
+ if isinstance(relates_to, collections.abc.Mapping):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
continue
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 102dd4b57d..5619f8f50e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -357,7 +357,7 @@ class SearchHandler:
itertools.chain(
# The events_before and events_after for each context.
itertools.chain.from_iterable(
- itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
+ itertools.chain(context["events_before"], context["events_after"])
for context in contexts.values()
),
# The returned events.
@@ -373,10 +373,10 @@ class SearchHandler:
for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events(
- context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
+ context["events_before"], time_now, bundle_aggregations=aggregations
)
context["events_after"] = self._event_serializer.serialize_events(
- context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
+ context["events_after"], time_now, bundle_aggregations=aggregations
)
results = [
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 472b029af3..e2a441066d 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -256,7 +256,9 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
- self._enabled = bool(hs.config.registration.registration_requires_token)
+ self._enabled = bool(
+ hs.config.registration.registration_requires_token
+ ) or bool(hs.config.registration.enable_registration_token_3pid_bypasss)
self.store = hs.get_datastores().main
def is_enabled(self) -> bool:
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 31ca841889..1cf49830e8 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -295,7 +295,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
if isawaitable(raw_callback_return):
callback_return = await raw_callback_return
else:
- callback_return = raw_callback_return # type: ignore
+ callback_return = raw_callback_return
return callback_return
@@ -469,7 +469,7 @@ class JsonResource(DirectServeJsonResource):
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
- callback_return = raw_callback_return # type: ignore
+ callback_return = raw_callback_return
return callback_return
@@ -683,6 +683,9 @@ def respond_with_json(
Returns:
twisted.web.server.NOT_DONE_YET if the request is still active.
"""
+ # The response code must always be set, for logging purposes.
+ request.setResponseCode(code)
+
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
@@ -697,7 +700,6 @@ def respond_with_json(
else:
encoder = _encode_json_bytes
- request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
@@ -728,13 +730,15 @@ def respond_with_json_bytes(
Returns:
twisted.web.server.NOT_DONE_YET if the request is still active.
"""
+ # The response code must always be set, for logging purposes.
+ request.setResponseCode(code)
+
if request._disconnected:
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
return None
- request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
@@ -840,6 +844,9 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> N
code: The HTTP response code.
html_bytes: The HTML bytes to use as the response body.
"""
+ # The response code must always be set, for logging purposes.
+ request.setResponseCode(code)
+
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
@@ -849,7 +856,6 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> N
)
return None
- request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 88cd8a9e1c..fd9cb97920 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -722,6 +722,11 @@ P = ParamSpec("P")
R = TypeVar("R")
+async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R:
+ """Unwraps an arbitrary awaitable by awaiting it."""
+ return await awaitable
+
+
@overload
def preserve_fn( # type: ignore[misc]
f: Callable[P, Awaitable[R]],
@@ -802,17 +807,20 @@ def run_in_background( # type: ignore[misc]
# by synchronous exceptions, so let's turn them into Failures.
return defer.fail()
+ # `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain
+ # value. Convert it to a `Deferred`.
if isinstance(res, typing.Coroutine):
+ # Wrap the coroutine in a `Deferred`.
res = defer.ensureDeferred(res)
-
- # At this point we should have a Deferred, if not then f was a synchronous
- # function, wrap it in a Deferred for consistency.
- if not isinstance(res, defer.Deferred):
- # `res` is not a `Deferred` and not a `Coroutine`.
- # There are no other types of `Awaitable`s we expect to encounter in Synapse.
- assert not isinstance(res, Awaitable)
-
- return defer.succeed(res)
+ elif isinstance(res, defer.Deferred):
+ pass
+ elif isinstance(res, Awaitable):
+ # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
+ # or `Future` from `make_awaitable`.
+ res = defer.ensureDeferred(_unwrap_awaitable(res))
+ else:
+ # `res` is a plain value. Wrap it in a `Deferred`.
+ res = defer.succeed(res)
if res.called and not res.paused:
# The function should have maintained the logcontext, so we can
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 8f9e629274..834fe1b62c 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -82,6 +82,7 @@ from synapse.handlers.auth import (
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
+from synapse.handlers.push_rules import RuleSpec, check_actions
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
DirectServeHtmlResource,
@@ -109,6 +110,7 @@ from synapse.storage.state import StateFilter
from synapse.types import (
DomainSpecificString,
JsonDict,
+ JsonMapping,
Requester,
StateMap,
UserID,
@@ -151,6 +153,7 @@ __all__ = [
"PRESENCE_ALL_USERS",
"LoginResponse",
"JsonDict",
+ "JsonMapping",
"EventBase",
"StateMap",
"ProfileInfo",
@@ -193,6 +196,7 @@ class ModuleApi:
self._clock: Clock = hs.get_clock()
self._registration_handler = hs.get_registration_handler()
self._send_email_handler = hs.get_send_email_handler()
+ self._push_rules_handler = hs.get_push_rules_handler()
self.custom_template_dir = hs.config.server.custom_template_directory
try:
@@ -1350,6 +1354,68 @@ class ModuleApi:
"""
await self._store.add_user_bound_threepid(user_id, medium, address, id_server)
+ def check_push_rule_actions(
+ self, actions: List[Union[str, Dict[str, str]]]
+ ) -> None:
+ """Checks if the given push rule actions are valid according to the Matrix
+ specification.
+
+ See https://spec.matrix.org/v1.2/client-server-api/#actions for the list of valid
+ actions.
+
+ Added in Synapse v1.58.0.
+
+ Args:
+ actions: the actions to check.
+
+ Raises:
+ synapse.module_api.errors.InvalidRuleException if the actions are invalid.
+ """
+ check_actions(actions)
+
+ async def set_push_rule_action(
+ self,
+ user_id: str,
+ scope: str,
+ kind: str,
+ rule_id: str,
+ actions: List[Union[str, Dict[str, str]]],
+ ) -> None:
+ """Changes the actions of an existing push rule for the given user.
+
+ See https://spec.matrix.org/v1.2/client-server-api/#push-rules for more
+ information about push rules and their syntax.
+
+ Can only be called on the main process.
+
+ Added in Synapse v1.58.0.
+
+ Args:
+ user_id: the user for which to change the push rule's actions.
+ scope: the push rule's scope, currently only "global" is allowed.
+ kind: the push rule's kind.
+ rule_id: the push rule's identifier.
+ actions: the actions to run when the rule's conditions match.
+
+ Raises:
+ RuntimeError if this method is called on a worker or `scope` is invalid.
+ synapse.module_api.errors.RuleNotFoundException if the rule being modified
+ can't be found.
+ synapse.module_api.errors.InvalidRuleException if the actions are invalid.
+ """
+ if self.worker_app is not None:
+ raise RuntimeError("module tried to change push rule actions on a worker")
+
+ if scope != "global":
+ raise RuntimeError(
+ "invalid scope %s, only 'global' is currently allowed" % scope
+ )
+
+ spec = RuleSpec(scope, kind, rule_id, "actions")
+ await self._push_rules_handler.set_rule_attr(
+ user_id, spec, {"actions": actions}
+ )
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
@@ -1419,7 +1485,7 @@ class AccountDataManager:
f"{user_id} is not local to this homeserver; can't access account data for remote users."
)
- async def get_global(self, user_id: str, data_type: str) -> Optional[JsonDict]:
+ async def get_global(self, user_id: str, data_type: str) -> Optional[JsonMapping]:
"""
Gets some global account data, of a specified type, for the specified user.
diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py
index 1db900e41f..e58e0e60fe 100644
--- a/synapse/module_api/errors.py
+++ b/synapse/module_api/errors.py
@@ -20,10 +20,14 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config._base import ConfigError
+from synapse.handlers.push_rules import InvalidRuleException
+from synapse.storage.push_rule import RuleNotFoundException
__all__ = [
"InvalidClientCredentialsError",
"RedirectException",
"SynapseError",
"ConfigError",
+ "InvalidRuleException",
+ "RuleNotFoundException",
]
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index a93f6fd5e0..b98640b14a 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union
-
-import attr
+from typing import TYPE_CHECKING, List, Sequence, Tuple, Union
from synapse.api.errors import (
NotFoundError,
@@ -22,6 +20,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
+from synapse.handlers.push_rules import InvalidRuleException, RuleSpec, check_actions
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -29,7 +28,6 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.http.site import SynapseRequest
-from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client._base import client_patterns
@@ -40,14 +38,6 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class RuleSpec:
- scope: str
- template: str
- rule_id: str
- attr: Optional[str]
-
-
class PushRuleRestServlet(RestServlet):
PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
@@ -60,6 +50,7 @@ class PushRuleRestServlet(RestServlet):
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker.worker_app is not None
+ self._push_rules_handler = hs.get_push_rules_handler()
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker:
@@ -81,8 +72,13 @@ class PushRuleRestServlet(RestServlet):
user_id = requester.user.to_string()
if spec.attr:
- await self.set_rule_attr(user_id, spec, content)
- self.notify_user(user_id)
+ try:
+ await self._push_rules_handler.set_rule_attr(user_id, spec, content)
+ except InvalidRuleException as e:
+ raise SynapseError(400, "Invalid actions: %s" % e)
+ except RuleNotFoundException:
+ raise NotFoundError("Unknown rule")
+
return 200, {}
if spec.rule_id.startswith("."):
@@ -98,23 +94,23 @@ class PushRuleRestServlet(RestServlet):
before = parse_string(request, "before")
if before:
- before = _namespaced_rule_id(spec, before)
+ before = f"global/{spec.template}/{before}"
after = parse_string(request, "after")
if after:
- after = _namespaced_rule_id(spec, after)
+ after = f"global/{spec.template}/{after}"
try:
await self.store.add_push_rule(
user_id=user_id,
- rule_id=_namespaced_rule_id_from_spec(spec),
+ rule_id=f"global/{spec.template}/{spec.rule_id}",
priority_class=priority_class,
conditions=conditions,
actions=actions,
before=before,
after=after,
)
- self.notify_user(user_id)
+ self._push_rules_handler.notify_user(user_id)
except InconsistentRuleException as e:
raise SynapseError(400, str(e))
except RuleNotFoundException as e:
@@ -133,11 +129,11 @@ class PushRuleRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+ namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
try:
await self.store.delete_push_rule(user_id, namespaced_rule_id)
- self.notify_user(user_id)
+ self._push_rules_handler.notify_user(user_id)
return 200, {}
except StoreError as e:
if e.code == 404:
@@ -172,55 +168,6 @@ class PushRuleRestServlet(RestServlet):
else:
raise UnrecognizedRequestError()
- def notify_user(self, user_id: str) -> None:
- stream_id = self.store.get_max_push_rules_stream_id()
- self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
-
- async def set_rule_attr(
- self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
- ) -> None:
- if spec.attr not in ("enabled", "actions"):
- # for the sake of potential future expansion, shouldn't report
- # 404 in the case of an unknown request so check it corresponds to
- # a known attribute first.
- raise UnrecognizedRequestError()
-
- namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- rule_id = spec.rule_id
- is_default_rule = rule_id.startswith(".")
- if is_default_rule:
- if namespaced_rule_id not in BASE_RULE_IDS:
- raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
- if spec.attr == "enabled":
- if isinstance(val, dict) and "enabled" in val:
- val = val["enabled"]
- if not isinstance(val, bool):
- # Legacy fallback
- # This should *actually* take a dict, but many clients pass
- # bools directly, so let's not break them.
- raise SynapseError(400, "Value for 'enabled' must be boolean")
- await self.store.set_push_rule_enabled(
- user_id, namespaced_rule_id, val, is_default_rule
- )
- elif spec.attr == "actions":
- if not isinstance(val, dict):
- raise SynapseError(400, "Value must be a dict")
- actions = val.get("actions")
- if not isinstance(actions, list):
- raise SynapseError(400, "Value for 'actions' must be dict")
- _check_actions(actions)
- namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- rule_id = spec.rule_id
- is_default_rule = rule_id.startswith(".")
- if is_default_rule:
- if namespaced_rule_id not in BASE_RULE_IDS:
- raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
- await self.store.set_push_rule_actions(
- user_id, namespaced_rule_id, actions, is_default_rule
- )
- else:
- raise UnrecognizedRequestError()
-
def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
"""Turn a sequence of path components into a rule spec
@@ -291,24 +238,11 @@ def _rule_tuple_from_request_object(
raise InvalidRuleException("No actions found")
actions = req_obj["actions"]
- _check_actions(actions)
+ check_actions(actions)
return conditions, actions
-def _check_actions(actions: List[Union[str, JsonDict]]) -> None:
- if not isinstance(actions, list):
- raise InvalidRuleException("No actions found")
-
- for a in actions:
- if a in ["notify", "dont_notify", "coalesce"]:
- pass
- elif isinstance(a, dict) and "set_tweak" in a:
- pass
- else:
- raise InvalidRuleException("Unrecognised action")
-
-
def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict:
if path == []:
raise UnrecognizedRequestError(
@@ -357,17 +291,5 @@ def _priority_class_from_spec(spec: RuleSpec) -> int:
return pc
-def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str:
- return _namespaced_rule_id(spec, spec.rule_id)
-
-
-def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str:
- return "global/%s/%s" % (spec.template, rule_id)
-
-
-class InvalidRuleException(Exception):
- pass
-
-
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushRuleRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 70baf50fa4..13ef6b35a0 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -929,6 +929,10 @@ def _calculate_registration_flows(
# always let users provide both MSISDN & email
flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY])
+ # Add a flow that doesn't require any 3pids, if the config requests it.
+ if config.registration.enable_registration_token_3pid_bypasss:
+ flows.append([LoginType.REGISTRATION_TOKEN])
+
# Prepend m.login.terms to all flows if we're requiring consent
if config.consent.user_consent_at_registration:
for flow in flows:
@@ -942,7 +946,8 @@ def _calculate_registration_flows(
# Prepend registration token to all flows if we're requiring a token
if config.registration.registration_requires_token:
for flow in flows:
- flow.insert(0, LoginType.REGISTRATION_TOKEN)
+ if LoginType.REGISTRATION_TOKEN not in flow:
+ flow.insert(0, LoginType.REGISTRATION_TOKEN)
return flows
diff --git a/synapse/server.py b/synapse/server.py
index 37c72bd83a..d49c76518a 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -91,6 +91,7 @@ from synapse.handlers.presence import (
WorkerPresenceHandler,
)
from synapse.handlers.profile import ProfileHandler
+from synapse.handlers.push_rules import PushRulesHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.register import RegistrationHandler
@@ -811,6 +812,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return AccountHandler(self)
@cache_in_self
+ def get_push_rules_handler(self) -> PushRulesHandler:
+ return PushRulesHandler(self)
+
+ @cache_in_self
def get_outbound_redis_connection(self) -> "ConnectionHandler":
"""
The Redis connection used for replication.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 951031af50..5895b89202 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,12 +15,17 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from synapse.config.homeserver import HomeServerConfig
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.stats import UserSortOrder
-from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
IdGenerator,
MultiWriterIdGenerator,
@@ -266,7 +271,9 @@ class DataStore(
A tuple of a list of mappings from user to information and a count of total users.
"""
- def get_users_paginate_txn(txn):
+ def get_users_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
filters = []
args = [self.hs.config.server.server_name]
@@ -301,7 +308,7 @@ class DataStore(
"""
sql = "SELECT COUNT(*) as total_users " + sql_base
txn.execute(sql, args)
- count = txn.fetchone()[0]
+ count = cast(Tuple[int], txn.fetchone())[0]
sql = f"""
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
@@ -338,7 +345,9 @@ class DataStore(
)
-def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
+def check_database_before_upgrade(
+ cur: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
+) -> None:
"""Called before upgrading an existing database to check that it is broadly sane
compared with the configuration.
"""
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index fa732edcca..945707b0ec 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast
from synapse.appservice import (
ApplicationService,
@@ -83,7 +83,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
txn.execute(
"SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
)
- return txn.fetchone()[0] # type: ignore
+ return cast(Tuple[int], txn.fetchone())[0]
self._as_txn_seq_gen = build_sequence_generator(
db_conn,
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index b4a1b041b1..599b418383 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,17 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -118,7 +128,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
prefilled_cache=device_outbox_prefill,
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
+ ) -> None:
if stream_name == ToDeviceStream.NAME:
# If replication is happening than postgres must be being used.
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
@@ -134,7 +150,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
return super().process_replication_rows(stream_name, instance_name, token, rows)
- def get_to_device_stream_token(self):
+ def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token()
async def get_messages_for_user_devices(
@@ -301,7 +317,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
if not user_ids_to_query:
return {}, to_stream_id
- def get_device_messages_txn(txn: LoggingTransaction):
+ def get_device_messages_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
# Build a query to select messages from any of the given devices that
# are between the given stream id bounds.
@@ -428,7 +446,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": "No changes in cache since last check"})
return 0
- def delete_messages_for_device_txn(txn):
+ def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
@@ -455,15 +473,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
@trace
async def get_new_device_msgs_for_remote(
- self, destination, last_stream_id, current_stream_id, limit
- ) -> Tuple[List[dict], int]:
+ self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
+ ) -> Tuple[List[JsonDict], int]:
"""
Args:
- destination(str): The name of the remote server.
- last_stream_id(int|long): The last position of the device message stream
+ destination: The name of the remote server.
+ last_stream_id: The last position of the device message stream
that the server sent up to.
- current_stream_id(int|long): The current position of the device
- message stream.
+ current_stream_id: The current position of the device message stream.
Returns:
A list of messages for the device and where in the stream the messages got to.
"""
@@ -485,7 +502,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return [], last_stream_id
@trace
- def get_new_messages_for_remote_destination_txn(txn):
+ def get_new_messages_for_remote_destination_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
@@ -527,7 +546,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
up_to_stream_id: Where to delete messages up to.
"""
- def delete_messages_for_remote_destination_txn(txn):
+ def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
@@ -566,7 +585,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_new_device_messages_txn(txn):
+ def get_all_new_device_messages_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
@@ -607,8 +628,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
@trace
async def add_messages_to_device_inbox(
self,
- local_messages_by_user_then_device: dict,
- remote_messages_by_destination: dict,
+ local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
+ remote_messages_by_destination: Dict[str, JsonDict],
) -> int:
"""Used to send messages from this server.
@@ -624,7 +645,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
assert self._can_write_to_device
- def add_messages_txn(txn, now_ms, stream_id):
+ def add_messages_txn(
+ txn: LoggingTransaction, now_ms: int, stream_id: int
+ ) -> None:
# Add the local messages directly to the local inbox.
self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
@@ -677,11 +700,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return self._device_inbox_id_gen.get_current_token()
async def add_messages_from_remote_to_device_inbox(
- self, origin: str, message_id: str, local_messages_by_user_then_device: dict
+ self,
+ origin: str,
+ message_id: str,
+ local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
) -> int:
assert self._can_write_to_device
- def add_messages_txn(txn, now_ms, stream_id):
+ def add_messages_txn(
+ txn: LoggingTransaction, now_ms: int, stream_id: int
+ ) -> None:
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
@@ -727,8 +755,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return stream_id
def _add_messages_to_local_device_inbox_txn(
- self, txn, stream_id, messages_by_user_then_device
- ):
+ self,
+ txn: LoggingTransaction,
+ stream_id: int,
+ messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
+ ) -> None:
assert self._can_write_to_device
local_by_user_then_device = {}
@@ -840,8 +871,10 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self._remove_dead_devices_from_device_inbox,
)
- async def _background_drop_index_device_inbox(self, progress, batch_size):
- def reindex_txn(conn):
+ async def _background_drop_index_device_inbox(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ def reindex_txn(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 483dd80406..2df4dd4ed4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -25,6 +25,7 @@ from typing import (
Optional,
Set,
Tuple,
+ cast,
)
from synapse.api.errors import Codes, StoreError
@@ -136,7 +137,9 @@ class DeviceWorkerStore(SQLBaseStore):
Number of devices of this users.
"""
- def count_devices_by_users_txn(txn, user_ids):
+ def count_devices_by_users_txn(
+ txn: LoggingTransaction, user_ids: List[str]
+ ) -> int:
sql = """
SELECT count(*)
FROM devices
@@ -149,7 +152,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
txn.execute(sql + clause, args)
- return txn.fetchone()[0]
+ return cast(Tuple[int], txn.fetchone())[0]
if not user_ids:
return 0
@@ -468,7 +471,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
- return list(txn)
+ return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall())
async def _get_device_update_edus_by_remote(
self,
@@ -549,7 +552,7 @@ class DeviceWorkerStore(SQLBaseStore):
async def _get_last_device_update_for_remote_user(
self, destination: str, user_id: str, from_stream_id: int
) -> int:
- def f(txn):
+ def f(txn: LoggingTransaction) -> int:
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
@@ -767,7 +770,7 @@ class DeviceWorkerStore(SQLBaseStore):
if not user_ids_to_check:
return set()
- def _get_users_whose_devices_changed_txn(txn):
+ def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
changes = set()
stream_id_where_clause = "stream_id > ?"
@@ -966,7 +969,9 @@ class DeviceWorkerStore(SQLBaseStore):
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user."""
- def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
+ def _mark_remote_user_device_list_as_unsubscribed_txn(
+ txn: LoggingTransaction,
+ ) -> None:
self.db_pool.simple_delete_txn(
txn,
table="device_lists_remote_extremeties",
@@ -1004,7 +1009,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
def _store_dehydrated_device_txn(
- self, txn, user_id: str, device_id: str, device_data: str
+ self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
) -> Optional[str]:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
@@ -1081,7 +1086,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
yesterday = self._clock.time_msec() - prune_age
- def _prune_txn(txn):
+ def _prune_txn(txn: LoggingTransaction) -> None:
# look for (user, destination) pairs which have an update older than
# the cutoff.
#
@@ -1204,8 +1209,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
"drop_device_lists_outbound_last_success_non_unique_idx",
)
- async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
- def f(conn):
+ async def _drop_device_list_streams_non_unique_indexes(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ def f(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
@@ -1217,7 +1224,9 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
)
return 1
- async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
+ async def _remove_duplicate_outbound_pokes(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
# for some reason, we have accumulated duplicate entries in
# device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
# efficient.
@@ -1230,7 +1239,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
)
- def _txn(txn):
+ def _txn(txn: LoggingTransaction) -> int:
clause, args = make_tuple_comparison_clause(
[(x, last_row[x]) for x in KEY_COLS]
)
@@ -1602,7 +1611,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context = get_active_span_text_map()
- def add_device_changes_txn(txn, stream_ids):
+ def add_device_changes_txn(
+ txn: LoggingTransaction, stream_ids: List[int]
+ ) -> None:
self._add_device_change_to_stream_txn(
txn,
user_id,
@@ -1635,8 +1646,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn: LoggingTransaction,
user_id: str,
device_ids: Collection[str],
- stream_ids: List[str],
- ):
+ stream_ids: List[int],
+ ) -> None:
txn.call_after(
self._device_list_stream_cache.entity_has_changed,
user_id,
@@ -1720,7 +1731,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_ids: Iterable[str],
room_ids: Collection[str],
- stream_ids: List[str],
+ stream_ids: List[int],
context: Dict[str, str],
) -> None:
"""Record the user in the room has updated their device."""
@@ -1775,7 +1786,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
LIMIT ?
"""
- def get_uncoverted_outbound_room_pokes_txn(txn):
+ def get_uncoverted_outbound_room_pokes_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
txn.execute(sql, (limit,))
return [
@@ -1808,7 +1821,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Marks the associated row in `device_lists_changes_in_room` as handled.
"""
- def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+ def add_device_list_outbound_pokes_txn(
+ txn: LoggingTransaction, stream_ids: List[int]
+ ) -> None:
if hosts:
self._add_device_outbound_poke_to_stream_txn(
txn,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2a1e567ce0..9a6c2fd47a 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -47,6 +47,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
+from synapse.storage.engines.postgres import PostgresEngine
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
@@ -364,6 +365,20 @@ class PersistEventsStore:
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+ # We check that the room still exists for events we're trying to
+ # persist. This is to protect against races with deleting a room.
+ #
+ # Annoyingly SQLite doesn't support row level locking.
+ if isinstance(self.database_engine, PostgresEngine):
+ for room_id in {e.room_id for e, _ in events_and_contexts}:
+ txn.execute(
+ "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE",
+ (room_id,),
+ )
+ row = txn.fetchone()
+ if row is None:
+ raise Exception(f"Room does not exist {room_id}")
+
# stream orderings should have been assigned by now
assert min_stream_order
assert max_stream_order
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 0aef121d83..04efad9e9a 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -522,7 +522,9 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups",
)
- async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
+ async def get_all_groups_for_user(
+ self, user_id: str, now_token: int
+ ) -> List[JsonDict]:
def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, type, membership, u.content
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 6990f3ed1d..0a19f607bd 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -15,11 +15,12 @@
import itertools
import logging
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
from synapse.storage.keys import FetchKeyResult
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
@@ -35,7 +36,9 @@ class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys"""
@cached()
- def _get_server_verify_key(self, server_name_and_key_id):
+ def _get_server_verify_key(
+ self, server_name_and_key_id: Tuple[str, str]
+ ) -> FetchKeyResult:
raise NotImplementedError()
@cachedList(
@@ -179,19 +182,21 @@ class KeyStore(SQLBaseStore):
async def get_server_keys_json(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
- ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
+ ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
"""Retrieve the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently
used in an HTTP response.
Args:
- server_keys (list): List of (server_name, key_id, source) triplets.
+ server_keys: List of (server_name, key_id, source) triplets.
Returns:
A mapping from (server_name, key_id, source) triplets to a list of dicts
"""
- def _get_server_keys_json_txn(txn):
+ def _get_server_keys_json_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
results = {}
for server_name, key_id, from_server in server_keys:
keyvalues = {"server_name": server_name}
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 322ed05390..40ac377ca9 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -388,7 +388,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
async def store_url_cache(
- self, url, response_code, etag, expires_ts, og, media_id, download_ts
+ self,
+ url: str,
+ response_code: int,
+ etag: Optional[str],
+ expires_ts: int,
+ og: Optional[str],
+ media_id: str,
+ download_ts: int,
) -> None:
await self.db_pool.simple_insert(
"local_media_repository_url_cache",
@@ -441,7 +448,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_cached_remote_media(
- self, origin, media_id: str
+ self, origin: str, media_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
"remote_media_cache",
@@ -608,7 +615,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
- def delete_remote_media_txn(txn):
+ def delete_remote_media_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
"remote_media_cache",
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 4f1c22c71b..5beb8f1d4b 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -232,10 +232,10 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
- self._invalidate_all_cache_and_stream( # type: ignore[attr-defined]
+ self._invalidate_all_cache_and_stream(
txn, self.user_last_seen_monthly_active
)
- self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined]
+ self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
reserved_users = await self.get_registered_reserved_users()
await self.db_pool.runInteraction(
@@ -363,7 +363,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
- is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
+ is_guest = await self.is_guest(user_id)
if is_guest:
return
is_trial = await self.is_trial_user(user_id)
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index d3c4611686..b47c511450 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
@@ -103,7 +103,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
prefilled_cache=presence_cache_prefill,
)
- async def update_presence(self, presence_states) -> Tuple[int, int]:
+ async def update_presence(
+ self, presence_states: List[UserPresenceState]
+ ) -> Tuple[int, int]:
assert self._can_persist_presence
stream_ordering_manager = self._presence_id_gen.get_next_mult(
@@ -121,7 +123,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return stream_orderings[-1], self._presence_id_gen.get_current_token()
def _update_presence_txn(
- self, txn: LoggingTransaction, stream_orderings, presence_states
+ self,
+ txn: LoggingTransaction,
+ stream_orderings: List[int],
+ presence_states: List[UserPresenceState],
) -> None:
for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after(
@@ -405,7 +410,13 @@ class PresenceStore(PresenceBackgroundUpdateStore):
self._presence_on_startup = []
return active_on_startup
- def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
for row in rows:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 2e3818e432..bfc85b3add 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -324,7 +324,12 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
)
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
- # First we fetch all the state groups that should be deleted, before
+ # We *immediately* delete the room from the rooms table. This ensures
+ # that we don't race when persisting events (as that transaction checks
+ # that the room exists).
+ txn.execute("DELETE FROM rooms WHERE room_id = ?", (room_id,))
+
+ # Next, we fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
"""
@@ -403,7 +408,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"room_stats_state",
"room_stats_current",
"room_stats_earliest_token",
- "rooms",
"stream_ordering_to_exterm",
"users_in_public_rooms",
"users_who_share_private_rooms",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 92539f5d41..eb85bbd392 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -16,7 +16,7 @@ import abc
import logging
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
-from synapse.api.errors import NotFoundError, StoreError
+from synapse.api.errors import StoreError
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -618,7 +618,7 @@ class PushRuleStore(PushRulesWorkerStore):
are always stored in the database `push_rules` table).
Raises:
- NotFoundError if the rule does not exist.
+ RuleNotFoundException if the rule does not exist.
"""
async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
@@ -668,8 +668,7 @@ class PushRuleStore(PushRulesWorkerStore):
)
txn.execute(sql, (user_id, rule_id))
if txn.fetchone() is None:
- # needed to set NOT_FOUND code.
- raise NotFoundError("Push rule does not exist.")
+ raise RuleNotFoundException("Push rule does not exist.")
self.db_pool.simple_upsert_txn(
txn,
@@ -698,9 +697,6 @@ class PushRuleStore(PushRulesWorkerStore):
"""
Sets the `actions` state of a push rule.
- Will throw NotFoundError if the rule does not exist; the Code for this
- is NOT_FOUND.
-
Args:
user_id: the user ID of the user who wishes to enable/disable the rule
e.g. '@tina:example.org'
@@ -712,6 +708,9 @@ class PushRuleStore(PushRulesWorkerStore):
is_default_rule: True if and only if this is a server-default rule.
This skips the check for existence (as only user-created rules
are always stored in the database `push_rules` table).
+
+ Raises:
+ RuleNotFoundException if the rule does not exist.
"""
actions_json = json_encoder.encode(actions)
@@ -744,7 +743,7 @@ class PushRuleStore(PushRulesWorkerStore):
except StoreError as serr:
if serr.code == 404:
# this sets the NOT_FOUND error Code
- raise NotFoundError("Push rule does not exist")
+ raise RuleNotFoundException("Push rule does not exist")
else:
raise
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index cf64cd63a4..91286c9b65 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -14,11 +14,25 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Tuple,
+ cast,
+)
from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -117,7 +131,7 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(ret)
async def get_all_pushers(self) -> Iterator[PusherConfig]:
- def get_pushers(txn):
+ def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
@@ -152,7 +166,9 @@ class PusherWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_updated_pushers_rows_txn(txn):
+ def get_all_updated_pushers_rows_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT id, user_name, app_id, pushkey
FROM pushers
@@ -160,10 +176,13 @@ class PusherWorkerStore(SQLBaseStore):
ORDER BY id ASC LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [
- (stream_id, (user_name, app_id, pushkey, False))
- for stream_id, user_name, app_id, pushkey in txn
- ]
+ updates = cast(
+ List[Tuple[int, tuple]],
+ [
+ (stream_id, (user_name, app_id, pushkey, False))
+ for stream_id, user_name, app_id, pushkey in txn
+ ],
+ )
sql = """
SELECT stream_id, user_id, app_id, pushkey
@@ -192,12 +211,12 @@ class PusherWorkerStore(SQLBaseStore):
)
@cached(num_args=1, max_entries=15000)
- async def get_if_user_has_pusher(self, user_id: str):
+ async def get_if_user_has_pusher(self, user_id: str) -> None:
# This only exists for the cachedList decorator
raise NotImplementedError()
async def update_pusher_last_stream_ordering(
- self, app_id, pushkey, user_id, last_stream_ordering
+ self, app_id: str, pushkey: str, user_id: str, last_stream_ordering: int
) -> None:
await self.db_pool.simple_update_one(
"pushers",
@@ -291,7 +310,7 @@ class PusherWorkerStore(SQLBaseStore):
last_user = progress.get("last_user", "")
- def _delete_pushers(txn) -> int:
+ def _delete_pushers(txn: LoggingTransaction) -> int:
sql = """
SELECT name FROM users
@@ -339,7 +358,7 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
- def _delete_pushers(txn) -> int:
+ def _delete_pushers(txn: LoggingTransaction) -> int:
sql = """
SELECT p.id, access_token FROM pushers AS p
@@ -396,7 +415,7 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
- def _delete_pushers(txn) -> int:
+ def _delete_pushers(txn: LoggingTransaction) -> int:
sql = """
SELECT p.id, p.user_name, p.app_id, p.pushkey
@@ -502,7 +521,7 @@ class PusherStore(PusherWorkerStore):
async def delete_pusher_by_app_id_pushkey_user_id(
self, app_id: str, pushkey: str, user_id: str
) -> None:
- def delete_pusher_txn(txn, stream_id):
+ def delete_pusher_txn(txn: LoggingTransaction, stream_id: int) -> None:
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)
@@ -547,7 +566,7 @@ class PusherStore(PusherWorkerStore):
# account.
pushers = list(await self.get_pushers_by_user_id(user_id))
- def delete_pushers_txn(txn, stream_ids):
+ def delete_pushers_txn(txn: LoggingTransaction, stream_ids: List[int]) -> None:
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 332e901dda..7d96f4feda 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -122,10 +122,21 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
return {r["user_id"] for r in receipts}
- @cached(num_args=2)
+ @cached()
async def get_receipts_for_room(
self, room_id: str, receipt_type: str
) -> List[Dict[str, Any]]:
+ """
+ Fetch the event IDs for the latest receipt for all users in a room with the given receipt type.
+
+ Args:
+ room_id: The room ID to fetch the receipt for.
+ receipt_type: The receipt type to fetch.
+
+ Returns:
+ A list of dictionaries, one for each user ID. Each dictionary
+ contains a user ID and the event ID of that user's latest receipt.
+ """
return await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
@@ -133,10 +144,21 @@ class ReceiptsWorkerStore(SQLBaseStore):
desc="get_receipts_for_room",
)
- @cached(num_args=3)
+ @cached()
async def get_last_receipt_event_id_for_user(
self, user_id: str, room_id: str, receipt_type: str
) -> Optional[str]:
+ """
+ Fetch the event ID for the latest receipt in a room with the given receipt type.
+
+ Args:
+ user_id: The user to fetch receipts for.
+ room_id: The room ID to fetch the receipt for.
+ receipt_type: The receipt type to fetch.
+
+ Returns:
+ The event ID of the latest receipt, if one exists; otherwise `None`.
+ """
return await self.db_pool.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
@@ -149,10 +171,23 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
- @cached(num_args=2)
+ @cached()
async def get_receipts_for_user(
self, user_id: str, receipt_type: str
) -> Dict[str, str]:
+ """
+ Fetch the event IDs for the latest receipts sent by the given user.
+
+ Args:
+ user_id: The user to fetch receipts for.
+ receipt_type: The receipt type to fetch.
+
+ Returns:
+ A map of room ID to the event ID of the latest receipt for that room.
+
+ If the user has not sent a receipt to a room then it will not appear
+ in the returned dictionary.
+ """
rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
@@ -165,6 +200,17 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_receipts_for_user_with_orderings(
self, user_id: str, receipt_type: str
) -> JsonDict:
+ """
+ Fetch receipts for all rooms that the given user is joined to.
+
+ Args:
+ user_id: The user to fetch receipts for.
+ receipt_type: The receipt type to fetch.
+
+ Returns:
+ A map of room ID to the latest receipt information.
+ """
+
def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
sql = (
"SELECT rl.room_id, rl.event_id,"
@@ -241,7 +287,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
- @cached(num_args=3, tree=True)
+ @cached(tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[JsonDict]:
@@ -541,7 +587,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
data: JsonDict,
stream_id: int,
) -> Optional[int]:
- """Inserts a read-receipt into the database if it's newer than the current RR
+ """Inserts a receipt into the database if it's newer than the current one.
Returns:
None if the RR is older than the current RR
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index e653841fe5..18ae8aee29 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -12,11 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections.abc
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
-from frozendict import frozendict
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -160,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
predecessor = create_event.content.get("predecessor", None)
# Ensure the key is a dictionary
- if not isinstance(predecessor, (dict, frozendict)):
+ if not isinstance(predecessor, collections.abc.Mapping):
return None
# The keys must be strings since the data is JSON.
@@ -370,10 +369,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _update_state_for_partial_state_event_txn(
self,
- txn,
+ txn: LoggingTransaction,
event: EventBase,
context: EventContext,
- ):
+ ) -> None:
# we shouldn't have any outliers here
assert not event.internal_metadata.is_outlier()
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 2d339b6008..f38bedbbcd 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -131,7 +131,7 @@ class UIAuthWorkerStore(SQLBaseStore):
session_id: str,
stage_type: str,
result: Union[str, bool, JsonDict],
- ):
+ ) -> None:
"""
Mark a session stage as completed.
@@ -200,7 +200,9 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="set_ui_auth_client_dict",
)
- async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
+ async def set_ui_auth_session_data(
+ self, session_id: str, key: str, value: Any
+ ) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
@@ -223,7 +225,7 @@ class UIAuthWorkerStore(SQLBaseStore):
def _set_ui_auth_session_data_txn(
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
- ):
+ ) -> None:
# Get the current value.
result = cast(
Dict[str, Any],
@@ -275,7 +277,7 @@ class UIAuthWorkerStore(SQLBaseStore):
session_id: str,
user_agent: str,
ip: str,
- ):
+ ) -> None:
"""Add the given user agent / IP to the tracking table"""
await self.db_pool.simple_upsert(
table="ui_auth_sessions_ips",
@@ -318,7 +320,7 @@ class UIAuthWorkerStore(SQLBaseStore):
def _delete_old_ui_auth_sessions_txn(
self, txn: LoggingTransaction, expiration_time: int
- ):
+ ) -> None:
# Get the expired sessions.
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
txn.execute(sql, [expiration_time])
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e3153d1a4a..546d6bae6e 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -501,11 +501,11 @@ def _upgrade_existing_database(
if hasattr(module, "run_create"):
logger.info("Running %s:run_create", relative_path)
- module.run_create(cur, database_engine) # type: ignore
+ module.run_create(cur, database_engine)
if not is_empty and hasattr(module, "run_upgrade"):
logger.info("Running %s:run_upgrade", relative_path)
- module.run_upgrade(cur, database_engine, config=config) # type: ignore
+ module.run_upgrade(cur, database_engine, config=config)
elif ext == ".pyc" or file_name == "__pycache__":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 0b9ac26b69..f6b3ee31e4 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -107,7 +107,7 @@ class TTLCache(Generic[KT, VT]):
self._metrics.inc_hits()
return e.value, e.expiry_time, e.ttl
- def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore
+ def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Remove a value from the cache
If key is in the cache, remove it and return its value, else return default.
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 9c405eb4d7..7223af1a36 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections.abc
from typing import Any
from frozendict import frozendict
@@ -35,7 +36,7 @@ def freeze(o: Any) -> Any:
def unfreeze(o: Any) -> Any:
- if isinstance(o, (dict, frozendict)):
+ if isinstance(o, collections.abc.Mapping):
return {k: unfreeze(v) for k, v in o.items()}
if isinstance(o, (bytes, str)):
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index ec8864dafe..268a48d7ba 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -83,7 +83,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
)
# mock up the response, and have the agent return it
- self._mock_agent.request.return_value = defer.succeed(
+ self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
_mock_response(
{
"pdus": [
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 91f982518e..6b26353d5e 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -226,7 +226,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
self.hs.get_federation_transport_client().query_user_devices.return_value = (
- defer.succeed(
+ make_awaitable(
{
"stream_id": "1",
"user_id": "@user2:host2",
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 8c72cf6b30..5b0cd1ab86 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -411,6 +411,88 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)
+ def test_sending_read_receipt_batches_to_application_services(self):
+ """Tests that a large batch of read receipts are sent correctly to
+ interested application services.
+ """
+ # Register an application service that's interested in a certain user
+ # and room prefix
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": "@exclusive_as_user:.+",
+ "exclusive": True,
+ }
+ ],
+ ApplicationService.NS_ROOMS: [
+ {
+ "regex": "!fakeroom_.*",
+ "exclusive": True,
+ }
+ ],
+ },
+ )
+
+ # "Complete" a transaction.
+ # All this really does for us is make an entry in the application_services_state
+ # database table, which tracks the current stream_token per stream ID per AS.
+ self.get_success(
+ self.hs.get_datastores().main.complete_appservice_txn(
+ 0,
+ interested_appservice,
+ )
+ )
+
+ # Now, pretend that we receive a large burst of read receipts (300 total) that
+ # all come in at once.
+ for i in range(300):
+ self.get_success(
+ # Insert a fake read receipt into the database
+ self.hs.get_datastores().main.insert_receipt(
+ # We have to use unique room ID + user ID combinations here, as the db query
+ # is an upsert.
+ room_id=f"!fakeroom_{i}:test",
+ receipt_type="m.read",
+ user_id=self.local_user,
+ event_ids=[f"$eventid_{i}"],
+ data={},
+ )
+ )
+
+ # Now notify the appservice handler that 300 read receipts have all arrived
+ # at once. What will it do!
+ # note: stream tokens start at 2
+ for stream_token in range(2, 303):
+ self.get_success(
+ self.hs.get_application_service_handler()._notify_interested_services_ephemeral(
+ services=[interested_appservice],
+ stream_key="receipt_key",
+ new_token=stream_token,
+ users=[self.exclusive_as_user],
+ )
+ )
+
+ # Using our txn send mock, we can see what the AS received. After iterating over every
+ # transaction, we'd like to see all 300 read receipts accounted for.
+ # No more, no less.
+ all_ephemeral_events = []
+ for call in self.send_mock.call_args_list:
+ ephemeral_events = call[0][2]
+ all_ephemeral_events += ephemeral_events
+
+ # Ensure that no duplicate events were sent
+ self.assertEqual(len(all_ephemeral_events), 300)
+
+ # Check that the ephemeral event is a read receipt with the expected structure
+ latest_read_receipt = all_ephemeral_events[-1]
+ self.assertEqual(latest_read_receipt["type"], "m.receipt")
+
+ event_id = list(latest_read_receipt["content"].keys())[0]
+ self.assertEqual(
+ latest_read_receipt["content"][event_id]["m.read"], {self.local_user: {}}
+ )
+
@unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index a54aa29cf1..751025c5da 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -201,4 +201,16 @@ class CasHandlerTestCase(HomeserverTestCase):
def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest"""
- return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
+ mock = Mock(
+ spec=[
+ "finish",
+ "getClientIP",
+ "getHeader",
+ "setHeader",
+ "setResponseCode",
+ "write",
+ ]
+ )
+ # `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
+ mock._disconnected = False
+ return mock
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 8c74ed1fcf..1e6ad4b663 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -19,7 +19,6 @@ from unittest import mock
from parameterized import parameterized
from signedjson import key as key, sign as sign
-from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
@@ -704,7 +703,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_client_keys = mock.Mock(
- return_value=defer.succeed(
+ return_value=make_awaitable(
{
"device_keys": {remote_user_id: {}},
"master_keys": {
@@ -777,14 +776,14 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
- return_value=defer.succeed({"some_room_id"})
+ return_value=make_awaitable({"some_room_id"})
)
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_user_devices = mock.Mock(
- return_value=defer.succeed(
+ return_value=make_awaitable(
{
"user_id": remote_user_id,
"stream_id": 1,
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index d401fda938..addf14fa2b 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -17,8 +17,6 @@
from typing import Any, Type, Union
from unittest.mock import Mock
-from twisted.internet import defer
-
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
@@ -190,7 +188,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
# check_password must return an awaitable
- mock_password_provider.check_password.return_value = defer.succeed(True)
+ mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
@@ -226,13 +224,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.get_success(module_api.register_user("u"))
# log in twice, to get two devices
- mock_password_provider.check_password.return_value = defer.succeed(True)
+ mock_password_provider.check_password.return_value = make_awaitable(True)
tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock()
# have the auth provider deny the request to start with
- mock_password_provider.check_password.return_value = defer.succeed(False)
+ mock_password_provider.check_password.return_value = make_awaitable(False)
# make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2")
@@ -246,7 +244,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# Finally, check the request goes through when we allow it
- mock_password_provider.check_password.return_value = defer.succeed(True)
+ mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@@ -260,7 +258,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# check_password must return an awaitable
- mock_password_provider.check_password.return_value = defer.succeed(False)
+ mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result)
@@ -277,7 +275,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# have the auth provider deny the request
- mock_password_provider.check_password.return_value = defer.succeed(False)
+ mock_password_provider.check_password.return_value = make_awaitable(False)
# log in twice, to get two devices
tok1 = self.login("localuser", "localpass")
@@ -320,7 +318,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# check_password must return an awaitable
- mock_password_provider.check_password.return_value = defer.succeed(False)
+ mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -342,7 +340,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# allow login via the auth provider
- mock_password_provider.check_password.return_value = defer.succeed(True)
+ mock_password_provider.check_password.return_value = make_awaitable(True)
# log in twice, to get two devices
tok1 = self.login("localuser", "p")
@@ -359,7 +357,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_not_called()
# now try deleting with the local password
- mock_password_provider.check_password.return_value = defer.succeed(False)
+ mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
@@ -413,7 +411,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
- mock_password_provider.check_auth.return_value = defer.succeed(
+ mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
@@ -427,7 +425,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# try a weird username. Again, it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
- mock_password_provider.check_auth.return_value = defer.succeed(
+ mock_password_provider.check_auth.return_value = make_awaitable(
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
@@ -477,7 +475,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
- mock_password_provider.check_auth.return_value = defer.succeed(
+ mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
body["auth"]["test_field"] = "foo"
@@ -490,7 +488,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# and finally, succeed
- mock_password_provider.check_auth.return_value = defer.succeed(
+ mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
@@ -508,9 +506,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self):
- callback = Mock(return_value=defer.succeed(None))
+ callback = Mock(return_value=make_awaitable(None))
- mock_password_provider.check_auth.return_value = defer.succeed(
+ mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
@@ -646,7 +644,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
- mock_password_provider.check_auth.return_value = defer.succeed(
+ mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 5081b97573..65ab7db0c8 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -15,7 +15,7 @@
from typing import List
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.types import JsonDict
from tests import unittest
@@ -35,7 +35,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@rikj:jki.re": {
"ts": 1436451550453,
"hidden": True,
@@ -56,7 +56,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1435641916hfgh4394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@me:server.org": {
"ts": 1436451550453,
"hidden": True,
@@ -72,7 +72,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1435641916hfgh4394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@me:server.org": {
"ts": 1436451550453,
ReadReceiptEventFields.MSC2285_HIDDEN: True,
@@ -92,7 +92,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@rikj:jki.re": {
"ts": 1436451550453,
"hidden": True,
@@ -111,7 +111,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -130,7 +130,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$14356419edgd14394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@rikj:jki.re": {
"ts": 1436451550453,
"hidden": True,
@@ -138,7 +138,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
}
},
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -153,7 +153,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -171,9 +171,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
[
{
"content": {
- "$14356419ggffg114394fHBLK:matrix.org": {"m.read": {}},
+ "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}},
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -187,9 +187,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
[
{
"content": {
- "$14356419ggffg114394fHBLK:matrix.org": {"m.read": {}},
+ "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}},
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -209,7 +209,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
"content": {
"$143564gdfg6114394fHBLK:matrix.org": {},
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -225,7 +225,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
"content": {
"$143564gdfg6114394fHBLK:matrix.org": {},
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -244,7 +244,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$14356419edgd14394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@rikj:jki.re": {
"ts": 1436451550453,
"hidden": True,
@@ -258,7 +258,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -273,7 +273,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -297,7 +297,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$14356419edgd14394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@rikj:jki.re": "string",
}
},
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 45fd30cf43..b6ba19c739 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -193,8 +193,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self):
- # Type ignore: mypy doesn't like us assigning to methods.
- self.store.count_monthly_users = Mock( # type: ignore[assignment]
+ self.store.count_monthly_users = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
@@ -202,8 +201,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self):
- # Type ignore: mypy doesn't like us assigning to methods.
- self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
+ self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
@@ -211,8 +209,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
- # Type ignore: mypy doesn't like us assigning to methods.
- self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
+ self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
self.get_failure(
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 8d4404eda1..e2f0f90ef1 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -349,4 +349,16 @@ class SamlHandlerTestCase(HomeserverTestCase):
def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest"""
- return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
+ mock = Mock(
+ spec=[
+ "finish",
+ "getClientIP",
+ "getHeader",
+ "setHeader",
+ "setResponseCode",
+ "write",
+ ]
+ )
+ # `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
+ mock._disconnected = False
+ return mock
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index ffd5c4cb93..5f2e26a5fc 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -65,11 +65,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
- mock_keyring.verify_json_for_server.return_value = defer.succeed(True)
+ mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
# we mock out the federation client too
mock_federation_client = Mock(spec=["put_json"])
- mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
+ mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
@@ -98,7 +98,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore = hs.get_datastores().main
self.datastore.get_destination_retry_timings = Mock(
- return_value=defer.succeed(None)
+ return_value=make_awaitable(None)
)
self.datastore.get_device_updates_by_remote = Mock(
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index c6e501c7be..96e2e3039b 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -15,7 +15,6 @@ from typing import Tuple
from unittest.mock import Mock, patch
from urllib.parse import quote
-from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -30,6 +29,7 @@ from synapse.util import Clock
from tests import unittest
from tests.storage.test_user_directory import GetUserDirectoryTables
+from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
@@ -439,7 +439,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
- mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
+ mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):
@@ -454,7 +454,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.store.register_user(user_id=r_user_id, password_hash=None)
)
- mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
+ mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):
diff --git a/tests/module_api/test_account_data_manager.py b/tests/module_api/test_account_data_manager.py
index bec018d9e7..89009bea8c 100644
--- a/tests/module_api/test_account_data_manager.py
+++ b/tests/module_api/test_account_data_manager.py
@@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import SynapseError
from synapse.rest import admin
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -22,7 +26,9 @@ class ModuleApiTestCase(HomeserverTestCase):
admin.register_servlets,
]
- def prepare(self, reactor, clock, homeserver) -> None:
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self._store = homeserver.get_datastores().main
self._module_api = homeserver.get_module_api()
self._account_data_mgr = self._module_api.account_data_manager
@@ -91,7 +97,7 @@ class ModuleApiTestCase(HomeserverTestCase):
)
with self.assertRaises(TypeError):
# This throws an exception because it's a frozen dict.
- the_data["wombat"] = False
+ the_data["wombat"] = False # type: ignore[index]
def test_put_global(self) -> None:
"""
@@ -143,15 +149,14 @@ class ModuleApiTestCase(HomeserverTestCase):
with self.assertRaises(TypeError):
# The account data type must be a string.
self.get_success_or_raise(
- self._module_api.account_data_manager.put_global(
- self.user_id, 42, {} # type: ignore
- )
+ self._module_api.account_data_manager.put_global(self.user_id, 42, {}) # type: ignore[arg-type]
)
with self.assertRaises(TypeError):
# The account data dict must be a dict.
+ # noinspection PyTypeChecker
self.get_success_or_raise(
self._module_api.account_data_manager.put_global(
- self.user_id, "test.data", 42 # type: ignore
+ self.user_id, "test.data", 42 # type: ignore[arg-type]
)
)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9fd5d59c55..8bc84aaaca 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -19,8 +19,9 @@ from synapse.api.constants import EduTypes, EventTypes
from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
+from synapse.handlers.push_rules import InvalidRuleException
from synapse.rest import admin
-from synapse.rest.client import login, presence, profile, room
+from synapse.rest.client import login, notifications, presence, profile, room
from synapse.types import create_requester
from tests.events.test_presence_router import send_presence_update, sync_presence
@@ -38,6 +39,7 @@ class ModuleApiTestCase(HomeserverTestCase):
room.register_servlets,
presence.register_servlets,
profile.register_servlets,
+ notifications.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
@@ -553,6 +555,86 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(state[("org.matrix.test", "")].state_key, "")
self.assertEqual(state[("org.matrix.test", "")].content, {})
+ def test_set_push_rules_action(self) -> None:
+ """Test that a module can change the actions of an existing push rule for a user."""
+
+ # Create a room with 2 users in it. Push rules must not match if the user is the
+ # event's sender, so we need one user to send messages and one user to receive
+ # notifications.
+ user_id = self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ room_id = self.helper.create_room_as(user_id, is_public=True, tok=tok)
+
+ user_id2 = self.register_user("user2", "password")
+ tok2 = self.login("user2", "password")
+ self.helper.join(room_id, user_id2, tok=tok2)
+
+ # Register a 3rd user and join them to the room, so that we don't accidentally
+ # trigger 1:1 push rules.
+ user_id3 = self.register_user("user3", "password")
+ tok3 = self.login("user3", "password")
+ self.helper.join(room_id, user_id3, tok=tok3)
+
+ # Send a message as the second user and check that it notifies.
+ res = self.helper.send(room_id=room_id, body="here's a message", tok=tok2)
+ event_id = res["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/notifications",
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
+ self.assertEqual(
+ channel.json_body["notifications"][0]["event"]["event_id"],
+ event_id,
+ channel.json_body,
+ )
+
+ # Change the .m.rule.message actions to not notify on new messages.
+ self.get_success(
+ defer.ensureDeferred(
+ self.module_api.set_push_rule_action(
+ user_id=user_id,
+ scope="global",
+ kind="underride",
+ rule_id=".m.rule.message",
+ actions=["dont_notify"],
+ )
+ )
+ )
+
+ # Send another message as the second user and check that the number of
+ # notifications didn't change.
+ self.helper.send(room_id=room_id, body="here's another message", tok=tok2)
+
+ channel = self.make_request(
+ "GET",
+ "/notifications?from=",
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
+
+ def test_check_push_rules_actions(self) -> None:
+ """Test that modules can check whether a list of push rules actions are spec
+ compliant.
+ """
+ with self.assertRaises(InvalidRuleException):
+ self.module_api.check_push_rule_actions(["foo"])
+
+ with self.assertRaises(InvalidRuleException):
+ self.module_api.check_push_rule_actions({"foo": "bar"})
+
+ self.module_api.check_push_rule_actions(["notify"])
+
+ self.module_api.check_push_rule_actions(
+ [{"set_tweak": "sound", "value": "default"}]
+ )
+
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index f47d94f690..de19e75b9d 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from ._base import BaseSlavedStoreTestCase
@@ -26,9 +27,13 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedReceiptsStore
def test_receipt(self):
- self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
+ self.check("get_receipts_for_user", [USER_ID, ReceiptTypes.READ], {})
self.get_success(
- self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {})
+ self.master_store.insert_receipt(
+ ROOM_ID, ReceiptTypes.READ, USER_ID, [EVENT_ID], {}
+ )
)
self.replicate()
- self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID})
+ self.check(
+ "get_receipts_for_user", [USER_ID, ReceiptTypes.READ], {ROOM_ID: EVENT_ID}
+ )
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index ba1a63c0d6..6104a55aa1 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -102,8 +102,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
for i in range(20):
server_name = "other_server_%d" % (i,)
room = self.create_room_with_remote_server(user, token, server_name)
- mock_client1.reset_mock() # type: ignore[attr-defined]
- mock_client2.reset_mock() # type: ignore[attr-defined]
+ mock_client1.reset_mock()
+ mock_client2.reset_mock()
self.create_and_send_event(room, UserID.from_string(user))
self.replicate()
@@ -167,8 +167,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
for i in range(20):
server_name = "other_server_%d" % (i,)
room = self.create_room_with_remote_server(user, token, server_name)
- mock_client1.reset_mock() # type: ignore[attr-defined]
- mock_client2.reset_mock() # type: ignore[attr-defined]
+ mock_client1.reset_mock()
+ mock_client2.reset_mock()
self.get_success(
typing_handler.started_typing(
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index 0abe378fe4..b3738a0304 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -14,7 +14,6 @@
from http import HTTPStatus
from unittest.mock import Mock
-from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.presence import PresenceHandler
@@ -24,6 +23,7 @@ from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
+from tests.test_utils import make_awaitable
class PresenceTestCase(unittest.HomeserverTestCase):
@@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
presence_handler = Mock(spec=PresenceHandler)
- presence_handler.set_state.return_value = defer.succeed(None)
+ presence_handler.set_state.return_value = make_awaitable(None)
hs = self.setup_test_homeserver(
"red",
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 6ff79b9e2e..9443daa056 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, List, Optional
from unittest.mock import Mock, call
from urllib import parse as urlparse
-from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -1426,9 +1425,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
def test_simple(self) -> None:
"Simple test for searching rooms over federation"
- self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
- {}
- )
+ self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
search_filter = {"generic_search_term": "foobar"}
@@ -1456,7 +1453,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(404, "Not Found", b""),
- defer.succeed({}),
+ make_awaitable({}),
)
search_filter = {"generic_search_term": "foobar"}
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 773c16a54c..cb765455c1 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -24,6 +24,7 @@ from synapse.api.constants import (
EventContentFields,
EventTypes,
ReadReceiptEventFields,
+ ReceiptTypes,
RelationTypes,
)
from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
@@ -560,7 +561,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(1)
# Send a read receipt to tell the server we've read the latest event.
- body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
+ body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8")
channel = self.make_request(
"POST",
"/rooms/%s/read_markers" % self.room_id,
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 8d8251b2ac..21a1ca2a68 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -22,6 +22,7 @@ from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionC
from synapse.util import Clock
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import MockClock
@@ -38,7 +39,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_executes_given_function(self):
- cb = Mock(return_value=defer.succeed(self.mock_http_response))
+ cb = Mock(return_value=make_awaitable(self.mock_http_response))
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg"
)
@@ -47,7 +48,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_deduplicates_based_on_key(self):
- cb = Mock(return_value=defer.succeed(self.mock_http_response))
+ cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
@@ -130,7 +131,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cleans_up(self):
- cb = Mock(return_value=defer.succeed(self.mock_http_response))
+ cb = Mock(return_value=make_awaitable(self.mock_http_response))
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 02b96c9e6e..9ee9509d3a 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -14,8 +14,6 @@
from unittest.mock import Mock
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError
from synapse.rest import admin
@@ -68,16 +66,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value=make_awaitable(1000)
)
self._rlsn._server_notices_manager.send_notice = Mock(
- return_value=defer.succeed(Mock())
+ return_value=make_awaitable(Mock())
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
self.user_id = "@user_id:test"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
- return_value=defer.succeed("!something:localhost")
+ return_value=make_awaitable("!something:localhost")
)
- self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
+ self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
@override_config({"hs_disabled": True})
@@ -95,7 +93,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed"""
- self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
+ self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
@@ -111,7 +109,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock(
- return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
+ return_value=make_awaitable(None),
+ side_effect=ResourceLimitError(403, "foo"),
)
mock_event = Mock(
@@ -130,7 +129,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user does not have blocked notice, but should have one
"""
self._rlsn._auth.check_auth_blocking = Mock(
- return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
+ return_value=make_awaitable(None),
+ side_effect=ResourceLimitError(403, "foo"),
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -141,7 +141,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
- self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
+ self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -152,7 +152,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
- self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
+ self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(None)
)
@@ -167,7 +167,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
an alert message is not sent into the room
"""
self._rlsn._auth.check_auth_blocking = Mock(
- return_value=defer.succeed(None),
+ return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
),
@@ -182,7 +182,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
self._rlsn._auth.check_auth_blocking = Mock(
- return_value=defer.succeed(None),
+ return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
),
@@ -199,14 +199,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
is suppressed that the room is returned to an unblocked state.
"""
self._rlsn._auth.check_auth_blocking = Mock(
- return_value=defer.succeed(None),
+ return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
),
)
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
- return_value=defer.succeed((True, []))
+ return_value=make_awaitable((True, []))
)
mock_event = Mock(
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 60c8d37594..0fbf465670 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -14,7 +14,6 @@
from typing import Any, Dict, List
from unittest.mock import Mock
-from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes
@@ -259,10 +258,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_populate_monthly_users_should_update(self):
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
- self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment]
+ self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(None)
+ return_value=make_awaitable(None)
)
d = self.store.populate_monthly_active_users("user_id")
self.get_success(d)
@@ -272,9 +271,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_populate_monthly_users_should_not_update(self):
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
- self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment]
+ self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(self.hs.get_clock().time_msec())
+ return_value=make_awaitable(self.hs.get_clock().time_msec())
)
d = self.store.populate_monthly_active_users("user_id")
diff --git a/tests/test_federation.py b/tests/test_federation.py
index c39816de85..0cbef70bfa 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -233,7 +233,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client.
federation_client = self.homeserver.get_federation_client()
federation_client.query_user_devices = Mock(
- return_value=succeed(
+ return_value=make_awaitable(
{
"user_id": remote_user_id,
"stream_id": 1,
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index f05a373aa0..0d0d6faf0d 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -52,7 +52,7 @@ def make_awaitable(result: TV) -> Awaitable[TV]:
This uses Futures as they can be awaited multiple times so can be returned
to multiple callers.
"""
- future = Future() # type: ignore
+ future: Future[TV] = Future()
future.set_result(result)
return future
@@ -69,7 +69,7 @@ def setup_awaitable_errors() -> Callable[[], None]:
# State shared between unraisablehook and check_for_unraisable_exceptions.
unraisable_exceptions = []
- orig_unraisablehook = sys.unraisablehook # type: ignore
+ orig_unraisablehook = sys.unraisablehook
def unraisablehook(unraisable):
unraisable_exceptions.append(unraisable.exc_value)
@@ -78,11 +78,11 @@ def setup_awaitable_errors() -> Callable[[], None]:
"""
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
"""
- sys.unraisablehook = orig_unraisablehook # type: ignore
+ sys.unraisablehook = orig_unraisablehook
if unraisable_exceptions:
raise unraisable_exceptions.pop()
- sys.unraisablehook = unraisablehook # type: ignore
+ sys.unraisablehook = unraisablehook
return cleanup
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 51a197a8c6..9228454c9e 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -27,7 +27,7 @@ class ToTwistedHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
- self.tx_log.emit( # type: ignore
+ self.tx_log.emit(
twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
)
|