diff options
108 files changed, 1090 insertions, 517 deletions
diff --git a/.github/workflows/latest_deps.yml b/.github/workflows/latest_deps.yml index 1a61d179d9..c537a5a60f 100644 --- a/.github/workflows/latest_deps.yml +++ b/.github/workflows/latest_deps.yml @@ -32,12 +32,15 @@ jobs: with: python-version: "3.x" poetry-version: "1.2.0b1" + extras: "all" # Dump installed versions for debugging. - run: poetry run pip list > before.txt # Upgrade all runtime dependencies only. This is intended to mimic a fresh # `pip install matrix-synapse[all]` as closely as possible. - run: poetry update --no-dev - run: poetry run pip list > after.txt && (diff -u before.txt after.txt || true) + - name: Remove warn_unused_ignores from mypy config + run: sed '/warn_unused_ignores = True/d' -i mypy.ini - run: poetry run mypy trial: runs-on: ubuntu-latest diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cad4cb6d77..efa35b71df 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,13 +20,9 @@ jobs: - run: scripts-dev/config-lint.sh lint: - # This does a vanilla `poetry install` - no extras. I'm slightly anxious - # that we might skip some typechecks on code that uses extras. However, - # I think the right way to fix this is to mark any extras needed for - # typechecking as development dependencies. To detect this, we ought to - # turn up mypy's strictness: disallow unknown imports and be accept fewer - # uses of `Any`. uses: "matrix-org/backend-meta/.github/workflows/python-poetry-ci.yml@v1" + with: + typechecking-extras: "all" lint-crlf: runs-on: ubuntu-latest diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml index 8fc1affb77..5f0671f350 100644 --- a/.github/workflows/twisted_trunk.yml +++ b/.github/workflows/twisted_trunk.yml @@ -24,6 +24,8 @@ jobs: poetry remove twisted poetry add --extras tls git+https://github.com/twisted/twisted.git#trunk poetry install --no-interaction --extras "all test" + - name: Remove warn_unused_ignores from mypy config + run: sed '/warn_unused_ignores = True/d' -i mypy.ini - run: poetry run mypy trial: diff --git a/README.rst b/README.rst index d71d733679..80201d4eb1 100644 --- a/README.rst +++ b/README.rst @@ -296,8 +296,8 @@ directory of your choice:: Synapse has a number of external dependencies. We maintain a fixed development environment using [poetry](https://python-poetry.org/). First, install poetry. We recommend - pip install --user pipx - pipx install poetry + | pip install --user pipx + | pipx install poetry as described `here <https://python-poetry.org/docs/#installing-with-pipx>`_. (See `poetry's installation docs <https://python-poetry.org/docs/#installation>` diff --git a/changelog.d/12356.misc b/changelog.d/12356.misc new file mode 100644 index 0000000000..43e1929106 --- /dev/null +++ b/changelog.d/12356.misc @@ -0,0 +1 @@ +Fix scripts-dev to pass typechecking. \ No newline at end of file diff --git a/changelog.d/12406.feature b/changelog.d/12406.feature new file mode 100644 index 0000000000..e345afdee7 --- /dev/null +++ b/changelog.d/12406.feature @@ -0,0 +1 @@ +Add a module API to allow modules to change actions for existing push rules of local users. diff --git a/changelog.d/12480.misc b/changelog.d/12480.misc new file mode 100644 index 0000000000..18a85e7b15 --- /dev/null +++ b/changelog.d/12480.misc @@ -0,0 +1 @@ +Use supervisord to supervise Postgres and Caddy in the Complement image to reduce restart time. \ No newline at end of file diff --git a/changelog.d/12485.misc b/changelog.d/12485.misc new file mode 100644 index 0000000000..e793d08e5e --- /dev/null +++ b/changelog.d/12485.misc @@ -0,0 +1 @@ +Add some type hints to datastore. \ No newline at end of file diff --git a/changelog.d/12505.misc b/changelog.d/12505.misc new file mode 100644 index 0000000000..a691d7962f --- /dev/null +++ b/changelog.d/12505.misc @@ -0,0 +1 @@ +Use `make_awaitable` instead of `defer.succeed` for return values of mocks in tests. diff --git a/changelog.d/12526.feature b/changelog.d/12526.feature new file mode 100644 index 0000000000..c01596282c --- /dev/null +++ b/changelog.d/12526.feature @@ -0,0 +1 @@ +Add new `enable_registration_token_3pid_bypass` configuration option to allow registrations via token as an alternative to verifying a 3pid. \ No newline at end of file diff --git a/changelog.d/12531.misc b/changelog.d/12531.misc new file mode 100644 index 0000000000..412fc9b6dc --- /dev/null +++ b/changelog.d/12531.misc @@ -0,0 +1 @@ +Remove unused `# type: ignore`s. diff --git a/changelog.d/12541.docker b/changelog.d/12541.docker new file mode 100644 index 0000000000..c3b9c31657 --- /dev/null +++ b/changelog.d/12541.docker @@ -0,0 +1 @@ +Explicitly opt-in to using [BuildKit-specific features](https://github.com/moby/buildkit/blob/master/frontend/dockerfile/docs/syntax.md) in the Dockerfile. This fixes issues with building images in some GitLab CI environments. diff --git a/changelog.d/12544.bugfix b/changelog.d/12544.bugfix new file mode 100644 index 0000000000..b5169cd831 --- /dev/null +++ b/changelog.d/12544.bugfix @@ -0,0 +1 @@ +Fix a bug where attempting to send a large amount of read receipts to an application service all at once would result in duplicate content and abnormally high memory usage. Contributed by Brad & Nick @ Beeper. diff --git a/changelog.d/12556.misc b/changelog.d/12556.misc new file mode 100644 index 0000000000..dc245397fb --- /dev/null +++ b/changelog.d/12556.misc @@ -0,0 +1 @@ +Release script: confirm the commit to be tagged before tagging. diff --git a/changelog.d/12564.misc b/changelog.d/12564.misc new file mode 100644 index 0000000000..207c322464 --- /dev/null +++ b/changelog.d/12564.misc @@ -0,0 +1 @@ +Consistently check if an object is a `frozendict`. diff --git a/changelog.d/12576.misc b/changelog.d/12576.misc new file mode 100644 index 0000000000..71022c8633 --- /dev/null +++ b/changelog.d/12576.misc @@ -0,0 +1 @@ +Allow unused `#type: ignore` comments in bleeding edge CI jobs. diff --git a/changelog.d/12579.doc b/changelog.d/12579.doc new file mode 100644 index 0000000000..bcec5fe1af --- /dev/null +++ b/changelog.d/12579.doc @@ -0,0 +1 @@ +Add missing linebreak to pipx install instructions. diff --git a/changelog.d/12580.bugfix b/changelog.d/12580.bugfix new file mode 100644 index 0000000000..bedce405e2 --- /dev/null +++ b/changelog.d/12580.bugfix @@ -0,0 +1 @@ +Fix a long standing bug where status codes would almost always get logged as 200!, irrespective of the actual status code, when clients disconnect before a request has finished processing. diff --git a/changelog.d/12581.misc b/changelog.d/12581.misc new file mode 100644 index 0000000000..38d40b262b --- /dev/null +++ b/changelog.d/12581.misc @@ -0,0 +1 @@ +Improve docstrings for the receipts store. diff --git a/changelog.d/12582.misc b/changelog.d/12582.misc new file mode 100644 index 0000000000..5fa9c9afe8 --- /dev/null +++ b/changelog.d/12582.misc @@ -0,0 +1 @@ +Use constants for read-receipts in tests. diff --git a/changelog.d/12589.misc b/changelog.d/12589.misc new file mode 100644 index 0000000000..d362828d2e --- /dev/null +++ b/changelog.d/12589.misc @@ -0,0 +1 @@ +Remove special-case for `twisted` logger from default log config. diff --git a/changelog.d/12594.bugfix b/changelog.d/12594.bugfix new file mode 100644 index 0000000000..7411d9c079 --- /dev/null +++ b/changelog.d/12594.bugfix @@ -0,0 +1 @@ +Fix race when persisting an event and deleting a room that could lead to outbound federation breaking. diff --git a/changelog.d/12608.misc b/changelog.d/12608.misc new file mode 100644 index 0000000000..38272118fb --- /dev/null +++ b/changelog.d/12608.misc @@ -0,0 +1 @@ +Remove redundant lines of config from `mypy.ini`. \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index 4523c60645..ccc6a9f778 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,3 +1,4 @@ +# syntax=docker/dockerfile:1 # Dockerfile to build the matrixdotorg/synapse docker images. # # Note that it uses features which are only available in BuildKit - see diff --git a/docker/Dockerfile-workers b/docker/Dockerfile-workers index 9ccb2b22a7..24b03585f9 100644 --- a/docker/Dockerfile-workers +++ b/docker/Dockerfile-workers @@ -20,6 +20,9 @@ RUN rm /etc/nginx/sites-enabled/default # Copy Synapse worker, nginx and supervisord configuration template files COPY ./docker/conf-workers/* /conf/ +# Copy a script to prefix log lines with the supervisor program name +COPY ./docker/prefix-log /usr/local/bin/ + # Expose nginx listener port EXPOSE 8080/tcp diff --git a/docker/complement/SynapseWorkers.Dockerfile b/docker/complement/SynapseWorkers.Dockerfile index 65df2d114d..9a4438e730 100644 --- a/docker/complement/SynapseWorkers.Dockerfile +++ b/docker/complement/SynapseWorkers.Dockerfile @@ -34,13 +34,16 @@ WORKDIR /data # Copy the caddy config COPY conf-workers/caddy.complement.json /root/caddy.json +COPY conf-workers/postgres.supervisord.conf /etc/supervisor/conf.d/postgres.conf +COPY conf-workers/caddy.supervisord.conf /etc/supervisor/conf.d/caddy.conf + # Copy the entrypoint COPY conf-workers/start-complement-synapse-workers.sh / # Expose caddy's listener ports EXPOSE 8008 8448 -ENTRYPOINT /start-complement-synapse-workers.sh +ENTRYPOINT ["/start-complement-synapse-workers.sh"] # Update the healthcheck to have a shorter check interval HEALTHCHECK --start-period=5s --interval=1s --timeout=1s \ diff --git a/docker/complement/conf-workers/caddy.supervisord.conf b/docker/complement/conf-workers/caddy.supervisord.conf new file mode 100644 index 0000000000..d9ddb51dac --- /dev/null +++ b/docker/complement/conf-workers/caddy.supervisord.conf @@ -0,0 +1,7 @@ +[program:caddy] +command=/usr/local/bin/prefix-log /root/caddy run --config /root/caddy.json +autorestart=unexpected +stdout_logfile=/dev/stdout +stdout_logfile_maxbytes=0 +stderr_logfile=/dev/stderr +stderr_logfile_maxbytes=0 diff --git a/docker/complement/conf-workers/postgres.supervisord.conf b/docker/complement/conf-workers/postgres.supervisord.conf new file mode 100644 index 0000000000..5608342d1a --- /dev/null +++ b/docker/complement/conf-workers/postgres.supervisord.conf @@ -0,0 +1,16 @@ +[program:postgres] +command=/usr/local/bin/prefix-log /usr/bin/pg_ctlcluster 13 main start --foreground + +# Lower priority number = starts first +priority=1 + +autorestart=unexpected +stdout_logfile=/dev/stdout +stdout_logfile_maxbytes=0 +stderr_logfile=/dev/stderr +stderr_logfile_maxbytes=0 + +# Use 'Fast Shutdown' mode which aborts current transactions and closes connections quickly. +# (Default (TERM) is 'Smart Shutdown' which stops accepting new connections but +# lets existing connections close gracefully.) +stopsignal=INT diff --git a/docker/complement/conf-workers/start-complement-synapse-workers.sh b/docker/complement/conf-workers/start-complement-synapse-workers.sh index 2c1e05bd62..b9a6b55bbe 100755 --- a/docker/complement/conf-workers/start-complement-synapse-workers.sh +++ b/docker/complement/conf-workers/start-complement-synapse-workers.sh @@ -12,12 +12,6 @@ function log { # Replace the server name in the caddy config sed -i "s/{{ server_name }}/${SERVER_NAME}/g" /root/caddy.json -log "starting postgres" -pg_ctlcluster 13 main start - -log "starting caddy" -/root/caddy start --config /root/caddy.json - # Set the server name of the homeserver export SYNAPSE_SERVER_NAME=${SERVER_NAME} diff --git a/docker/conf/log.config b/docker/conf/log.config index 7a216a36a0..dc8c70befd 100644 --- a/docker/conf/log.config +++ b/docker/conf/log.config @@ -2,11 +2,7 @@ version: 1 formatters: precise: -{% if worker_name %} - format: '%(asctime)s - worker:{{ worker_name }} - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s' -{% else %} format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s' -{% endif %} handlers: {% if LOG_FILE_PATH %} diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 3bda6c300b..33fc20d218 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -171,7 +171,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { # Templates for sections that may be inserted multiple times in config files SUPERVISORD_PROCESS_CONFIG_BLOCK = """ [program:synapse_{name}] -command=/usr/local/bin/python -m {app} \ +command=/usr/local/bin/prefix-log /usr/local/bin/python -m {app} \ --config-path="{config_path}" \ --config-path=/conf/workers/shared.yaml \ --config-path=/conf/workers/{name}.yaml diff --git a/docker/prefix-log b/docker/prefix-log new file mode 100755 index 0000000000..0e26a4f19d --- /dev/null +++ b/docker/prefix-log @@ -0,0 +1,12 @@ +#!/bin/bash +# +# Prefixes all lines on stdout and stderr with the process name (as determined by +# the SUPERVISOR_PROCESS_NAME env var, which is automatically set by Supervisor). +# +# Usage: +# prefix-log command [args...] +# + +exec 1> >(awk '{print "'"${SUPERVISOR_PROCESS_NAME}"' | "$0}' >&1) +exec 2> >(awk '{print "'"${SUPERVISOR_PROCESS_NAME}"' | "$0}' >&2) +exec "$@" diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b8d8c0dbf0..67184c6b1a 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1323,6 +1323,12 @@ oembed: # #registration_requires_token: true +# Allow users to submit a token during registration to bypass any required 3pid +# steps configured in `registrations_require_3pid`. +# Defaults to false, requiring that registration tokens (if enabled) complete a 3pid flow. +# +#enable_registration_token_3pid_bypass: false + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index 2485ad25ed..3065a0e2d9 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -62,13 +62,6 @@ loggers: # information such as access tokens. level: INFO - twisted: - # We send the twisted logging directly to the file handler, - # to work around https://github.com/matrix-org/synapse/issues/3471 - # when using "buffer" logger. Use "console" to log to stderr instead. - handlers: [file] - propagate: false - root: level: INFO diff --git a/mypy.ini b/mypy.ini index a663bf6975..78699e3187 100644 --- a/mypy.ini +++ b/mypy.ini @@ -7,6 +7,7 @@ show_error_codes = True show_traceback = True mypy_path = stubs warn_unreachable = True +warn_unused_ignores = True local_partial_types = True no_implicit_optional = True @@ -23,10 +24,6 @@ files = # https://docs.python.org/3/library/re.html#re.X exclude = (?x) ^( - |scripts-dev/build_debian_packages.py - |scripts-dev/federation_client.py - |scripts-dev/release.py - |synapse/storage/databases/__init__.py |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py @@ -134,6 +131,11 @@ disallow_untyped_defs = True [mypy-synapse.metrics.*] disallow_untyped_defs = True +[mypy-synapse.metrics._reactor_metrics] +# This module imports select.epoll. That exists on Linux, but doesn't on macOS. +# See https://github.com/matrix-org/synapse/pull/11771. +warn_unused_ignores = False + [mypy-synapse.module_api.*] disallow_untyped_defs = True @@ -239,63 +241,29 @@ disallow_untyped_defs = True [mypy-authlib.*] ignore_missing_imports = True -[mypy-bcrypt] -ignore_missing_imports = True - [mypy-canonicaljson] ignore_missing_imports = True [mypy-constantly] ignore_missing_imports = True -[mypy-daemonize] -ignore_missing_imports = True - -[mypy-h11] -ignore_missing_imports = True - -[mypy-hiredis] -ignore_missing_imports = True - -[mypy-hyperlink] -ignore_missing_imports = True - [mypy-ijson.*] ignore_missing_imports = True -[mypy-importlib_metadata.*] -ignore_missing_imports = True - -[mypy-jaeger_client.*] -ignore_missing_imports = True - -[mypy-josepy.*] -ignore_missing_imports = True - -[mypy-jwt.*] -ignore_missing_imports = True - [mypy-lxml] ignore_missing_imports = True [mypy-msgpack] ignore_missing_imports = True -[mypy-nacl.*] -ignore_missing_imports = True - +# Note: WIP stubs available at +# https://github.com/microsoft/python-type-stubs/tree/64934207f523ad6b611e6cfe039d85d7175d7d0d/netaddr [mypy-netaddr] ignore_missing_imports = True [mypy-parameterized.*] ignore_missing_imports = True -[mypy-phonenumbers.*] -ignore_missing_imports = True - -[mypy-prometheus_client.*] -ignore_missing_imports = True - [mypy-pymacaroons.*] ignore_missing_imports = True @@ -308,23 +276,14 @@ ignore_missing_imports = True [mypy-saml2.*] ignore_missing_imports = True -[mypy-sentry_sdk] -ignore_missing_imports = True - [mypy-service_identity.*] ignore_missing_imports = True -[mypy-signedjson.*] +[mypy-srvlookup.*] ignore_missing_imports = True [mypy-treq.*] ignore_missing_imports = True -[mypy-twisted.*] -ignore_missing_imports = True - -[mypy-zope] -ignore_missing_imports = True - [mypy-incremental.*] ignore_missing_imports = True diff --git a/poetry.lock b/poetry.lock index 8c7af1fa1e..e27a44989c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -309,14 +309,15 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.14" -description = "Python Git Library" +version = "3.1.27" +description = "GitPython is a python library used to interact with Git repositories" category = "dev" optional = false -python-versions = ">=3.4" +python-versions = ">=3.7" [package.dependencies] gitdb = ">=4.0.1,<5" +typing-extensions = {version = ">=3.7.4.3", markers = "python_version < \"3.8\""} [[package]] name = "hiredis" @@ -1316,6 +1317,14 @@ optional = false python-versions = "*" [[package]] +name = "types-commonmark" +version = "0.9.2" +description = "Typing stubs for commonmark" +category = "dev" +optional = false +python-versions = "*" + +[[package]] name = "types-cryptography" version = "3.3.15" description = "Typing stubs for cryptography" @@ -1553,7 +1562,7 @@ url_preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "f482a4f594a165dfe01ce253a22510d5faf38647ab0dcebc35789350cafd9bf0" +content-hash = "3825cef058b8c9f520ef4b7acb92519be95db9a663a61c2e89a5fe431ed55655" [metadata.files] attrs = [ @@ -1766,8 +1775,8 @@ gitdb = [ {file = "gitdb-4.0.9.tar.gz", hash = "sha256:bac2fd45c0a1c9cf619e63a90d62bdc63892ef92387424b855792a6cabe789aa"}, ] gitpython = [ - {file = "GitPython-3.1.14-py3-none-any.whl", hash = "sha256:3283ae2fba31c913d857e12e5ba5f9a7772bbc064ae2bb09efafa71b0dd4939b"}, - {file = "GitPython-3.1.14.tar.gz", hash = "sha256:be27633e7509e58391f10207cd32b2a6cf5b908f92d9cd30da2e514e1137af61"}, + {file = "GitPython-3.1.27-py3-none-any.whl", hash = "sha256:5b68b000463593e05ff2b261acff0ff0972df8ab1b70d3cdbd41b546c8b8fc3d"}, + {file = "GitPython-3.1.27.tar.gz", hash = "sha256:1c885ce809e8ba2d88a29befeb385fcea06338d3640712b59ca623c220bb5704"}, ] hiredis = [ {file = "hiredis-2.0.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b4c8b0bc5841e578d5fb32a16e0c305359b987b850a06964bd5a62739d688048"}, @@ -2588,6 +2597,10 @@ types-bleach = [ {file = "types-bleach-4.1.4.tar.gz", hash = "sha256:2d30c2c4fb6854088ac636471352c9a51bf6c089289800d2a8060820a01cd43a"}, {file = "types_bleach-4.1.4-py3-none-any.whl", hash = "sha256:edffe173ed6d7b6f3543036a96204a9319c3bf6c3645917b14274e43f000cc9b"}, ] +types-commonmark = [ + {file = "types-commonmark-0.9.2.tar.gz", hash = "sha256:b894b67750c52fd5abc9a40a9ceb9da4652a391d75c1b480bba9cef90f19fc86"}, + {file = "types_commonmark-0.9.2-py3-none-any.whl", hash = "sha256:56f20199a1f9a2924443211a0ef97f8b15a8a956a7f4e9186be6950bf38d6d02"}, +] types-cryptography = [ {file = "types-cryptography-3.3.15.tar.gz", hash = "sha256:a7983a75a7b88a18f88832008f0ef140b8d1097888ec1a0824ec8fb7e105273b"}, {file = "types_cryptography-3.3.15-py3-none-any.whl", hash = "sha256:d9b0dd5465d7898d400850e7f35e5518aa93a7e23d3e11757cd81b4777089046"}, diff --git a/pyproject.toml b/pyproject.toml index f0f029f016..62e26fd95b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -251,6 +251,7 @@ flake8 = "*" mypy = "==0.931" mypy-zope = "==0.3.5" types-bleach = ">=4.1.0" +types-commonmark = ">=0.9.2" types-jsonschema = ">=3.2.0" types-opentracing = ">=2.4.2" types-Pillow = ">=8.3.4" @@ -270,7 +271,8 @@ idna = ">=2.5" # The following are used by the release script click = "==8.1.0" -GitPython = "==3.1.14" +# GitPython was == 3.1.14; bumped to 3.1.20, the first release with type hints. +GitPython = ">=3.1.20" commonmark = "==0.9.1" pygithub = "==1.55" # The following are executed as commands by the release script. diff --git a/scripts-dev/build_debian_packages.py b/scripts-dev/build_debian_packages.py index e3e6878686..38564893e9 100755 --- a/scripts-dev/build_debian_packages.py +++ b/scripts-dev/build_debian_packages.py @@ -17,7 +17,8 @@ import subprocess import sys import threading from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Sequence +from types import FrameType +from typing import Collection, Optional, Sequence, Set DISTS = ( "debian:buster", # oldstable: EOL 2022-08 @@ -41,15 +42,17 @@ projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) class Builder(object): def __init__( - self, redirect_stdout=False, docker_build_args: Optional[Sequence[str]] = None + self, + redirect_stdout: bool = False, + docker_build_args: Optional[Sequence[str]] = None, ): self.redirect_stdout = redirect_stdout self._docker_build_args = tuple(docker_build_args or ()) - self.active_containers = set() + self.active_containers: Set[str] = set() self._lock = threading.Lock() self._failed = False - def run_build(self, dist, skip_tests=False): + def run_build(self, dist: str, skip_tests: bool = False) -> None: """Build deb for a single distribution""" if self._failed: @@ -63,7 +66,7 @@ class Builder(object): self._failed = True raise - def _inner_build(self, dist, skip_tests=False): + def _inner_build(self, dist: str, skip_tests: bool = False) -> None: tag = dist.split(":", 1)[1] # Make the dir where the debs will live. @@ -138,7 +141,7 @@ class Builder(object): stdout.close() print("Completed build of %s" % (dist,)) - def kill_containers(self): + def kill_containers(self) -> None: with self._lock: active = list(self.active_containers) @@ -156,8 +159,10 @@ class Builder(object): self.active_containers.remove(c) -def run_builds(builder, dists, jobs=1, skip_tests=False): - def sig(signum, _frame): +def run_builds( + builder: Builder, dists: Collection[str], jobs: int = 1, skip_tests: bool = False +) -> None: + def sig(signum: int, _frame: Optional[FrameType]) -> None: print("Caught SIGINT") builder.kill_containers() diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 079d2f5ed0..763dd02c47 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -38,7 +38,7 @@ import argparse import base64 import json import sys -from typing import Any, Optional +from typing import Any, Dict, Optional, Tuple from urllib import parse as urlparse import requests @@ -47,13 +47,14 @@ import signedjson.types import srvlookup import yaml from requests.adapters import HTTPAdapter +from urllib3 import HTTPConnectionPool # uncomment the following to enable debug logging of http requests # from httplib import HTTPConnection # HTTPConnection.debuglevel = 1 -def encode_base64(input_bytes): +def encode_base64(input_bytes: bytes) -> str: """Encode bytes as a base64 string without any padding.""" input_len = len(input_bytes) @@ -63,7 +64,7 @@ def encode_base64(input_bytes): return output_string -def encode_canonical_json(value): +def encode_canonical_json(value: object) -> bytes: return json.dumps( value, # Encode code-points outside of ASCII as UTF-8 rather than \u escapes @@ -130,7 +131,7 @@ def request( sig, destination, ) - authorization_headers.append(header.encode("ascii")) + authorization_headers.append(header) print("Authorization: %s" % header, file=sys.stderr) dest = "matrix://%s%s" % (destination, path) @@ -139,7 +140,10 @@ def request( s = requests.Session() s.mount("matrix://", MatrixConnectionAdapter()) - headers = {"Host": destination, "Authorization": authorization_headers[0]} + headers: Dict[str, str] = { + "Host": destination, + "Authorization": authorization_headers[0], + } if method == "POST": headers["Content-Type"] = "application/json" @@ -154,7 +158,7 @@ def request( ) -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="Signs and sends a federation request to a matrix homeserver" ) @@ -212,6 +216,7 @@ def main(): if not args.server_name or not args.signing_key: read_args_from_config(args) + assert isinstance(args.signing_key, str) algorithm, version, key_base64 = args.signing_key.split() key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64) @@ -233,7 +238,7 @@ def main(): print("") -def read_args_from_config(args): +def read_args_from_config(args: argparse.Namespace) -> None: with open(args.config, "r") as fh: config = yaml.safe_load(fh) @@ -250,7 +255,7 @@ def read_args_from_config(args): class MatrixConnectionAdapter(HTTPAdapter): @staticmethod - def lookup(s, skip_well_known=False): + def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]: if s[-1] == "]": # ipv6 literal (with no port) return s, 8448 @@ -276,7 +281,7 @@ class MatrixConnectionAdapter(HTTPAdapter): return s, 8448 @staticmethod - def get_well_known(server_name): + def get_well_known(server_name: str) -> Optional[str]: uri = "https://%s/.well-known/matrix/server" % (server_name,) print("fetching %s" % (uri,), file=sys.stderr) @@ -299,7 +304,9 @@ class MatrixConnectionAdapter(HTTPAdapter): print("Invalid response from %s: %s" % (uri, e), file=sys.stderr) return None - def get_connection(self, url, proxies=None): + def get_connection( + self, url: str, proxies: Optional[Dict[str, str]] = None + ) -> HTTPConnectionPool: parsed = urlparse.urlparse(url) (host, port) = self.lookup(parsed.netloc) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 1217e14874..c775865212 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -16,7 +16,7 @@ can crop up, e.g the cache descriptors. """ -from typing import Callable, Optional +from typing import Callable, Optional, Type from mypy.nodes import ARG_NAMED_OPT from mypy.plugin import MethodSigContext, Plugin @@ -94,7 +94,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: return signature -def plugin(version: str): +def plugin(version: str) -> Type[SynapsePlugin]: # This is the entry point of the plugin, and let's us deal with the fact # that the mypy plugin interface is *not* stable by looking at the version # string. diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 9d7c7c445f..f4269e09bb 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -25,7 +25,7 @@ import sys import urllib.request from os import path from tempfile import TemporaryDirectory -from typing import List, Optional +from typing import Any, List, Optional, cast import attr import click @@ -36,7 +36,9 @@ from github import Github from packaging import version -def run_until_successful(command, *args, **kwargs): +def run_until_successful( + command: str, *args: Any, **kwargs: Any +) -> subprocess.CompletedProcess: while True: completed_process = subprocess.run(command, *args, **kwargs) exit_code = completed_process.returncode @@ -50,7 +52,7 @@ def run_until_successful(command, *args, **kwargs): @click.group() -def cli(): +def cli() -> None: """An interactive script to walk through the parts of creating a release. Requires the dev dependencies be installed, which can be done via: @@ -81,19 +83,13 @@ def cli(): @cli.command() -def prepare(): +def prepare() -> None: """Do the initial stages of creating a release, including creating release branch, updating changelog and pushing to GitHub. """ # Make sure we're in a git repo. - try: - repo = git.Repo() - except git.InvalidGitRepositoryError: - raise click.ClickException("Not in Synapse repo.") - - if repo.is_dirty(): - raise click.ClickException("Uncommitted changes exist.") + repo = get_repo_and_check_clean_checkout() click.secho("Updating git repo...") repo.remote().fetch() @@ -161,22 +157,21 @@ def prepare(): click.get_current_context().abort() # Switch to the release branch. - parsed_new_version: version.Version = version.parse(new_version) + # Cast safety: parse() won't return a version.LegacyVersion from our + # version string format. + parsed_new_version = cast(version.Version, version.parse(new_version)) # We assume for debian changelogs that we only do RCs or full releases. assert not parsed_new_version.is_devrelease assert not parsed_new_version.is_postrelease - release_branch_name = ( - f"release-v{parsed_new_version.major}.{parsed_new_version.minor}" - ) + release_branch_name = get_release_branch_name(parsed_new_version) release_branch = find_ref(repo, release_branch_name) if release_branch: if release_branch.is_remote(): # If the release branch only exists on the remote we check it out # locally. repo.git.checkout(release_branch_name) - release_branch = repo.active_branch else: # If a branch doesn't exist we create one. We ask which one branch it # should be based off, defaulting to sensible values depending on the @@ -198,13 +193,15 @@ def prepare(): click.get_current_context().abort() # Check out the base branch and ensure it's up to date - repo.head.reference = base_branch + repo.head.set_reference(base_branch, "check out the base branch") repo.head.reset(index=True, working_tree=True) if not base_branch.is_remote(): update_branch(repo) # Create the new release branch - release_branch = repo.create_head(release_branch_name, commit=base_branch) + # Type ignore will no longer be needed after GitPython 3.1.28. + # See https://github.com/gitpython-developers/GitPython/pull/1419 + repo.create_head(release_branch_name, commit=base_branch) # type: ignore[arg-type] # Switch to the release branch and ensure it's up to date. repo.git.checkout(release_branch_name) @@ -265,17 +262,11 @@ def prepare(): @cli.command() @click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"]) -def tag(gh_token: Optional[str]): +def tag(gh_token: Optional[str]) -> None: """Tags the release and generates a draft GitHub release""" # Make sure we're in a git repo. - try: - repo = git.Repo() - except git.InvalidGitRepositoryError: - raise click.ClickException("Not in Synapse repo.") - - if repo.is_dirty(): - raise click.ClickException("Uncommitted changes exist.") + repo = get_repo_and_check_clean_checkout() click.secho("Updating git repo...") repo.remote().fetch() @@ -288,12 +279,26 @@ def tag(gh_token: Optional[str]): if tag_name in repo.tags: raise click.ClickException(f"Tag {tag_name} already exists!\n") + # Check we're on the right release branch + release_branch = get_release_branch_name(current_version) + if repo.active_branch.name != release_branch: + click.echo( + f"Need to be on the release branch ({release_branch}) before tagging. " + f"Currently on ({repo.active_branch.name})." + ) + click.get_current_context().abort() + # Get the appropriate changelogs and tag. changes = get_changes_for_version(current_version) click.echo_via_pager(changes) if click.confirm("Edit text?", default=False): - changes = click.edit(changes, require_save=False) + edited_changes = click.edit(changes, require_save=False) + # This assert is for mypy's benefit. click's docs are a little unclear, but + # when `require_save=False`, not saving the temp file in the editor returns + # the original string. + assert edited_changes is not None + changes = edited_changes repo.create_tag(tag_name, message=changes, sign=True) @@ -347,22 +352,16 @@ def tag(gh_token: Optional[str]): @cli.command() @click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=True) -def publish(gh_token: str): - """Publish release.""" +def publish(gh_token: str) -> None: + """Publish release on GitHub.""" # Make sure we're in a git repo. - try: - repo = git.Repo() - except git.InvalidGitRepositoryError: - raise click.ClickException("Not in Synapse repo.") - - if repo.is_dirty(): - raise click.ClickException("Uncommitted changes exist.") + get_repo_and_check_clean_checkout() current_version = get_package_version() tag_name = f"v{current_version}" - if not click.confirm(f"Publish {tag_name}?", default=True): + if not click.confirm(f"Publish release {tag_name} on GitHub?", default=True): return # Publish the draft release @@ -390,12 +389,19 @@ def publish(gh_token: str): @cli.command() -def upload(): +def upload() -> None: """Upload release to pypi.""" current_version = get_package_version() tag_name = f"v{current_version}" + # Check we have the right tag checked out. + repo = get_repo_and_check_clean_checkout() + tag = repo.tag(f"refs/tags/{tag_name}") + if repo.head.commit != tag.commit: + click.echo("Tag {tag_name} (tag.commit) is not currently checked out!") + click.get_current_context().abort() + pypi_asset_names = [ f"matrix_synapse-{current_version}-py3-none-any.whl", f"matrix-synapse-{current_version}.tar.gz", @@ -418,7 +424,7 @@ def upload(): @cli.command() -def announce(): +def announce() -> None: """Generate markdown to announce the release.""" current_version = get_package_version() @@ -459,20 +465,36 @@ def get_package_version() -> version.Version: return version.Version(version_string) +def get_release_branch_name(version_number: version.Version) -> str: + return f"release-v{version_number.major}.{version_number.minor}" + + +def get_repo_and_check_clean_checkout() -> git.Repo: + """Get the project repo and check it's not got any uncommitted changes.""" + try: + repo = git.Repo() + except git.InvalidGitRepositoryError: + raise click.ClickException("Not in Synapse repo.") + if repo.is_dirty(): + raise click.ClickException("Uncommitted changes exist.") + return repo + + def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]: """Find the branch/ref, looking first locally then in the remote.""" - if ref_name in repo.refs: - return repo.refs[ref_name] + if ref_name in repo.references: + return repo.references[ref_name] elif ref_name in repo.remote().refs: return repo.remote().refs[ref_name] else: return None -def update_branch(repo: git.Repo): +def update_branch(repo: git.Repo) -> None: """Ensure branch is up to date if it has a remote""" - if repo.active_branch.tracking_branch(): - repo.git.merge(repo.active_branch.tracking_branch().name) + tracking_branch = repo.active_branch.tracking_branch() + if tracking_branch: + repo.git.merge(tracking_branch.name) def get_changes_for_version(wanted_version: version.Version) -> str: @@ -536,7 +558,9 @@ def get_changes_for_version(wanted_version: version.Version) -> str: return "\n".join(version_changelog) -def generate_and_write_changelog(current_version: version.Version, new_version: str): +def generate_and_write_changelog( + current_version: version.Version, new_version: str +) -> None: # We do this by getting a draft so that we can edit it before writing to the # changelog. result = run_until_successful( @@ -558,8 +582,8 @@ def generate_and_write_changelog(current_version: version.Version, new_version: f.write(existing_content) # Remove all the news fragments - for f in glob.iglob("changelog.d/*.*"): - os.remove(f) + for filename in glob.iglob("changelog.d/*.*"): + os.remove(filename) if __name__ == "__main__": diff --git a/scripts-dev/sign_json.py b/scripts-dev/sign_json.py index 9459543106..bb217799fb 100755 --- a/scripts-dev/sign_json.py +++ b/scripts-dev/sign_json.py @@ -27,7 +27,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.util import json_encoder -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="""Adds a signature to a JSON object. diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi index e18d617281..3a4f9c3076 100644 --- a/stubs/sortedcontainers/sorteddict.pyi +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -115,9 +115,7 @@ class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]): def __getitem__(self, index: slice) -> List[_KT_co]: ... def __delitem__(self, index: Union[int, slice]) -> None: ... -class SortedItemsView( # type: ignore - ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]] -): +class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]]): def __iter__(self) -> Iterator[Tuple[_KT_co, _VT_co]]: ... @overload def __getitem__(self, index: int) -> Tuple[_KT_co, _VT_co]: ... diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 37321f9133..d28b87a3f4 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -48,7 +48,6 @@ from twisted.logger import LoggingFile, LogLevel from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.python.threadpool import ThreadPool -import synapse from synapse.api.constants import MAX_PDU_SIZE from synapse.app import check_bind_error from synapse.app.phone_stats_home import start_phone_stats_home @@ -60,6 +59,7 @@ from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.logging.context import PreserveLoggingContext +from synapse.logging.opentracing import init_tracer from synapse.metrics import install_gc_manager, register_threadpool from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats @@ -431,7 +431,7 @@ async def start(hs: "HomeServer") -> None: refresh_certificate(hs) # Start the tracer - synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa + init_tracer(hs) # noqa # Instantiate the modules so they can register their web resources to the module API # before we start the listeners. diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 99db9e1e39..470b8b4492 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -110,13 +110,6 @@ loggers: # information such as access tokens. level: INFO - twisted: - # We send the twisted logging directly to the file handler, - # to work around https://github.com/matrix-org/synapse/issues/3471 - # when using "buffer" logger. Use "console" to log to stderr instead. - handlers: [file] - propagate: false - root: level: INFO diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 39e9acb62a..70eb7e6a97 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -43,6 +43,9 @@ class RegistrationConfig(Config): self.registration_requires_token = config.get( "registration_requires_token", False ) + self.enable_registration_token_3pid_bypasss = config.get( + "enable_registration_token_3pid_bypasss", False + ) self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) @@ -309,6 +312,12 @@ class RegistrationConfig(Config): # #registration_requires_token: true + # Allow users to submit a token during registration to bypass any required 3pid + # steps configured in `registrations_require_3pid`. + # Defaults to false, requiring that registration tokens (if enabled) complete a 3pid flow. + # + #enable_registration_token_3pid_bypass: false + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/synapse/config/server.py b/synapse/config/server.py index d771045b52..b6cd326416 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -186,7 +186,7 @@ KNOWN_RESOURCES = { class HttpResourceConfig: names: List[str] = attr.ib( factory=list, - validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), # type: ignore + validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), ) compress: bool = attr.ib( default=False, @@ -231,9 +231,7 @@ class ManholeConfig: class LimitRemoteRoomsConfig: enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False) complexity: Union[float, int] = attr.ib( - validator=attr.validators.instance_of( - (float, int) # type: ignore[arg-type] # noqa - ), + validator=attr.validators.instance_of((float, int)), # noqa default=1.0, ) complexity_error: str = attr.ib( diff --git a/synapse/events/utils.py b/synapse/events/utils.py index f8d3ba5456..a6c48308b3 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -27,7 +27,6 @@ from typing import ( ) import attr -from frozendict import frozendict from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError @@ -204,7 +203,9 @@ def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None: key_to_move = field.pop(-1) sub_dict = src for sub_field in field: # e.g. sub_field => "content" - if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]: + if sub_field in sub_dict and isinstance( + sub_dict[sub_field], collections.abc.Mapping + ): sub_dict = sub_dict[sub_field] else: return @@ -622,7 +623,7 @@ def validate_canonicaljson(value: Any) -> None: # Note that Infinity, -Infinity, and NaN are also considered floats. raise SynapseError(400, "Bad JSON value: float", Codes.BAD_JSON) - elif isinstance(value, (dict, frozendict)): + elif isinstance(value, collections.abc.Mapping): for v in value.values(): validate_canonicaljson(v) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index beab1227b8..884b5d60b4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -268,8 +268,8 @@ class FederationServer(FederationBase): transaction_id=transaction_id, destination=destination, origin=origin, - origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore - pdus=transaction_data.get("pdus"), # type: ignore + origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore[arg-type] + pdus=transaction_data.get("pdus"), edus=transaction_data.get("edus"), ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 1421050b9a..9ce06dfa28 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -229,21 +229,21 @@ class TransportLayerClient: """ logger.debug( "send_data dest=%s, txid=%s", - transaction.destination, # type: ignore - transaction.transaction_id, # type: ignore + transaction.destination, + transaction.transaction_id, ) - if transaction.destination == self.server_name: # type: ignore + if transaction.destination == self.server_name: raise RuntimeError("Transport layer cannot send to itself!") # FIXME: This is only used by the tests. The actual json sent is # generated by the json_data_callback. json_data = transaction.get_dict() - path = _create_v1_path("/send/%s", transaction.transaction_id) # type: ignore + path = _create_v1_path("/send/%s", transaction.transaction_id) return await self.client.put_json( - transaction.destination, # type: ignore + transaction.destination, path=path, data=json_data, json_data_callback=json_data_callback, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 1b57840506..b3894666cc 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -416,7 +416,7 @@ class ApplicationServicesHandler: return typing async def _handle_receipts( - self, service: ApplicationService, new_token: Optional[int] + self, service: ApplicationService, new_token: int ) -> List[JsonDict]: """ Return the latest read receipts that the given application service should receive. @@ -447,7 +447,7 @@ class ApplicationServicesHandler: receipts_source = self.event_sources.sources.receipt receipts, _ = await receipts_source.get_new_events_as( - service=service, from_key=from_key + service=service, from_key=from_key, to_key=new_token ) return receipts diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 86991d26ce..22678d486d 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -481,7 +481,7 @@ class AuthHandler: sid = authdict["session"] # Convert the URI and method to strings. - uri = request.uri.decode("utf-8") # type: ignore + uri = request.uri.decode("utf-8") method = request.method.decode("utf-8") # If there's no session ID, create a new session. diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 724b9cfcb4..f6ffb7d18d 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -966,7 +966,7 @@ class OidcProvider: "Mapping provider does not support de-duplicating Matrix IDs" ) - attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore + attributes = await self._user_mapping_provider.map_user_attributes( userinfo, token ) diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py new file mode 100644 index 0000000000..2599160bcc --- /dev/null +++ b/synapse/handlers/push_rules.py @@ -0,0 +1,138 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, List, Optional, Union + +import attr + +from synapse.api.errors import SynapseError, UnrecognizedRequestError +from synapse.push.baserules import BASE_RULE_IDS +from synapse.storage.push_rule import RuleNotFoundException +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RuleSpec: + scope: str + template: str + rule_id: str + attr: Optional[str] + + +class PushRulesHandler: + """A class to handle changes in push rules for users.""" + + def __init__(self, hs: "HomeServer"): + self._notifier = hs.get_notifier() + self._main_store = hs.get_datastores().main + + async def set_rule_attr( + self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict] + ) -> None: + """Set an attribute (enabled or actions) on an existing push rule. + + Notifies listeners (e.g. sync handler) of the change. + + Args: + user_id: the user for which to modify the push rule. + spec: the spec of the push rule to modify. + val: the value to change the attribute to. + + Raises: + RuleNotFoundException if the rule being modified doesn't exist. + SynapseError(400) if the value is malformed. + UnrecognizedRequestError if the attribute to change is unknown. + InvalidRuleException if we're trying to change the actions on a rule but + the provided actions aren't compliant with the spec. + """ + if spec.attr not in ("enabled", "actions"): + # for the sake of potential future expansion, shouldn't report + # 404 in the case of an unknown request so check it corresponds to + # a known attribute first. + raise UnrecognizedRequestError() + + namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}" + rule_id = spec.rule_id + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise RuleNotFoundException("Unknown rule %r" % (namespaced_rule_id,)) + if spec.attr == "enabled": + if isinstance(val, dict) and "enabled" in val: + val = val["enabled"] + if not isinstance(val, bool): + # Legacy fallback + # This should *actually* take a dict, but many clients pass + # bools directly, so let's not break them. + raise SynapseError(400, "Value for 'enabled' must be boolean") + await self._main_store.set_push_rule_enabled( + user_id, namespaced_rule_id, val, is_default_rule + ) + elif spec.attr == "actions": + if not isinstance(val, dict): + raise SynapseError(400, "Value must be a dict") + actions = val.get("actions") + if not isinstance(actions, list): + raise SynapseError(400, "Value for 'actions' must be dict") + check_actions(actions) + rule_id = spec.rule_id + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise RuleNotFoundException( + "Unknown rule %r" % (namespaced_rule_id,) + ) + await self._main_store.set_push_rule_actions( + user_id, namespaced_rule_id, actions, is_default_rule + ) + else: + raise UnrecognizedRequestError() + + self.notify_user(user_id) + + def notify_user(self, user_id: str) -> None: + """Notify listeners about a push rule change. + + Args: + user_id: the user ID the change is for. + """ + stream_id = self._main_store.get_max_push_rules_stream_id() + self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) + + +def check_actions(actions: List[Union[str, JsonDict]]) -> None: + """Check if the given actions are spec compliant. + + Args: + actions: the actions to check. + + Raises: + InvalidRuleException if the rules aren't compliant with the spec. + """ + if not isinstance(actions, list): + raise InvalidRuleException("No actions found") + + for a in actions: + if a in ["notify", "dont_notify", "coalesce"]: + pass + elif isinstance(a, dict) and "set_tweak" in a: + pass + else: + raise InvalidRuleException("Unrecognised action %s" % a) + + +class InvalidRuleException(Exception): + pass diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 6250bb3bdf..cfe860decc 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -239,13 +239,14 @@ class ReceiptEventSource(EventSource[int, JsonDict]): return events, to_key async def get_new_events_as( - self, from_key: int, service: ApplicationService + self, from_key: int, to_key: int, service: ApplicationService ) -> Tuple[List[JsonDict], int]: """Returns a set of new read receipt events that an appservice may be interested in. Args: from_key: the stream position at which events should be fetched from + to_key: the stream position up to which events should be fetched to service: The appservice which may be interested Returns: @@ -255,7 +256,6 @@ class ReceiptEventSource(EventSource[int, JsonDict]): * The current read receipt stream token. """ from_key = int(from_key) - to_key = self.get_current_key() if from_key == to_key: return [], to_key diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 5efb561273..b5dc9f74b3 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections.abc import logging from typing import ( TYPE_CHECKING, @@ -24,7 +25,6 @@ from typing import ( ) import attr -from frozendict import frozendict from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError @@ -380,7 +380,7 @@ class RelationsHandler: # Do not bundle aggregations for an event which represents an edit or an # annotation. It does not make sense for them to have related events. relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): + if isinstance(relates_to, collections.abc.Mapping): relation_type = relates_to.get("rel_type") if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): continue diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 102dd4b57d..5619f8f50e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -357,7 +357,7 @@ class SearchHandler: itertools.chain( # The events_before and events_after for each context. itertools.chain.from_iterable( - itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type] + itertools.chain(context["events_before"], context["events_after"]) for context in contexts.values() ), # The returned events. @@ -373,10 +373,10 @@ class SearchHandler: for context in contexts.values(): context["events_before"] = self._event_serializer.serialize_events( - context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] + context["events_before"], time_now, bundle_aggregations=aggregations ) context["events_after"] = self._event_serializer.serialize_events( - context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] + context["events_after"], time_now, bundle_aggregations=aggregations ) results = [ diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 472b029af3..e2a441066d 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -256,7 +256,9 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs - self._enabled = bool(hs.config.registration.registration_requires_token) + self._enabled = bool( + hs.config.registration.registration_requires_token + ) or bool(hs.config.registration.enable_registration_token_3pid_bypasss) self.store = hs.get_datastores().main def is_enabled(self) -> bool: diff --git a/synapse/http/server.py b/synapse/http/server.py index 31ca841889..1cf49830e8 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -295,7 +295,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): if isawaitable(raw_callback_return): callback_return = await raw_callback_return else: - callback_return = raw_callback_return # type: ignore + callback_return = raw_callback_return return callback_return @@ -469,7 +469,7 @@ class JsonResource(DirectServeJsonResource): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): callback_return = await raw_callback_return else: - callback_return = raw_callback_return # type: ignore + callback_return = raw_callback_return return callback_return @@ -683,6 +683,9 @@ def respond_with_json( Returns: twisted.web.server.NOT_DONE_YET if the request is still active. """ + # The response code must always be set, for logging purposes. + request.setResponseCode(code) + # could alternatively use request.notifyFinish() and flip a flag when # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. @@ -697,7 +700,6 @@ def respond_with_json( else: encoder = _encode_json_bytes - request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") @@ -728,13 +730,15 @@ def respond_with_json_bytes( Returns: twisted.web.server.NOT_DONE_YET if the request is still active. """ + # The response code must always be set, for logging purposes. + request.setResponseCode(code) + if request._disconnected: logger.warning( "Not sending response to request %s, already disconnected.", request ) return None - request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") @@ -840,6 +844,9 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> N code: The HTTP response code. html_bytes: The HTML bytes to use as the response body. """ + # The response code must always be set, for logging purposes. + request.setResponseCode(code) + # could alternatively use request.notifyFinish() and flip a flag when # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. @@ -849,7 +856,6 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> N ) return None - request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 88cd8a9e1c..fd9cb97920 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -722,6 +722,11 @@ P = ParamSpec("P") R = TypeVar("R") +async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R: + """Unwraps an arbitrary awaitable by awaiting it.""" + return await awaitable + + @overload def preserve_fn( # type: ignore[misc] f: Callable[P, Awaitable[R]], @@ -802,17 +807,20 @@ def run_in_background( # type: ignore[misc] # by synchronous exceptions, so let's turn them into Failures. return defer.fail() + # `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain + # value. Convert it to a `Deferred`. if isinstance(res, typing.Coroutine): + # Wrap the coroutine in a `Deferred`. res = defer.ensureDeferred(res) - - # At this point we should have a Deferred, if not then f was a synchronous - # function, wrap it in a Deferred for consistency. - if not isinstance(res, defer.Deferred): - # `res` is not a `Deferred` and not a `Coroutine`. - # There are no other types of `Awaitable`s we expect to encounter in Synapse. - assert not isinstance(res, Awaitable) - - return defer.succeed(res) + elif isinstance(res, defer.Deferred): + pass + elif isinstance(res, Awaitable): + # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable` + # or `Future` from `make_awaitable`. + res = defer.ensureDeferred(_unwrap_awaitable(res)) + else: + # `res` is a plain value. Wrap it in a `Deferred`. + res = defer.succeed(res) if res.called and not res.paused: # The function should have maintained the logcontext, so we can diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 8f9e629274..834fe1b62c 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -82,6 +82,7 @@ from synapse.handlers.auth import ( ON_LOGGED_OUT_CALLBACK, AuthHandler, ) +from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient from synapse.http.server import ( DirectServeHtmlResource, @@ -109,6 +110,7 @@ from synapse.storage.state import StateFilter from synapse.types import ( DomainSpecificString, JsonDict, + JsonMapping, Requester, StateMap, UserID, @@ -151,6 +153,7 @@ __all__ = [ "PRESENCE_ALL_USERS", "LoginResponse", "JsonDict", + "JsonMapping", "EventBase", "StateMap", "ProfileInfo", @@ -193,6 +196,7 @@ class ModuleApi: self._clock: Clock = hs.get_clock() self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() + self._push_rules_handler = hs.get_push_rules_handler() self.custom_template_dir = hs.config.server.custom_template_directory try: @@ -1350,6 +1354,68 @@ class ModuleApi: """ await self._store.add_user_bound_threepid(user_id, medium, address, id_server) + def check_push_rule_actions( + self, actions: List[Union[str, Dict[str, str]]] + ) -> None: + """Checks if the given push rule actions are valid according to the Matrix + specification. + + See https://spec.matrix.org/v1.2/client-server-api/#actions for the list of valid + actions. + + Added in Synapse v1.58.0. + + Args: + actions: the actions to check. + + Raises: + synapse.module_api.errors.InvalidRuleException if the actions are invalid. + """ + check_actions(actions) + + async def set_push_rule_action( + self, + user_id: str, + scope: str, + kind: str, + rule_id: str, + actions: List[Union[str, Dict[str, str]]], + ) -> None: + """Changes the actions of an existing push rule for the given user. + + See https://spec.matrix.org/v1.2/client-server-api/#push-rules for more + information about push rules and their syntax. + + Can only be called on the main process. + + Added in Synapse v1.58.0. + + Args: + user_id: the user for which to change the push rule's actions. + scope: the push rule's scope, currently only "global" is allowed. + kind: the push rule's kind. + rule_id: the push rule's identifier. + actions: the actions to run when the rule's conditions match. + + Raises: + RuntimeError if this method is called on a worker or `scope` is invalid. + synapse.module_api.errors.RuleNotFoundException if the rule being modified + can't be found. + synapse.module_api.errors.InvalidRuleException if the actions are invalid. + """ + if self.worker_app is not None: + raise RuntimeError("module tried to change push rule actions on a worker") + + if scope != "global": + raise RuntimeError( + "invalid scope %s, only 'global' is currently allowed" % scope + ) + + spec = RuleSpec(scope, kind, rule_id, "actions") + await self._push_rules_handler.set_rule_attr( + user_id, spec, {"actions": actions} + ) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room @@ -1419,7 +1485,7 @@ class AccountDataManager: f"{user_id} is not local to this homeserver; can't access account data for remote users." ) - async def get_global(self, user_id: str, data_type: str) -> Optional[JsonDict]: + async def get_global(self, user_id: str, data_type: str) -> Optional[JsonMapping]: """ Gets some global account data, of a specified type, for the specified user. diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py index 1db900e41f..e58e0e60fe 100644 --- a/synapse/module_api/errors.py +++ b/synapse/module_api/errors.py @@ -20,10 +20,14 @@ from synapse.api.errors import ( SynapseError, ) from synapse.config._base import ConfigError +from synapse.handlers.push_rules import InvalidRuleException +from synapse.storage.push_rule import RuleNotFoundException __all__ = [ "InvalidClientCredentialsError", "RedirectException", "SynapseError", "ConfigError", + "InvalidRuleException", + "RuleNotFoundException", ] diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index a93f6fd5e0..b98640b14a 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union - -import attr +from typing import TYPE_CHECKING, List, Sequence, Tuple, Union from synapse.api.errors import ( NotFoundError, @@ -22,6 +20,7 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) +from synapse.handlers.push_rules import InvalidRuleException, RuleSpec, check_actions from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -29,7 +28,6 @@ from synapse.http.servlet import ( parse_string, ) from synapse.http.site import SynapseRequest -from synapse.push.baserules import BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client._base import client_patterns @@ -40,14 +38,6 @@ if TYPE_CHECKING: from synapse.server import HomeServer -@attr.s(slots=True, frozen=True, auto_attribs=True) -class RuleSpec: - scope: str - template: str - rule_id: str - attr: Optional[str] - - class PushRuleRestServlet(RestServlet): PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True) SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( @@ -60,6 +50,7 @@ class PushRuleRestServlet(RestServlet): self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None + self._push_rules_handler = hs.get_push_rules_handler() async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: @@ -81,8 +72,13 @@ class PushRuleRestServlet(RestServlet): user_id = requester.user.to_string() if spec.attr: - await self.set_rule_attr(user_id, spec, content) - self.notify_user(user_id) + try: + await self._push_rules_handler.set_rule_attr(user_id, spec, content) + except InvalidRuleException as e: + raise SynapseError(400, "Invalid actions: %s" % e) + except RuleNotFoundException: + raise NotFoundError("Unknown rule") + return 200, {} if spec.rule_id.startswith("."): @@ -98,23 +94,23 @@ class PushRuleRestServlet(RestServlet): before = parse_string(request, "before") if before: - before = _namespaced_rule_id(spec, before) + before = f"global/{spec.template}/{before}" after = parse_string(request, "after") if after: - after = _namespaced_rule_id(spec, after) + after = f"global/{spec.template}/{after}" try: await self.store.add_push_rule( user_id=user_id, - rule_id=_namespaced_rule_id_from_spec(spec), + rule_id=f"global/{spec.template}/{spec.rule_id}", priority_class=priority_class, conditions=conditions, actions=actions, before=before, after=after, ) - self.notify_user(user_id) + self._push_rules_handler.notify_user(user_id) except InconsistentRuleException as e: raise SynapseError(400, str(e)) except RuleNotFoundException as e: @@ -133,11 +129,11 @@ class PushRuleRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}" try: await self.store.delete_push_rule(user_id, namespaced_rule_id) - self.notify_user(user_id) + self._push_rules_handler.notify_user(user_id) return 200, {} except StoreError as e: if e.code == 404: @@ -172,55 +168,6 @@ class PushRuleRestServlet(RestServlet): else: raise UnrecognizedRequestError() - def notify_user(self, user_id: str) -> None: - stream_id = self.store.get_max_push_rules_stream_id() - self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) - - async def set_rule_attr( - self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict] - ) -> None: - if spec.attr not in ("enabled", "actions"): - # for the sake of potential future expansion, shouldn't report - # 404 in the case of an unknown request so check it corresponds to - # a known attribute first. - raise UnrecognizedRequestError() - - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec.rule_id - is_default_rule = rule_id.startswith(".") - if is_default_rule: - if namespaced_rule_id not in BASE_RULE_IDS: - raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) - if spec.attr == "enabled": - if isinstance(val, dict) and "enabled" in val: - val = val["enabled"] - if not isinstance(val, bool): - # Legacy fallback - # This should *actually* take a dict, but many clients pass - # bools directly, so let's not break them. - raise SynapseError(400, "Value for 'enabled' must be boolean") - await self.store.set_push_rule_enabled( - user_id, namespaced_rule_id, val, is_default_rule - ) - elif spec.attr == "actions": - if not isinstance(val, dict): - raise SynapseError(400, "Value must be a dict") - actions = val.get("actions") - if not isinstance(actions, list): - raise SynapseError(400, "Value for 'actions' must be dict") - _check_actions(actions) - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec.rule_id - is_default_rule = rule_id.startswith(".") - if is_default_rule: - if namespaced_rule_id not in BASE_RULE_IDS: - raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - await self.store.set_push_rule_actions( - user_id, namespaced_rule_id, actions, is_default_rule - ) - else: - raise UnrecognizedRequestError() - def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec: """Turn a sequence of path components into a rule spec @@ -291,24 +238,11 @@ def _rule_tuple_from_request_object( raise InvalidRuleException("No actions found") actions = req_obj["actions"] - _check_actions(actions) + check_actions(actions) return conditions, actions -def _check_actions(actions: List[Union[str, JsonDict]]) -> None: - if not isinstance(actions, list): - raise InvalidRuleException("No actions found") - - for a in actions: - if a in ["notify", "dont_notify", "coalesce"]: - pass - elif isinstance(a, dict) and "set_tweak" in a: - pass - else: - raise InvalidRuleException("Unrecognised action") - - def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict: if path == []: raise UnrecognizedRequestError( @@ -357,17 +291,5 @@ def _priority_class_from_spec(spec: RuleSpec) -> int: return pc -def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str: - return _namespaced_rule_id(spec, spec.rule_id) - - -def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str: - return "global/%s/%s" % (spec.template, rule_id) - - -class InvalidRuleException(Exception): - pass - - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 70baf50fa4..13ef6b35a0 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -929,6 +929,10 @@ def _calculate_registration_flows( # always let users provide both MSISDN & email flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) + # Add a flow that doesn't require any 3pids, if the config requests it. + if config.registration.enable_registration_token_3pid_bypasss: + flows.append([LoginType.REGISTRATION_TOKEN]) + # Prepend m.login.terms to all flows if we're requiring consent if config.consent.user_consent_at_registration: for flow in flows: @@ -942,7 +946,8 @@ def _calculate_registration_flows( # Prepend registration token to all flows if we're requiring a token if config.registration.registration_requires_token: for flow in flows: - flow.insert(0, LoginType.REGISTRATION_TOKEN) + if LoginType.REGISTRATION_TOKEN not in flow: + flow.insert(0, LoginType.REGISTRATION_TOKEN) return flows diff --git a/synapse/server.py b/synapse/server.py index 37c72bd83a..d49c76518a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -91,6 +91,7 @@ from synapse.handlers.presence import ( WorkerPresenceHandler, ) from synapse.handlers.profile import ProfileHandler +from synapse.handlers.push_rules import PushRulesHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.register import RegistrationHandler @@ -811,6 +812,10 @@ class HomeServer(metaclass=abc.ABCMeta): return AccountHandler(self) @cache_in_self + def get_push_rules_handler(self) -> PushRulesHandler: + return PushRulesHandler(self) + + @cache_in_self def get_outbound_redis_connection(self) -> "ConnectionHandler": """ The Redis connection used for replication. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 951031af50..5895b89202 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -15,12 +15,17 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from synapse.config.homeserver import HomeServerConfig -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.stats import UserSortOrder -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.types import Cursor from synapse.storage.util.id_generators import ( IdGenerator, MultiWriterIdGenerator, @@ -266,7 +271,9 @@ class DataStore( A tuple of a list of mappings from user to information and a count of total users. """ - def get_users_paginate_txn(txn): + def get_users_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: filters = [] args = [self.hs.config.server.server_name] @@ -301,7 +308,7 @@ class DataStore( """ sql = "SELECT COUNT(*) as total_users " + sql_base txn.execute(sql, args) - count = txn.fetchone()[0] + count = cast(Tuple[int], txn.fetchone())[0] sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, @@ -338,7 +345,9 @@ class DataStore( ) -def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig): +def check_database_before_upgrade( + cur: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig +) -> None: """Called before upgrading an existing database to check that it is broadly sane compared with the configuration. """ diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index fa732edcca..945707b0ec 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast from synapse.appservice import ( ApplicationService, @@ -83,7 +83,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): txn.execute( "SELECT COALESCE(max(txn_id), 0) FROM application_services_txns" ) - return txn.fetchone()[0] # type: ignore + return cast(Tuple[int], txn.fetchone())[0] self._as_txn_seq_gen = build_sequence_generator( db_conn, diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index b4a1b041b1..599b418383 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -14,7 +14,17 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + cast, +) from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -118,7 +128,13 @@ class DeviceInboxWorkerStore(SQLBaseStore): prefilled_cache=device_outbox_prefill, ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[ToDeviceStream.ToDeviceStreamRow], + ) -> None: if stream_name == ToDeviceStream.NAME: # If replication is happening than postgres must be being used. assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator) @@ -134,7 +150,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) return super().process_replication_rows(stream_name, instance_name, token, rows) - def get_to_device_stream_token(self): + def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() async def get_messages_for_user_devices( @@ -301,7 +317,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): if not user_ids_to_query: return {}, to_stream_id - def get_device_messages_txn(txn: LoggingTransaction): + def get_device_messages_txn( + txn: LoggingTransaction, + ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: # Build a query to select messages from any of the given devices that # are between the given stream id bounds. @@ -428,7 +446,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): log_kv({"message": "No changes in cache since last check"}) return 0 - def delete_messages_for_device_txn(txn): + def delete_messages_for_device_txn(txn: LoggingTransaction) -> int: sql = ( "DELETE FROM device_inbox" " WHERE user_id = ? AND device_id = ?" @@ -455,15 +473,14 @@ class DeviceInboxWorkerStore(SQLBaseStore): @trace async def get_new_device_msgs_for_remote( - self, destination, last_stream_id, current_stream_id, limit - ) -> Tuple[List[dict], int]: + self, destination: str, last_stream_id: int, current_stream_id: int, limit: int + ) -> Tuple[List[JsonDict], int]: """ Args: - destination(str): The name of the remote server. - last_stream_id(int|long): The last position of the device message stream + destination: The name of the remote server. + last_stream_id: The last position of the device message stream that the server sent up to. - current_stream_id(int|long): The current position of the device - message stream. + current_stream_id: The current position of the device message stream. Returns: A list of messages for the device and where in the stream the messages got to. """ @@ -485,7 +502,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): return [], last_stream_id @trace - def get_new_messages_for_remote_destination_txn(txn): + def get_new_messages_for_remote_destination_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: sql = ( "SELECT stream_id, messages_json FROM device_federation_outbox" " WHERE destination = ?" @@ -527,7 +546,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): up_to_stream_id: Where to delete messages up to. """ - def delete_messages_for_remote_destination_txn(txn): + def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None: sql = ( "DELETE FROM device_federation_outbox" " WHERE destination = ?" @@ -566,7 +585,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_new_device_messages_txn(txn): + def get_all_new_device_messages_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # We limit like this as we might have multiple rows per stream_id, and # we want to make sure we always get all entries for any stream_id # we return. @@ -607,8 +628,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): @trace async def add_messages_to_device_inbox( self, - local_messages_by_user_then_device: dict, - remote_messages_by_destination: dict, + local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], + remote_messages_by_destination: Dict[str, JsonDict], ) -> int: """Used to send messages from this server. @@ -624,7 +645,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): assert self._can_write_to_device - def add_messages_txn(txn, now_ms, stream_id): + def add_messages_txn( + txn: LoggingTransaction, now_ms: int, stream_id: int + ) -> None: # Add the local messages directly to the local inbox. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device @@ -677,11 +700,16 @@ class DeviceInboxWorkerStore(SQLBaseStore): return self._device_inbox_id_gen.get_current_token() async def add_messages_from_remote_to_device_inbox( - self, origin: str, message_id: str, local_messages_by_user_then_device: dict + self, + origin: str, + message_id: str, + local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], ) -> int: assert self._can_write_to_device - def add_messages_txn(txn, now_ms, stream_id): + def add_messages_txn( + txn: LoggingTransaction, now_ms: int, stream_id: int + ) -> None: # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. @@ -727,8 +755,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): return stream_id def _add_messages_to_local_device_inbox_txn( - self, txn, stream_id, messages_by_user_then_device - ): + self, + txn: LoggingTransaction, + stream_id: int, + messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], + ) -> None: assert self._can_write_to_device local_by_user_then_device = {} @@ -840,8 +871,10 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): self._remove_dead_devices_from_device_inbox, ) - async def _background_drop_index_device_inbox(self, progress, batch_size): - def reindex_txn(conn): + async def _background_drop_index_device_inbox( + self, progress: JsonDict, batch_size: int + ) -> int: + def reindex_txn(conn: LoggingDatabaseConnection) -> None: txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.close() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 483dd80406..2df4dd4ed4 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -25,6 +25,7 @@ from typing import ( Optional, Set, Tuple, + cast, ) from synapse.api.errors import Codes, StoreError @@ -136,7 +137,9 @@ class DeviceWorkerStore(SQLBaseStore): Number of devices of this users. """ - def count_devices_by_users_txn(txn, user_ids): + def count_devices_by_users_txn( + txn: LoggingTransaction, user_ids: List[str] + ) -> int: sql = """ SELECT count(*) FROM devices @@ -149,7 +152,7 @@ class DeviceWorkerStore(SQLBaseStore): ) txn.execute(sql + clause, args) - return txn.fetchone()[0] + return cast(Tuple[int], txn.fetchone())[0] if not user_ids: return 0 @@ -468,7 +471,7 @@ class DeviceWorkerStore(SQLBaseStore): """ txn.execute(sql, (destination, from_stream_id, now_stream_id, limit)) - return list(txn) + return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall()) async def _get_device_update_edus_by_remote( self, @@ -549,7 +552,7 @@ class DeviceWorkerStore(SQLBaseStore): async def _get_last_device_update_for_remote_user( self, destination: str, user_id: str, from_stream_id: int ) -> int: - def f(txn): + def f(txn: LoggingTransaction) -> int: prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id FROM device_lists_outbound_last_success @@ -767,7 +770,7 @@ class DeviceWorkerStore(SQLBaseStore): if not user_ids_to_check: return set() - def _get_users_whose_devices_changed_txn(txn): + def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: changes = set() stream_id_where_clause = "stream_id > ?" @@ -966,7 +969,9 @@ class DeviceWorkerStore(SQLBaseStore): async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: """Mark that we no longer track device lists for remote user.""" - def _mark_remote_user_device_list_as_unsubscribed_txn(txn): + def _mark_remote_user_device_list_as_unsubscribed_txn( + txn: LoggingTransaction, + ) -> None: self.db_pool.simple_delete_txn( txn, table="device_lists_remote_extremeties", @@ -1004,7 +1009,7 @@ class DeviceWorkerStore(SQLBaseStore): ) def _store_dehydrated_device_txn( - self, txn, user_id: str, device_id: str, device_data: str + self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str ) -> Optional[str]: old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, @@ -1081,7 +1086,7 @@ class DeviceWorkerStore(SQLBaseStore): """ yesterday = self._clock.time_msec() - prune_age - def _prune_txn(txn): + def _prune_txn(txn: LoggingTransaction) -> None: # look for (user, destination) pairs which have an update older than # the cutoff. # @@ -1204,8 +1209,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): "drop_device_lists_outbound_last_success_non_unique_idx", ) - async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): - def f(conn): + async def _drop_device_list_streams_non_unique_indexes( + self, progress: JsonDict, batch_size: int + ) -> int: + def f(conn: LoggingDatabaseConnection) -> None: txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") @@ -1217,7 +1224,9 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): ) return 1 - async def _remove_duplicate_outbound_pokes(self, progress, batch_size): + async def _remove_duplicate_outbound_pokes( + self, progress: JsonDict, batch_size: int + ) -> int: # for some reason, we have accumulated duplicate entries in # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less # efficient. @@ -1230,7 +1239,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, ) - def _txn(txn): + def _txn(txn: LoggingTransaction) -> int: clause, args = make_tuple_comparison_clause( [(x, last_row[x]) for x in KEY_COLS] ) @@ -1602,7 +1611,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context = get_active_span_text_map() - def add_device_changes_txn(txn, stream_ids): + def add_device_changes_txn( + txn: LoggingTransaction, stream_ids: List[int] + ) -> None: self._add_device_change_to_stream_txn( txn, user_id, @@ -1635,8 +1646,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn: LoggingTransaction, user_id: str, device_ids: Collection[str], - stream_ids: List[str], - ): + stream_ids: List[int], + ) -> None: txn.call_after( self._device_list_stream_cache.entity_has_changed, user_id, @@ -1720,7 +1731,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_ids: Iterable[str], room_ids: Collection[str], - stream_ids: List[str], + stream_ids: List[int], context: Dict[str, str], ) -> None: """Record the user in the room has updated their device.""" @@ -1775,7 +1786,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): LIMIT ? """ - def get_uncoverted_outbound_room_pokes_txn(txn): + def get_uncoverted_outbound_room_pokes_txn( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: txn.execute(sql, (limit,)) return [ @@ -1808,7 +1821,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Marks the associated row in `device_lists_changes_in_room` as handled. """ - def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]): + def add_device_list_outbound_pokes_txn( + txn: LoggingTransaction, stream_ids: List[int] + ) -> None: if hosts: self._add_device_outbound_poke_to_stream_txn( txn, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 2a1e567ce0..9a6c2fd47a 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -47,6 +47,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry +from synapse.storage.engines.postgres import PostgresEngine from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id @@ -364,6 +365,20 @@ class PersistEventsStore: min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + # We check that the room still exists for events we're trying to + # persist. This is to protect against races with deleting a room. + # + # Annoyingly SQLite doesn't support row level locking. + if isinstance(self.database_engine, PostgresEngine): + for room_id in {e.room_id for e, _ in events_and_contexts}: + txn.execute( + "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE", + (room_id,), + ) + row = txn.fetchone() + if row is None: + raise Exception(f"Room does not exist {room_id}") + # stream orderings should have been assigned by now assert min_stream_order assert max_stream_order diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 0aef121d83..04efad9e9a 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -522,7 +522,9 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_joined_groups", ) - async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]: + async def get_all_groups_for_user( + self, user_id: str, now_token: int + ) -> List[JsonDict]: def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: sql = """ SELECT group_id, type, membership, u.content diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 6990f3ed1d..0a19f607bd 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -15,11 +15,12 @@ import itertools import logging -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple from signedjson.key import decode_verify_key_bytes from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction from synapse.storage.keys import FetchKeyResult from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList @@ -35,7 +36,9 @@ class KeyStore(SQLBaseStore): """Persistence for signature verification keys""" @cached() - def _get_server_verify_key(self, server_name_and_key_id): + def _get_server_verify_key( + self, server_name_and_key_id: Tuple[str, str] + ) -> FetchKeyResult: raise NotImplementedError() @cachedList( @@ -179,19 +182,21 @@ class KeyStore(SQLBaseStore): async def get_server_keys_json( self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] - ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: + ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: """Retrieve the key json for a list of server_keys and key ids. If no keys are found for a given server, key_id and source then that server, key_id, and source triplet entry will be an empty list. The JSON is returned as a byte array so that it can be efficiently used in an HTTP response. Args: - server_keys (list): List of (server_name, key_id, source) triplets. + server_keys: List of (server_name, key_id, source) triplets. Returns: A mapping from (server_name, key_id, source) triplets to a list of dicts """ - def _get_server_keys_json_txn(txn): + def _get_server_keys_json_txn( + txn: LoggingTransaction, + ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: results = {} for server_name, key_id, from_server in server_keys: keyvalues = {"server_name": server_name} diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 322ed05390..40ac377ca9 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -388,7 +388,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) async def store_url_cache( - self, url, response_code, etag, expires_ts, og, media_id, download_ts + self, + url: str, + response_code: int, + etag: Optional[str], + expires_ts: int, + og: Optional[str], + media_id: str, + download_ts: int, ) -> None: await self.db_pool.simple_insert( "local_media_repository_url_cache", @@ -441,7 +448,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) async def get_cached_remote_media( - self, origin, media_id: str + self, origin: str, media_id: str ) -> Optional[Dict[str, Any]]: return await self.db_pool.simple_select_one( "remote_media_cache", @@ -608,7 +615,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) async def delete_remote_media(self, media_origin: str, media_id: str) -> None: - def delete_remote_media_txn(txn): + def delete_remote_media_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, "remote_media_cache", diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 4f1c22c71b..5beb8f1d4b 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -232,10 +232,10 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): # is racy. # Have resolved to invalidate the whole cache for now and do # something about it if and when the perf becomes significant - self._invalidate_all_cache_and_stream( # type: ignore[attr-defined] + self._invalidate_all_cache_and_stream( txn, self.user_last_seen_monthly_active ) - self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined] + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) reserved_users = await self.get_registered_reserved_users() await self.db_pool.runInteraction( @@ -363,7 +363,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group - is_guest = await self.is_guest(user_id) # type: ignore[attr-defined] + is_guest = await self.is_guest(user_id) if is_guest: return is_trial = await self.is_trial_user(user_id) diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index d3c4611686..b47c511450 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream @@ -103,7 +103,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): prefilled_cache=presence_cache_prefill, ) - async def update_presence(self, presence_states) -> Tuple[int, int]: + async def update_presence( + self, presence_states: List[UserPresenceState] + ) -> Tuple[int, int]: assert self._can_persist_presence stream_ordering_manager = self._presence_id_gen.get_next_mult( @@ -121,7 +123,10 @@ class PresenceStore(PresenceBackgroundUpdateStore): return stream_orderings[-1], self._presence_id_gen.get_current_token() def _update_presence_txn( - self, txn: LoggingTransaction, stream_orderings, presence_states + self, + txn: LoggingTransaction, + stream_orderings: List[int], + presence_states: List[UserPresenceState], ) -> None: for stream_id, state in zip(stream_orderings, presence_states): txn.call_after( @@ -405,7 +410,13 @@ class PresenceStore(PresenceBackgroundUpdateStore): self._presence_on_startup = [] return active_on_startup - def process_replication_rows(self, stream_name, instance_name, token, rows) -> None: + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == PresenceStream.NAME: self._presence_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 2e3818e432..bfc85b3add 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -324,7 +324,12 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): ) def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: - # First we fetch all the state groups that should be deleted, before + # We *immediately* delete the room from the rooms table. This ensures + # that we don't race when persisting events (as that transaction checks + # that the room exists). + txn.execute("DELETE FROM rooms WHERE room_id = ?", (room_id,)) + + # Next, we fetch all the state groups that should be deleted, before # we delete that information. txn.execute( """ @@ -403,7 +408,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "room_stats_state", "room_stats_current", "room_stats_earliest_token", - "rooms", "stream_ordering_to_exterm", "users_in_public_rooms", "users_who_share_private_rooms", diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 92539f5d41..eb85bbd392 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -16,7 +16,7 @@ import abc import logging from typing import TYPE_CHECKING, Dict, List, Tuple, Union -from synapse.api.errors import NotFoundError, StoreError +from synapse.api.errors import StoreError from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json @@ -618,7 +618,7 @@ class PushRuleStore(PushRulesWorkerStore): are always stored in the database `push_rules` table). Raises: - NotFoundError if the rule does not exist. + RuleNotFoundException if the rule does not exist. """ async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() @@ -668,8 +668,7 @@ class PushRuleStore(PushRulesWorkerStore): ) txn.execute(sql, (user_id, rule_id)) if txn.fetchone() is None: - # needed to set NOT_FOUND code. - raise NotFoundError("Push rule does not exist.") + raise RuleNotFoundException("Push rule does not exist.") self.db_pool.simple_upsert_txn( txn, @@ -698,9 +697,6 @@ class PushRuleStore(PushRulesWorkerStore): """ Sets the `actions` state of a push rule. - Will throw NotFoundError if the rule does not exist; the Code for this - is NOT_FOUND. - Args: user_id: the user ID of the user who wishes to enable/disable the rule e.g. '@tina:example.org' @@ -712,6 +708,9 @@ class PushRuleStore(PushRulesWorkerStore): is_default_rule: True if and only if this is a server-default rule. This skips the check for existence (as only user-created rules are always stored in the database `push_rules` table). + + Raises: + RuleNotFoundException if the rule does not exist. """ actions_json = json_encoder.encode(actions) @@ -744,7 +743,7 @@ class PushRuleStore(PushRulesWorkerStore): except StoreError as serr: if serr.code == 404: # this sets the NOT_FOUND error Code - raise NotFoundError("Push rule does not exist") + raise RuleNotFoundException("Push rule does not exist") else: raise diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index cf64cd63a4..91286c9b65 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -14,11 +14,25 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + cast, +) from synapse.push import PusherConfig, ThrottleParams from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder @@ -117,7 +131,7 @@ class PusherWorkerStore(SQLBaseStore): return self._decode_pushers_rows(ret) async def get_all_pushers(self) -> Iterator[PusherConfig]: - def get_pushers(txn): + def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]: txn.execute("SELECT * FROM pushers") rows = self.db_pool.cursor_to_dict(txn) @@ -152,7 +166,9 @@ class PusherWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_pushers_rows_txn(txn): + def get_all_updated_pushers_rows_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT id, user_name, app_id, pushkey FROM pushers @@ -160,10 +176,13 @@ class PusherWorkerStore(SQLBaseStore): ORDER BY id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [ - (stream_id, (user_name, app_id, pushkey, False)) - for stream_id, user_name, app_id, pushkey in txn - ] + updates = cast( + List[Tuple[int, tuple]], + [ + (stream_id, (user_name, app_id, pushkey, False)) + for stream_id, user_name, app_id, pushkey in txn + ], + ) sql = """ SELECT stream_id, user_id, app_id, pushkey @@ -192,12 +211,12 @@ class PusherWorkerStore(SQLBaseStore): ) @cached(num_args=1, max_entries=15000) - async def get_if_user_has_pusher(self, user_id: str): + async def get_if_user_has_pusher(self, user_id: str) -> None: # This only exists for the cachedList decorator raise NotImplementedError() async def update_pusher_last_stream_ordering( - self, app_id, pushkey, user_id, last_stream_ordering + self, app_id: str, pushkey: str, user_id: str, last_stream_ordering: int ) -> None: await self.db_pool.simple_update_one( "pushers", @@ -291,7 +310,7 @@ class PusherWorkerStore(SQLBaseStore): last_user = progress.get("last_user", "") - def _delete_pushers(txn) -> int: + def _delete_pushers(txn: LoggingTransaction) -> int: sql = """ SELECT name FROM users @@ -339,7 +358,7 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) - def _delete_pushers(txn) -> int: + def _delete_pushers(txn: LoggingTransaction) -> int: sql = """ SELECT p.id, access_token FROM pushers AS p @@ -396,7 +415,7 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) - def _delete_pushers(txn) -> int: + def _delete_pushers(txn: LoggingTransaction) -> int: sql = """ SELECT p.id, p.user_name, p.app_id, p.pushkey @@ -502,7 +521,7 @@ class PusherStore(PusherWorkerStore): async def delete_pusher_by_app_id_pushkey_user_id( self, app_id: str, pushkey: str, user_id: str ) -> None: - def delete_pusher_txn(txn, stream_id): + def delete_pusher_txn(txn: LoggingTransaction, stream_id: int) -> None: self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_if_user_has_pusher, (user_id,) ) @@ -547,7 +566,7 @@ class PusherStore(PusherWorkerStore): # account. pushers = list(await self.get_pushers_by_user_id(user_id)) - def delete_pushers_txn(txn, stream_ids): + def delete_pushers_txn(txn: LoggingTransaction, stream_ids: List[int]) -> None: self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_if_user_has_pusher, (user_id,) ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 332e901dda..7d96f4feda 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -122,10 +122,21 @@ class ReceiptsWorkerStore(SQLBaseStore): receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ) return {r["user_id"] for r in receipts} - @cached(num_args=2) + @cached() async def get_receipts_for_room( self, room_id: str, receipt_type: str ) -> List[Dict[str, Any]]: + """ + Fetch the event IDs for the latest receipt for all users in a room with the given receipt type. + + Args: + room_id: The room ID to fetch the receipt for. + receipt_type: The receipt type to fetch. + + Returns: + A list of dictionaries, one for each user ID. Each dictionary + contains a user ID and the event ID of that user's latest receipt. + """ return await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, @@ -133,10 +144,21 @@ class ReceiptsWorkerStore(SQLBaseStore): desc="get_receipts_for_room", ) - @cached(num_args=3) + @cached() async def get_last_receipt_event_id_for_user( self, user_id: str, room_id: str, receipt_type: str ) -> Optional[str]: + """ + Fetch the event ID for the latest receipt in a room with the given receipt type. + + Args: + user_id: The user to fetch receipts for. + room_id: The room ID to fetch the receipt for. + receipt_type: The receipt type to fetch. + + Returns: + The event ID of the latest receipt, if one exists; otherwise `None`. + """ return await self.db_pool.simple_select_one_onecol( table="receipts_linearized", keyvalues={ @@ -149,10 +171,23 @@ class ReceiptsWorkerStore(SQLBaseStore): allow_none=True, ) - @cached(num_args=2) + @cached() async def get_receipts_for_user( self, user_id: str, receipt_type: str ) -> Dict[str, str]: + """ + Fetch the event IDs for the latest receipts sent by the given user. + + Args: + user_id: The user to fetch receipts for. + receipt_type: The receipt type to fetch. + + Returns: + A map of room ID to the event ID of the latest receipt for that room. + + If the user has not sent a receipt to a room then it will not appear + in the returned dictionary. + """ rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, @@ -165,6 +200,17 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_receipts_for_user_with_orderings( self, user_id: str, receipt_type: str ) -> JsonDict: + """ + Fetch receipts for all rooms that the given user is joined to. + + Args: + user_id: The user to fetch receipts for. + receipt_type: The receipt type to fetch. + + Returns: + A map of room ID to the latest receipt information. + """ + def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: sql = ( "SELECT rl.room_id, rl.event_id," @@ -241,7 +287,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) - @cached(num_args=3, tree=True) + @cached(tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None ) -> List[JsonDict]: @@ -541,7 +587,7 @@ class ReceiptsWorkerStore(SQLBaseStore): data: JsonDict, stream_id: int, ) -> Optional[int]: - """Inserts a read-receipt into the database if it's newer than the current RR + """Inserts a receipt into the database if it's newer than the current one. Returns: None if the RR is older than the current RR diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index e653841fe5..18ae8aee29 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -12,11 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections.abc import logging from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple -from frozendict import frozendict - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion @@ -160,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): predecessor = create_event.content.get("predecessor", None) # Ensure the key is a dictionary - if not isinstance(predecessor, (dict, frozendict)): + if not isinstance(predecessor, collections.abc.Mapping): return None # The keys must be strings since the data is JSON. @@ -370,10 +369,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def _update_state_for_partial_state_event_txn( self, - txn, + txn: LoggingTransaction, event: EventBase, context: EventContext, - ): + ) -> None: # we shouldn't have any outliers here assert not event.internal_metadata.is_outlier() diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 2d339b6008..f38bedbbcd 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -131,7 +131,7 @@ class UIAuthWorkerStore(SQLBaseStore): session_id: str, stage_type: str, result: Union[str, bool, JsonDict], - ): + ) -> None: """ Mark a session stage as completed. @@ -200,7 +200,9 @@ class UIAuthWorkerStore(SQLBaseStore): desc="set_ui_auth_client_dict", ) - async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): + async def set_ui_auth_session_data( + self, session_id: str, key: str, value: Any + ) -> None: """ Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by @@ -223,7 +225,7 @@ class UIAuthWorkerStore(SQLBaseStore): def _set_ui_auth_session_data_txn( self, txn: LoggingTransaction, session_id: str, key: str, value: Any - ): + ) -> None: # Get the current value. result = cast( Dict[str, Any], @@ -275,7 +277,7 @@ class UIAuthWorkerStore(SQLBaseStore): session_id: str, user_agent: str, ip: str, - ): + ) -> None: """Add the given user agent / IP to the tracking table""" await self.db_pool.simple_upsert( table="ui_auth_sessions_ips", @@ -318,7 +320,7 @@ class UIAuthWorkerStore(SQLBaseStore): def _delete_old_ui_auth_sessions_txn( self, txn: LoggingTransaction, expiration_time: int - ): + ) -> None: # Get the expired sessions. sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" txn.execute(sql, [expiration_time]) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e3153d1a4a..546d6bae6e 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -501,11 +501,11 @@ def _upgrade_existing_database( if hasattr(module, "run_create"): logger.info("Running %s:run_create", relative_path) - module.run_create(cur, database_engine) # type: ignore + module.run_create(cur, database_engine) if not is_empty and hasattr(module, "run_upgrade"): logger.info("Running %s:run_upgrade", relative_path) - module.run_upgrade(cur, database_engine, config=config) # type: ignore + module.run_upgrade(cur, database_engine, config=config) elif ext == ".pyc" or file_name == "__pycache__": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 0b9ac26b69..f6b3ee31e4 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -107,7 +107,7 @@ class TTLCache(Generic[KT, VT]): self._metrics.inc_hits() return e.value, e.expiry_time, e.ttl - def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore + def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: """Remove a value from the cache If key is in the cache, remove it and return its value, else return default. diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 9c405eb4d7..7223af1a36 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections.abc from typing import Any from frozendict import frozendict @@ -35,7 +36,7 @@ def freeze(o: Any) -> Any: def unfreeze(o: Any) -> Any: - if isinstance(o, (dict, frozendict)): + if isinstance(o, collections.abc.Mapping): return {k: unfreeze(v) for k, v in o.items()} if isinstance(o, (bytes, str)): diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index ec8864dafe..268a48d7ba 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -83,7 +83,7 @@ class FederationClientTest(FederatingHomeserverTestCase): ) # mock up the response, and have the agent return it - self._mock_agent.request.return_value = defer.succeed( + self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( _mock_response( { "pdus": [ diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 91f982518e..6b26353d5e 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -226,7 +226,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # Send the server a device list EDU for the other user, this will cause # it to try and resync the device lists. self.hs.get_federation_transport_client().query_user_devices.return_value = ( - defer.succeed( + make_awaitable( { "stream_id": "1", "user_id": "@user2:host2", diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 8c72cf6b30..5b0cd1ab86 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -411,6 +411,88 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): "exclusive_as_user", "password", self.exclusive_as_user_device_id ) + def test_sending_read_receipt_batches_to_application_services(self): + """Tests that a large batch of read receipts are sent correctly to + interested application services. + """ + # Register an application service that's interested in a certain user + # and room prefix + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@exclusive_as_user:.+", + "exclusive": True, + } + ], + ApplicationService.NS_ROOMS: [ + { + "regex": "!fakeroom_.*", + "exclusive": True, + } + ], + }, + ) + + # "Complete" a transaction. + # All this really does for us is make an entry in the application_services_state + # database table, which tracks the current stream_token per stream ID per AS. + self.get_success( + self.hs.get_datastores().main.complete_appservice_txn( + 0, + interested_appservice, + ) + ) + + # Now, pretend that we receive a large burst of read receipts (300 total) that + # all come in at once. + for i in range(300): + self.get_success( + # Insert a fake read receipt into the database + self.hs.get_datastores().main.insert_receipt( + # We have to use unique room ID + user ID combinations here, as the db query + # is an upsert. + room_id=f"!fakeroom_{i}:test", + receipt_type="m.read", + user_id=self.local_user, + event_ids=[f"$eventid_{i}"], + data={}, + ) + ) + + # Now notify the appservice handler that 300 read receipts have all arrived + # at once. What will it do! + # note: stream tokens start at 2 + for stream_token in range(2, 303): + self.get_success( + self.hs.get_application_service_handler()._notify_interested_services_ephemeral( + services=[interested_appservice], + stream_key="receipt_key", + new_token=stream_token, + users=[self.exclusive_as_user], + ) + ) + + # Using our txn send mock, we can see what the AS received. After iterating over every + # transaction, we'd like to see all 300 read receipts accounted for. + # No more, no less. + all_ephemeral_events = [] + for call in self.send_mock.call_args_list: + ephemeral_events = call[0][2] + all_ephemeral_events += ephemeral_events + + # Ensure that no duplicate events were sent + self.assertEqual(len(all_ephemeral_events), 300) + + # Check that the ephemeral event is a read receipt with the expected structure + latest_read_receipt = all_ephemeral_events[-1] + self.assertEqual(latest_read_receipt["type"], "m.receipt") + + event_id = list(latest_read_receipt["content"].keys())[0] + self.assertEqual( + latest_read_receipt["content"][event_id]["m.read"], {self.local_user: {}} + ) + @unittest.override_config( {"experimental_features": {"msc2409_to_device_messages_enabled": True}} ) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index a54aa29cf1..751025c5da 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -201,4 +201,16 @@ class CasHandlerTestCase(HomeserverTestCase): def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "getHeader", "_disconnected"]) + mock = Mock( + spec=[ + "finish", + "getClientIP", + "getHeader", + "setHeader", + "setResponseCode", + "write", + ] + ) + # `_disconnected` musn't be another `Mock`, otherwise it will be truthy. + mock._disconnected = False + return mock diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 8c74ed1fcf..1e6ad4b663 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -19,7 +19,6 @@ from unittest import mock from parameterized import parameterized from signedjson import key as key, sign as sign -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms @@ -704,7 +703,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" self.hs.get_federation_client().query_client_keys = mock.Mock( - return_value=defer.succeed( + return_value=make_awaitable( { "device_keys": {remote_user_id: {}}, "master_keys": { @@ -777,14 +776,14 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): # Pretend we're sharing a room with the user we're querying. If not, # `_query_devices_for_destination` will return early. self.store.get_rooms_for_user = mock.Mock( - return_value=defer.succeed({"some_room_id"}) + return_value=make_awaitable({"some_room_id"}) ) remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" self.hs.get_federation_client().query_user_devices = mock.Mock( - return_value=defer.succeed( + return_value=make_awaitable( { "user_id": remote_user_id, "stream_id": 1, diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index d401fda938..addf14fa2b 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -17,8 +17,6 @@ from typing import Any, Type, Union from unittest.mock import Mock -from twisted.internet import defer - import synapse from synapse.api.constants import LoginType from synapse.api.errors import Codes @@ -190,7 +188,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) # check_password must return an awaitable - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@u:test", channel.json_body["user_id"]) @@ -226,13 +224,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.get_success(module_api.register_user("u")) # log in twice, to get two devices - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) tok1 = self.login("u", "p") self.login("u", "p", device_id="dev2") mock_password_provider.reset_mock() # have the auth provider deny the request to start with - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) # make the initial request which returns a 401 session = self._start_delete_device_session(tok1, "dev2") @@ -246,7 +244,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # Finally, check the request goes through when we allow it - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") @@ -260,7 +258,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, 403, channel.result) @@ -277,7 +275,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # have the auth provider deny the request - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) # log in twice, to get two devices tok1 = self.login("localuser", "localpass") @@ -320,7 +318,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -342,7 +340,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # allow login via the auth provider - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) # log in twice, to get two devices tok1 = self.login("localuser", "p") @@ -359,7 +357,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_password.assert_not_called() # now try deleting with the local password - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._authed_delete_device( tok1, "dev2", session, "localuser", "localpass" ) @@ -413,7 +411,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", None) ) channel = self._send_login("test.login_type", "u", test_field="y") @@ -427,7 +425,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # try a weird username. Again, it's unclear what we *expect* to happen # in these cases, but at least we can guard against the API changing # unexpectedly - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") @@ -477,7 +475,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", None) ) body["auth"]["test_field"] = "foo" @@ -490,7 +488,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # and finally, succeed - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) @@ -508,9 +506,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.custom_auth_provider_callback_test_body() def custom_auth_provider_callback_test_body(self): - callback = Mock(return_value=defer.succeed(None)) + callback = Mock(return_value=make_awaitable(None)) - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", callback) ) channel = self._send_login("test.login_type", "u", test_field="y") @@ -646,7 +644,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 5081b97573..65ab7db0c8 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -15,7 +15,7 @@ from typing import List -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.types import JsonDict from tests import unittest @@ -35,7 +35,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@rikj:jki.re": { "ts": 1436451550453, "hidden": True, @@ -56,7 +56,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1435641916hfgh4394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@me:server.org": { "ts": 1436451550453, "hidden": True, @@ -72,7 +72,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1435641916hfgh4394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@me:server.org": { "ts": 1436451550453, ReadReceiptEventFields.MSC2285_HIDDEN: True, @@ -92,7 +92,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@rikj:jki.re": { "ts": 1436451550453, "hidden": True, @@ -111,7 +111,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -130,7 +130,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$14356419edgd14394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@rikj:jki.re": { "ts": 1436451550453, "hidden": True, @@ -138,7 +138,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): } }, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -153,7 +153,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -171,9 +171,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [ { "content": { - "$14356419ggffg114394fHBLK:matrix.org": {"m.read": {}}, + "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}}, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -187,9 +187,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [ { "content": { - "$14356419ggffg114394fHBLK:matrix.org": {"m.read": {}}, + "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}}, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -209,7 +209,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): "content": { "$143564gdfg6114394fHBLK:matrix.org": {}, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -225,7 +225,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): "content": { "$143564gdfg6114394fHBLK:matrix.org": {}, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -244,7 +244,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$14356419edgd14394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@rikj:jki.re": { "ts": 1436451550453, "hidden": True, @@ -258,7 +258,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -273,7 +273,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -297,7 +297,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$14356419edgd14394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@rikj:jki.re": "string", } }, diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 45fd30cf43..b6ba19c739 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -193,8 +193,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self): - # Type ignore: mypy doesn't like us assigning to methods. - self.store.count_monthly_users = Mock( # type: ignore[assignment] + self.store.count_monthly_users = Mock( return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception @@ -202,8 +201,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_blocked(self): - # Type ignore: mypy doesn't like us assigning to methods. - self.store.get_monthly_active_count = Mock( # type: ignore[assignment] + self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) self.get_failure( @@ -211,8 +209,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - # Type ignore: mypy doesn't like us assigning to methods. - self.store.get_monthly_active_count = Mock( # type: ignore[assignment] + self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 8d4404eda1..e2f0f90ef1 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -349,4 +349,16 @@ class SamlHandlerTestCase(HomeserverTestCase): def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "getHeader", "_disconnected"]) + mock = Mock( + spec=[ + "finish", + "getClientIP", + "getHeader", + "setHeader", + "setResponseCode", + "write", + ] + ) + # `_disconnected` musn't be another `Mock`, otherwise it will be truthy. + mock._disconnected = False + return mock diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ffd5c4cb93..5f2e26a5fc 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -65,11 +65,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) - mock_keyring.verify_json_for_server.return_value = defer.succeed(True) + mock_keyring.verify_json_for_server.return_value = make_awaitable(True) # we mock out the federation client too mock_federation_client = Mock(spec=["put_json"]) - mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) + mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) # the tests assume that we are starting at unix time 1000 reactor.pump((1000,)) @@ -98,7 +98,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore = hs.get_datastores().main self.datastore.get_destination_retry_timings = Mock( - return_value=defer.succeed(None) + return_value=make_awaitable(None) ) self.datastore.get_device_updates_by_remote = Mock( diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index c6e501c7be..96e2e3039b 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -15,7 +15,6 @@ from typing import Tuple from unittest.mock import Mock, patch from urllib.parse import quote -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -30,6 +29,7 @@ from synapse.util import Clock from tests import unittest from tests.storage.test_user_directory import GetUserDirectoryTables +from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_member_event from tests.unittest import override_config @@ -439,7 +439,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) - mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) + mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) with patch.object( self.store, "remove_from_user_dir", mock_remove_from_user_dir ): @@ -454,7 +454,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.store.register_user(user_id=r_user_id, password_hash=None) ) - mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) + mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) with patch.object( self.store, "remove_from_user_dir", mock_remove_from_user_dir ): diff --git a/tests/module_api/test_account_data_manager.py b/tests/module_api/test_account_data_manager.py index bec018d9e7..89009bea8c 100644 --- a/tests/module_api/test_account_data_manager.py +++ b/tests/module_api/test_account_data_manager.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import SynapseError from synapse.rest import admin +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -22,7 +26,9 @@ class ModuleApiTestCase(HomeserverTestCase): admin.register_servlets, ] - def prepare(self, reactor, clock, homeserver) -> None: + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self._store = homeserver.get_datastores().main self._module_api = homeserver.get_module_api() self._account_data_mgr = self._module_api.account_data_manager @@ -91,7 +97,7 @@ class ModuleApiTestCase(HomeserverTestCase): ) with self.assertRaises(TypeError): # This throws an exception because it's a frozen dict. - the_data["wombat"] = False + the_data["wombat"] = False # type: ignore[index] def test_put_global(self) -> None: """ @@ -143,15 +149,14 @@ class ModuleApiTestCase(HomeserverTestCase): with self.assertRaises(TypeError): # The account data type must be a string. self.get_success_or_raise( - self._module_api.account_data_manager.put_global( - self.user_id, 42, {} # type: ignore - ) + self._module_api.account_data_manager.put_global(self.user_id, 42, {}) # type: ignore[arg-type] ) with self.assertRaises(TypeError): # The account data dict must be a dict. + # noinspection PyTypeChecker self.get_success_or_raise( self._module_api.account_data_manager.put_global( - self.user_id, "test.data", 42 # type: ignore + self.user_id, "test.data", 42 # type: ignore[arg-type] ) ) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 9fd5d59c55..8bc84aaaca 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -19,8 +19,9 @@ from synapse.api.constants import EduTypes, EventTypes from synapse.events import EventBase from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState +from synapse.handlers.push_rules import InvalidRuleException from synapse.rest import admin -from synapse.rest.client import login, presence, profile, room +from synapse.rest.client import login, notifications, presence, profile, room from synapse.types import create_requester from tests.events.test_presence_router import send_presence_update, sync_presence @@ -38,6 +39,7 @@ class ModuleApiTestCase(HomeserverTestCase): room.register_servlets, presence.register_servlets, profile.register_servlets, + notifications.register_servlets, ] def prepare(self, reactor, clock, homeserver): @@ -553,6 +555,86 @@ class ModuleApiTestCase(HomeserverTestCase): self.assertEqual(state[("org.matrix.test", "")].state_key, "") self.assertEqual(state[("org.matrix.test", "")].content, {}) + def test_set_push_rules_action(self) -> None: + """Test that a module can change the actions of an existing push rule for a user.""" + + # Create a room with 2 users in it. Push rules must not match if the user is the + # event's sender, so we need one user to send messages and one user to receive + # notifications. + user_id = self.register_user("user", "password") + tok = self.login("user", "password") + + room_id = self.helper.create_room_as(user_id, is_public=True, tok=tok) + + user_id2 = self.register_user("user2", "password") + tok2 = self.login("user2", "password") + self.helper.join(room_id, user_id2, tok=tok2) + + # Register a 3rd user and join them to the room, so that we don't accidentally + # trigger 1:1 push rules. + user_id3 = self.register_user("user3", "password") + tok3 = self.login("user3", "password") + self.helper.join(room_id, user_id3, tok=tok3) + + # Send a message as the second user and check that it notifies. + res = self.helper.send(room_id=room_id, body="here's a message", tok=tok2) + event_id = res["event_id"] + + channel = self.make_request( + "GET", + "/notifications", + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) + self.assertEqual( + channel.json_body["notifications"][0]["event"]["event_id"], + event_id, + channel.json_body, + ) + + # Change the .m.rule.message actions to not notify on new messages. + self.get_success( + defer.ensureDeferred( + self.module_api.set_push_rule_action( + user_id=user_id, + scope="global", + kind="underride", + rule_id=".m.rule.message", + actions=["dont_notify"], + ) + ) + ) + + # Send another message as the second user and check that the number of + # notifications didn't change. + self.helper.send(room_id=room_id, body="here's another message", tok=tok2) + + channel = self.make_request( + "GET", + "/notifications?from=", + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) + + def test_check_push_rules_actions(self) -> None: + """Test that modules can check whether a list of push rules actions are spec + compliant. + """ + with self.assertRaises(InvalidRuleException): + self.module_api.check_push_rule_actions(["foo"]) + + with self.assertRaises(InvalidRuleException): + self.module_api.check_push_rule_actions({"foo": "bar"}) + + self.module_api.check_push_rule_actions(["notify"]) + + self.module_api.check_push_rule_actions( + [{"set_tweak": "sound", "value": "default"}] + ) + class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase): """For testing ModuleApi functionality in a multi-worker setup""" diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index f47d94f690..de19e75b9d 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from ._base import BaseSlavedStoreTestCase @@ -26,9 +27,13 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): STORE_TYPE = SlavedReceiptsStore def test_receipt(self): - self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) + self.check("get_receipts_for_user", [USER_ID, ReceiptTypes.READ], {}) self.get_success( - self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}) + self.master_store.insert_receipt( + ROOM_ID, ReceiptTypes.READ, USER_ID, [EVENT_ID], {} + ) ) self.replicate() - self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}) + self.check( + "get_receipts_for_user", [USER_ID, ReceiptTypes.READ], {ROOM_ID: EVENT_ID} + ) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index ba1a63c0d6..6104a55aa1 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -102,8 +102,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): for i in range(20): server_name = "other_server_%d" % (i,) room = self.create_room_with_remote_server(user, token, server_name) - mock_client1.reset_mock() # type: ignore[attr-defined] - mock_client2.reset_mock() # type: ignore[attr-defined] + mock_client1.reset_mock() + mock_client2.reset_mock() self.create_and_send_event(room, UserID.from_string(user)) self.replicate() @@ -167,8 +167,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): for i in range(20): server_name = "other_server_%d" % (i,) room = self.create_room_with_remote_server(user, token, server_name) - mock_client1.reset_mock() # type: ignore[attr-defined] - mock_client2.reset_mock() # type: ignore[attr-defined] + mock_client1.reset_mock() + mock_client2.reset_mock() self.get_success( typing_handler.started_typing( diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 0abe378fe4..b3738a0304 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -14,7 +14,6 @@ from http import HTTPStatus from unittest.mock import Mock -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.presence import PresenceHandler @@ -24,6 +23,7 @@ from synapse.types import UserID from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable class PresenceTestCase(unittest.HomeserverTestCase): @@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: presence_handler = Mock(spec=PresenceHandler) - presence_handler.set_state.return_value = defer.succeed(None) + presence_handler.set_state.return_value = make_awaitable(None) hs = self.setup_test_homeserver( "red", diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 6ff79b9e2e..9443daa056 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, List, Optional from unittest.mock import Mock, call from urllib import parse as urlparse -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -1426,9 +1425,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): def test_simple(self) -> None: "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined] - {} - ) + self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] search_filter = {"generic_search_term": "foobar"} @@ -1456,7 +1453,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] HttpResponseException(404, "Not Found", b""), - defer.succeed({}), + make_awaitable({}), ) search_filter = {"generic_search_term": "foobar"} diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 773c16a54c..cb765455c1 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -24,6 +24,7 @@ from synapse.api.constants import ( EventContentFields, EventTypes, ReadReceiptEventFields, + ReceiptTypes, RelationTypes, ) from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync @@ -560,7 +561,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(1) # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8") channel = self.make_request( "POST", "/rooms/%s/read_markers" % self.room_id, diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 8d8251b2ac..21a1ca2a68 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -22,6 +22,7 @@ from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionC from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import MockClock @@ -38,7 +39,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_executes_given_function(self): - cb = Mock(return_value=defer.succeed(self.mock_http_response)) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg" ) @@ -47,7 +48,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_deduplicates_based_on_key(self): - cb = Mock(return_value=defer.succeed(self.mock_http_response)) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) for i in range(3): # invoke multiple times res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg", changing_args=i @@ -130,7 +131,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cleans_up(self): - cb = Mock(return_value=defer.succeed(self.mock_http_response)) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") # should NOT have cleaned up yet self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 02b96c9e6e..9ee9509d3a 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -14,8 +14,6 @@ from unittest.mock import Mock -from twisted.internet import defer - from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.rest import admin @@ -68,16 +66,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return_value=make_awaitable(1000) ) self._rlsn._server_notices_manager.send_notice = Mock( - return_value=defer.succeed(Mock()) + return_value=make_awaitable(Mock()) ) self._send_notice = self._rlsn._server_notices_manager.send_notice self.user_id = "@user_id:test" self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( - return_value=defer.succeed("!something:localhost") + return_value=make_awaitable("!something:localhost") ) - self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) + self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) @override_config({"hs_disabled": True}) @@ -95,7 +93,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) + self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) @@ -111,7 +109,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user has blocked notice, but notice ought to be there (NOOP) """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") + return_value=make_awaitable(None), + side_effect=ResourceLimitError(403, "foo"), ) mock_event = Mock( @@ -130,7 +129,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user does not have blocked notice, but should have one """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") + return_value=make_awaitable(None), + side_effect=ResourceLimitError(403, "foo"), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -141,7 +141,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) + self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -152,7 +152,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) + self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=make_awaitable(None) ) @@ -167,7 +167,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): an alert message is not sent into the room """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), @@ -182,7 +182,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test that when a server is disabled, that MAU limit alerting is ignored. """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED ), @@ -199,14 +199,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): is suppressed that the room is returned to an unblocked state. """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), ) self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( - return_value=defer.succeed((True, [])) + return_value=make_awaitable((True, [])) ) mock_event = Mock( diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 60c8d37594..0fbf465670 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -14,7 +14,6 @@ from typing import Any, Dict, List from unittest.mock import Mock -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes @@ -259,10 +258,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_should_update(self): self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment] + self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(None) + return_value=make_awaitable(None) ) d = self.store.populate_monthly_active_users("user_id") self.get_success(d) @@ -272,9 +271,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_should_not_update(self): self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment] + self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(self.hs.get_clock().time_msec()) + return_value=make_awaitable(self.hs.get_clock().time_msec()) ) d = self.store.populate_monthly_active_users("user_id") diff --git a/tests/test_federation.py b/tests/test_federation.py index c39816de85..0cbef70bfa 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -233,7 +233,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register mock device list retrieval on the federation client. federation_client = self.homeserver.get_federation_client() federation_client.query_user_devices = Mock( - return_value=succeed( + return_value=make_awaitable( { "user_id": remote_user_id, "stream_id": 1, diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index f05a373aa0..0d0d6faf0d 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -52,7 +52,7 @@ def make_awaitable(result: TV) -> Awaitable[TV]: This uses Futures as they can be awaited multiple times so can be returned to multiple callers. """ - future = Future() # type: ignore + future: Future[TV] = Future() future.set_result(result) return future @@ -69,7 +69,7 @@ def setup_awaitable_errors() -> Callable[[], None]: # State shared between unraisablehook and check_for_unraisable_exceptions. unraisable_exceptions = [] - orig_unraisablehook = sys.unraisablehook # type: ignore + orig_unraisablehook = sys.unraisablehook def unraisablehook(unraisable): unraisable_exceptions.append(unraisable.exc_value) @@ -78,11 +78,11 @@ def setup_awaitable_errors() -> Callable[[], None]: """ A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. """ - sys.unraisablehook = orig_unraisablehook # type: ignore + sys.unraisablehook = orig_unraisablehook if unraisable_exceptions: raise unraisable_exceptions.pop() - sys.unraisablehook = unraisablehook # type: ignore + sys.unraisablehook = unraisablehook return cleanup diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index 51a197a8c6..9228454c9e 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -27,7 +27,7 @@ class ToTwistedHandler(logging.Handler): def emit(self, record): log_entry = self.format(record) log_level = record.levelname.lower().replace("warning", "warn") - self.tx_log.emit( # type: ignore + self.tx_log.emit( twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry ) |