diff options
184 files changed, 2495 insertions, 1004 deletions
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 124b17458f..d20d30c035 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -34,32 +34,24 @@ jobs: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - # TODO: consider using https://github.com/docker/metadata-action instead of this - # custom magic - name: Calculate docker image tag id: set-tag - run: | - case "${GITHUB_REF}" in - refs/heads/develop) - tag=develop - ;; - refs/heads/master|refs/heads/main) - tag=latest - ;; - refs/tags/*) - tag=${GITHUB_REF#refs/tags/} - ;; - *) - tag=${GITHUB_SHA} - ;; - esac - echo "::set-output name=tag::$tag" + uses: docker/metadata-action@master + with: + images: matrixdotorg/synapse + flavor: | + latest=false + tags: | + type=raw,value=develop,enable=${{ github.ref == 'refs/heads/develop' }} + type=raw,value=latest,enable=${{ github.ref == 'refs/heads/master' }} + type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }} + type=pep440,pattern={{raw}} - name: Build and push all platforms uses: docker/build-push-action@v2 with: push: true labels: "gitsha1=${{ github.sha }}" - tags: "matrixdotorg/synapse:${{ steps.set-tag.outputs.tag }}" + tags: "${{ steps.set-tag.outputs.tags }}" file: "docker/Dockerfile" platforms: linux/amd64,linux/arm64 diff --git a/.github/workflows/latest_deps.yml b/.github/workflows/latest_deps.yml index 1a61d179d9..c537a5a60f 100644 --- a/.github/workflows/latest_deps.yml +++ b/.github/workflows/latest_deps.yml @@ -32,12 +32,15 @@ jobs: with: python-version: "3.x" poetry-version: "1.2.0b1" + extras: "all" # Dump installed versions for debugging. - run: poetry run pip list > before.txt # Upgrade all runtime dependencies only. This is intended to mimic a fresh # `pip install matrix-synapse[all]` as closely as possible. - run: poetry update --no-dev - run: poetry run pip list > after.txt && (diff -u before.txt after.txt || true) + - name: Remove warn_unused_ignores from mypy config + run: sed '/warn_unused_ignores = True/d' -i mypy.ini - run: poetry run mypy trial: runs-on: ubuntu-latest diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cad4cb6d77..efa35b71df 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,13 +20,9 @@ jobs: - run: scripts-dev/config-lint.sh lint: - # This does a vanilla `poetry install` - no extras. I'm slightly anxious - # that we might skip some typechecks on code that uses extras. However, - # I think the right way to fix this is to mark any extras needed for - # typechecking as development dependencies. To detect this, we ought to - # turn up mypy's strictness: disallow unknown imports and be accept fewer - # uses of `Any`. uses: "matrix-org/backend-meta/.github/workflows/python-poetry-ci.yml@v1" + with: + typechecking-extras: "all" lint-crlf: runs-on: ubuntu-latest diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml index 8fc1affb77..5f0671f350 100644 --- a/.github/workflows/twisted_trunk.yml +++ b/.github/workflows/twisted_trunk.yml @@ -24,6 +24,8 @@ jobs: poetry remove twisted poetry add --extras tls git+https://github.com/twisted/twisted.git#trunk poetry install --no-interaction --extras "all test" + - name: Remove warn_unused_ignores from mypy config + run: sed '/warn_unused_ignores = True/d' -i mypy.ini - run: poetry run mypy trial: diff --git a/CHANGES.md b/CHANGES.md index 982f09b1db..88b053897e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,9 @@ +Synapse 1.59.0 +============== + +The non-standard `m.login.jwt` login type has been removed from Synapse. It can be replaced with `org.matrix.login.jwt` for identical behaviour. This is only used if `jwt_config.enabled` is set to `true` in the configuration. + + Synapse 1.58.1 (2022-05-05) =========================== diff --git a/README.rst b/README.rst index d71d733679..219e32de8e 100644 --- a/README.rst +++ b/README.rst @@ -55,7 +55,7 @@ solutions. The hope is for Matrix to act as the building blocks for a new generation of fully open and interoperable messaging and VoIP apps for the internet. -Synapse is a Matrix "homeserver" implementation developed by the matrix.org core +Synapse is a Matrix "homeserver" implementation developed by the matrix.org core team, written in Python 3/Twisted. In Matrix, every user runs one or more Matrix clients, which connect through to @@ -294,13 +294,13 @@ directory of your choice:: cd synapse Synapse has a number of external dependencies. We maintain a fixed development -environment using [poetry](https://python-poetry.org/). First, install poetry. We recommend +environment using `Poetry <https://python-poetry.org/>`_. First, install poetry. We recommend:: pip install --user pipx pipx install poetry as described `here <https://python-poetry.org/docs/#installing-with-pipx>`_. -(See `poetry's installation docs <https://python-poetry.org/docs/#installation>` +(See `poetry's installation docs <https://python-poetry.org/docs/#installation>`_ for other installation methods.) Then ask poetry to create a virtual environment from the project and install Synapse's dependencies:: @@ -309,11 +309,11 @@ from the project and install Synapse's dependencies:: This will run a process of downloading and installing all the needed dependencies into a virtual env. -We recommend using the demo which starts 3 federated instances running on ports `8080` - `8082` +We recommend using the demo which starts 3 federated instances running on ports `8080` - `8082`:: poetry run ./demo/start.sh -(to stop, you can use `poetry run ./demo/stop.sh`) +(to stop, you can use ``poetry run ./demo/stop.sh``) See the `demo documentation <https://matrix-org.github.io/synapse/develop/development/demo.html>`_ for more information. diff --git a/changelog.d/11507.feature b/changelog.d/11507.feature new file mode 100644 index 0000000000..72c5690cca --- /dev/null +++ b/changelog.d/11507.feature @@ -0,0 +1 @@ +Support [MSC3266](https://github.com/matrix-org/matrix-doc/pull/3266) room summaries over federation. diff --git a/changelog.d/12168.feature b/changelog.d/12168.feature new file mode 100644 index 0000000000..cd5c45029e --- /dev/null +++ b/changelog.d/12168.feature @@ -0,0 +1 @@ +Implement [changes](https://github.com/matrix-org/matrix-spec-proposals/pull/2285/commits/4a77139249c2e830aec3c7d6bd5501a514d1cc27) to [MSC2285 (hidden read receipts)](https://github.com/matrix-org/matrix-spec-proposals/pull/2285). Contributed by @SimonBrandner. diff --git a/changelog.d/12273.bugfix b/changelog.d/12273.bugfix new file mode 100644 index 0000000000..f8d7b6c889 --- /dev/null +++ b/changelog.d/12273.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.48.0 where latest thread reply provided failed to include the proper bundled aggregations. diff --git a/changelog.d/12356.misc b/changelog.d/12356.misc new file mode 100644 index 0000000000..43e1929106 --- /dev/null +++ b/changelog.d/12356.misc @@ -0,0 +1 @@ +Fix scripts-dev to pass typechecking. \ No newline at end of file diff --git a/changelog.d/12406.feature b/changelog.d/12406.feature new file mode 100644 index 0000000000..e345afdee7 --- /dev/null +++ b/changelog.d/12406.feature @@ -0,0 +1 @@ +Add a module API to allow modules to change actions for existing push rules of local users. diff --git a/changelog.d/12480.misc b/changelog.d/12480.misc new file mode 100644 index 0000000000..18a85e7b15 --- /dev/null +++ b/changelog.d/12480.misc @@ -0,0 +1 @@ +Use supervisord to supervise Postgres and Caddy in the Complement image to reduce restart time. \ No newline at end of file diff --git a/changelog.d/12485.misc b/changelog.d/12485.misc new file mode 100644 index 0000000000..e793d08e5e --- /dev/null +++ b/changelog.d/12485.misc @@ -0,0 +1 @@ +Add some type hints to datastore. \ No newline at end of file diff --git a/changelog.d/12505.misc b/changelog.d/12505.misc new file mode 100644 index 0000000000..a691d7962f --- /dev/null +++ b/changelog.d/12505.misc @@ -0,0 +1 @@ +Use `make_awaitable` instead of `defer.succeed` for return values of mocks in tests. diff --git a/changelog.d/12526.feature b/changelog.d/12526.feature new file mode 100644 index 0000000000..c01596282c --- /dev/null +++ b/changelog.d/12526.feature @@ -0,0 +1 @@ +Add new `enable_registration_token_3pid_bypass` configuration option to allow registrations via token as an alternative to verifying a 3pid. \ No newline at end of file diff --git a/changelog.d/12531.misc b/changelog.d/12531.misc new file mode 100644 index 0000000000..412fc9b6dc --- /dev/null +++ b/changelog.d/12531.misc @@ -0,0 +1 @@ +Remove unused `# type: ignore`s. diff --git a/changelog.d/12541.docker b/changelog.d/12541.docker new file mode 100644 index 0000000000..c3b9c31657 --- /dev/null +++ b/changelog.d/12541.docker @@ -0,0 +1 @@ +Explicitly opt-in to using [BuildKit-specific features](https://github.com/moby/buildkit/blob/master/frontend/dockerfile/docs/syntax.md) in the Dockerfile. This fixes issues with building images in some GitLab CI environments. diff --git a/changelog.d/12544.bugfix b/changelog.d/12544.bugfix new file mode 100644 index 0000000000..b5169cd831 --- /dev/null +++ b/changelog.d/12544.bugfix @@ -0,0 +1 @@ +Fix a bug where attempting to send a large amount of read receipts to an application service all at once would result in duplicate content and abnormally high memory usage. Contributed by Brad & Nick @ Beeper. diff --git a/changelog.d/12556.misc b/changelog.d/12556.misc new file mode 100644 index 0000000000..dc245397fb --- /dev/null +++ b/changelog.d/12556.misc @@ -0,0 +1 @@ +Release script: confirm the commit to be tagged before tagging. diff --git a/changelog.d/12564.misc b/changelog.d/12564.misc new file mode 100644 index 0000000000..207c322464 --- /dev/null +++ b/changelog.d/12564.misc @@ -0,0 +1 @@ +Consistently check if an object is a `frozendict`. diff --git a/changelog.d/12570.bugfix b/changelog.d/12570.bugfix new file mode 100644 index 0000000000..1038646f35 --- /dev/null +++ b/changelog.d/12570.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.57 which could cause `Failed to calculate hosts in room` errors to be logged for outbound federation. diff --git a/changelog.d/12573.docker b/changelog.d/12573.docker new file mode 100644 index 0000000000..5cc8de50ac --- /dev/null +++ b/changelog.d/12573.docker @@ -0,0 +1 @@ +Update the "Build docker images" GitHub Actions workflow to use `docker/metadata-action` to generate docker image tags, instead of a custom shell script. Contributed by henryclw. \ No newline at end of file diff --git a/changelog.d/12576.misc b/changelog.d/12576.misc new file mode 100644 index 0000000000..71022c8633 --- /dev/null +++ b/changelog.d/12576.misc @@ -0,0 +1 @@ +Allow unused `#type: ignore` comments in bleeding edge CI jobs. diff --git a/changelog.d/12577.misc b/changelog.d/12577.misc new file mode 100644 index 0000000000..8c4c47ad52 --- /dev/null +++ b/changelog.d/12577.misc @@ -0,0 +1 @@ +Improve comments and error messages around access tokens. \ No newline at end of file diff --git a/changelog.d/12579.doc b/changelog.d/12579.doc new file mode 100644 index 0000000000..bcec5fe1af --- /dev/null +++ b/changelog.d/12579.doc @@ -0,0 +1 @@ +Add missing linebreak to pipx install instructions. diff --git a/changelog.d/12580.bugfix b/changelog.d/12580.bugfix new file mode 100644 index 0000000000..bedce405e2 --- /dev/null +++ b/changelog.d/12580.bugfix @@ -0,0 +1 @@ +Fix a long standing bug where status codes would almost always get logged as 200!, irrespective of the actual status code, when clients disconnect before a request has finished processing. diff --git a/changelog.d/12581.misc b/changelog.d/12581.misc new file mode 100644 index 0000000000..38d40b262b --- /dev/null +++ b/changelog.d/12581.misc @@ -0,0 +1 @@ +Improve docstrings for the receipts store. diff --git a/changelog.d/12582.misc b/changelog.d/12582.misc new file mode 100644 index 0000000000..5fa9c9afe8 --- /dev/null +++ b/changelog.d/12582.misc @@ -0,0 +1 @@ +Use constants for read-receipts in tests. diff --git a/changelog.d/12587.misc b/changelog.d/12587.misc new file mode 100644 index 0000000000..d26e332305 --- /dev/null +++ b/changelog.d/12587.misc @@ -0,0 +1 @@ +Add `@cancellable` decorator, for use on endpoint methods that can be cancelled when clients disconnect. diff --git a/changelog.d/12589.misc b/changelog.d/12589.misc new file mode 100644 index 0000000000..d362828d2e --- /dev/null +++ b/changelog.d/12589.misc @@ -0,0 +1 @@ +Remove special-case for `twisted` logger from default log config. diff --git a/changelog.d/12594.bugfix b/changelog.d/12594.bugfix new file mode 100644 index 0000000000..7411d9c079 --- /dev/null +++ b/changelog.d/12594.bugfix @@ -0,0 +1 @@ +Fix race when persisting an event and deleting a room that could lead to outbound federation breaking. diff --git a/changelog.d/12596.removal b/changelog.d/12596.removal new file mode 100644 index 0000000000..14fbfb3954 --- /dev/null +++ b/changelog.d/12596.removal @@ -0,0 +1 @@ +Remove unstable identifiers from [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069). diff --git a/changelog.d/12597.removal b/changelog.d/12597.removal new file mode 100644 index 0000000000..7927f1d68d --- /dev/null +++ b/changelog.d/12597.removal @@ -0,0 +1,2 @@ +Remove the unspecified `m.login.jwt` login type and the unstable `uk.half-shot.msc2778.login.application_service` from +[MSC2778](https://github.com/matrix-org/matrix-doc/pull/2778). diff --git a/changelog.d/12599.misc b/changelog.d/12599.misc new file mode 100644 index 0000000000..d01278bbce --- /dev/null +++ b/changelog.d/12599.misc @@ -0,0 +1 @@ +Use `getClientAddress` instead of the deprecated `getClientIP`. diff --git a/changelog.d/12608.misc b/changelog.d/12608.misc new file mode 100644 index 0000000000..38272118fb --- /dev/null +++ b/changelog.d/12608.misc @@ -0,0 +1 @@ +Remove redundant lines of config from `mypy.ini`. \ No newline at end of file diff --git a/changelog.d/12610.misc b/changelog.d/12610.misc new file mode 100644 index 0000000000..02efe0c72f --- /dev/null +++ b/changelog.d/12610.misc @@ -0,0 +1 @@ +Reduce log spam when running multiple event persisters. diff --git a/changelog.d/12612.bugfix b/changelog.d/12612.bugfix new file mode 100644 index 0000000000..c39e97f0cb --- /dev/null +++ b/changelog.d/12612.bugfix @@ -0,0 +1 @@ +Fix a typo in the announcement text generated by the Synapse release development script. \ No newline at end of file diff --git a/changelog.d/12613.removal b/changelog.d/12613.removal new file mode 100644 index 0000000000..b1a9e207b0 --- /dev/null +++ b/changelog.d/12613.removal @@ -0,0 +1 @@ +Synapse now requires at least Python 3.7.1 (up from 3.7.0), for compatibility with the latest Twisted trunk. diff --git a/changelog.d/12614.misc b/changelog.d/12614.misc new file mode 100644 index 0000000000..79022df127 --- /dev/null +++ b/changelog.d/12614.misc @@ -0,0 +1 @@ +Add extra debug logging to federation sender. diff --git a/changelog.d/12616.misc b/changelog.d/12616.misc new file mode 100644 index 0000000000..d17ce24cdf --- /dev/null +++ b/changelog.d/12616.misc @@ -0,0 +1 @@ +Prevent remote homeservers from requesting local user device names by default. \ No newline at end of file diff --git a/changelog.d/12619.feature b/changelog.d/12619.feature new file mode 100644 index 0000000000..b0fc0f5fed --- /dev/null +++ b/changelog.d/12619.feature @@ -0,0 +1 @@ +Add new `mau_appservice_trial_days` configuration option to specify a different trial period for users registered via an appservice. diff --git a/changelog.d/12620.misc b/changelog.d/12620.misc new file mode 100644 index 0000000000..63f8e540c3 --- /dev/null +++ b/changelog.d/12620.misc @@ -0,0 +1 @@ +Add a consistency check on events which we read from the database. diff --git a/changelog.d/12624.misc b/changelog.d/12624.misc new file mode 100644 index 0000000000..8772d40fa7 --- /dev/null +++ b/changelog.d/12624.misc @@ -0,0 +1 @@ +Remove use of constantly library and switch to enums for EventRedactBehaviour. Contributed by @andrewdoh. diff --git a/changelog.d/12627.doc b/changelog.d/12627.doc new file mode 100644 index 0000000000..3a787dfef2 --- /dev/null +++ b/changelog.d/12627.doc @@ -0,0 +1 @@ +Fixes to the formatting of README.rst. diff --git a/changelog.d/12632.misc b/changelog.d/12632.misc new file mode 100644 index 0000000000..9e4ba79c79 --- /dev/null +++ b/changelog.d/12632.misc @@ -0,0 +1 @@ +Remove unused code related to receipts. diff --git a/changelog.d/12633.bugfix b/changelog.d/12633.bugfix new file mode 100644 index 0000000000..32332acd9a --- /dev/null +++ b/changelog.d/12633.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.53.0 where bundled aggregations for annotations/edits were incorrectly calculated. diff --git a/changelog.d/12635.feature b/changelog.d/12635.feature new file mode 100644 index 0000000000..cd5c45029e --- /dev/null +++ b/changelog.d/12635.feature @@ -0,0 +1 @@ +Implement [changes](https://github.com/matrix-org/matrix-spec-proposals/pull/2285/commits/4a77139249c2e830aec3c7d6bd5501a514d1cc27) to [MSC2285 (hidden read receipts)](https://github.com/matrix-org/matrix-spec-proposals/pull/2285). Contributed by @SimonBrandner. diff --git a/changelog.d/12636.feature b/changelog.d/12636.feature new file mode 100644 index 0000000000..cd5c45029e --- /dev/null +++ b/changelog.d/12636.feature @@ -0,0 +1 @@ +Implement [changes](https://github.com/matrix-org/matrix-spec-proposals/pull/2285/commits/4a77139249c2e830aec3c7d6bd5501a514d1cc27) to [MSC2285 (hidden read receipts)](https://github.com/matrix-org/matrix-spec-proposals/pull/2285). Contributed by @SimonBrandner. diff --git a/changelog.d/12639.bugfix b/changelog.d/12639.bugfix new file mode 100644 index 0000000000..c01596282c --- /dev/null +++ b/changelog.d/12639.bugfix @@ -0,0 +1 @@ +Add new `enable_registration_token_3pid_bypass` configuration option to allow registrations via token as an alternative to verifying a 3pid. \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index 4523c60645..ccc6a9f778 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,3 +1,4 @@ +# syntax=docker/dockerfile:1 # Dockerfile to build the matrixdotorg/synapse docker images. # # Note that it uses features which are only available in BuildKit - see diff --git a/docker/Dockerfile-workers b/docker/Dockerfile-workers index 9ccb2b22a7..24b03585f9 100644 --- a/docker/Dockerfile-workers +++ b/docker/Dockerfile-workers @@ -20,6 +20,9 @@ RUN rm /etc/nginx/sites-enabled/default # Copy Synapse worker, nginx and supervisord configuration template files COPY ./docker/conf-workers/* /conf/ +# Copy a script to prefix log lines with the supervisor program name +COPY ./docker/prefix-log /usr/local/bin/ + # Expose nginx listener port EXPOSE 8080/tcp diff --git a/docker/complement/SynapseWorkers.Dockerfile b/docker/complement/SynapseWorkers.Dockerfile index 65df2d114d..9a4438e730 100644 --- a/docker/complement/SynapseWorkers.Dockerfile +++ b/docker/complement/SynapseWorkers.Dockerfile @@ -34,13 +34,16 @@ WORKDIR /data # Copy the caddy config COPY conf-workers/caddy.complement.json /root/caddy.json +COPY conf-workers/postgres.supervisord.conf /etc/supervisor/conf.d/postgres.conf +COPY conf-workers/caddy.supervisord.conf /etc/supervisor/conf.d/caddy.conf + # Copy the entrypoint COPY conf-workers/start-complement-synapse-workers.sh / # Expose caddy's listener ports EXPOSE 8008 8448 -ENTRYPOINT /start-complement-synapse-workers.sh +ENTRYPOINT ["/start-complement-synapse-workers.sh"] # Update the healthcheck to have a shorter check interval HEALTHCHECK --start-period=5s --interval=1s --timeout=1s \ diff --git a/docker/complement/conf-workers/caddy.supervisord.conf b/docker/complement/conf-workers/caddy.supervisord.conf new file mode 100644 index 0000000000..d9ddb51dac --- /dev/null +++ b/docker/complement/conf-workers/caddy.supervisord.conf @@ -0,0 +1,7 @@ +[program:caddy] +command=/usr/local/bin/prefix-log /root/caddy run --config /root/caddy.json +autorestart=unexpected +stdout_logfile=/dev/stdout +stdout_logfile_maxbytes=0 +stderr_logfile=/dev/stderr +stderr_logfile_maxbytes=0 diff --git a/docker/complement/conf-workers/postgres.supervisord.conf b/docker/complement/conf-workers/postgres.supervisord.conf new file mode 100644 index 0000000000..5608342d1a --- /dev/null +++ b/docker/complement/conf-workers/postgres.supervisord.conf @@ -0,0 +1,16 @@ +[program:postgres] +command=/usr/local/bin/prefix-log /usr/bin/pg_ctlcluster 13 main start --foreground + +# Lower priority number = starts first +priority=1 + +autorestart=unexpected +stdout_logfile=/dev/stdout +stdout_logfile_maxbytes=0 +stderr_logfile=/dev/stderr +stderr_logfile_maxbytes=0 + +# Use 'Fast Shutdown' mode which aborts current transactions and closes connections quickly. +# (Default (TERM) is 'Smart Shutdown' which stops accepting new connections but +# lets existing connections close gracefully.) +stopsignal=INT diff --git a/docker/complement/conf-workers/start-complement-synapse-workers.sh b/docker/complement/conf-workers/start-complement-synapse-workers.sh index 2c1e05bd62..b9a6b55bbe 100755 --- a/docker/complement/conf-workers/start-complement-synapse-workers.sh +++ b/docker/complement/conf-workers/start-complement-synapse-workers.sh @@ -12,12 +12,6 @@ function log { # Replace the server name in the caddy config sed -i "s/{{ server_name }}/${SERVER_NAME}/g" /root/caddy.json -log "starting postgres" -pg_ctlcluster 13 main start - -log "starting caddy" -/root/caddy start --config /root/caddy.json - # Set the server name of the homeserver export SYNAPSE_SERVER_NAME=${SERVER_NAME} diff --git a/docker/conf/log.config b/docker/conf/log.config index 7a216a36a0..dc8c70befd 100644 --- a/docker/conf/log.config +++ b/docker/conf/log.config @@ -2,11 +2,7 @@ version: 1 formatters: precise: -{% if worker_name %} - format: '%(asctime)s - worker:{{ worker_name }} - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s' -{% else %} format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s' -{% endif %} handlers: {% if LOG_FILE_PATH %} diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 3bda6c300b..33fc20d218 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -171,7 +171,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { # Templates for sections that may be inserted multiple times in config files SUPERVISORD_PROCESS_CONFIG_BLOCK = """ [program:synapse_{name}] -command=/usr/local/bin/python -m {app} \ +command=/usr/local/bin/prefix-log /usr/local/bin/python -m {app} \ --config-path="{config_path}" \ --config-path=/conf/workers/shared.yaml \ --config-path=/conf/workers/{name}.yaml diff --git a/docker/prefix-log b/docker/prefix-log new file mode 100755 index 0000000000..0e26a4f19d --- /dev/null +++ b/docker/prefix-log @@ -0,0 +1,12 @@ +#!/bin/bash +# +# Prefixes all lines on stdout and stderr with the process name (as determined by +# the SUPERVISOR_PROCESS_NAME env var, which is automatically set by Supervisor). +# +# Usage: +# prefix-log command [args...] +# + +exec 1> >(awk '{print "'"${SUPERVISOR_PROCESS_NAME}"' | "$0}' >&1) +exec 2> >(awk '{print "'"${SUPERVISOR_PROCESS_NAME}"' | "$0}' >&2) +exec "$@" diff --git a/docs/jwt.md b/docs/jwt.md index 32f58cc0cb..346daf78ad 100644 --- a/docs/jwt.md +++ b/docs/jwt.md @@ -17,9 +17,6 @@ follows: } ``` -Note that the login type of `m.login.jwt` is supported, but is deprecated. This -will be removed in a future version of Synapse. - The `token` field should include the JSON web token with the following claims: * A claim that encodes the local part of the user ID is required. By default, diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b8d8c0dbf0..a803b8261d 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -407,6 +407,11 @@ manhole_settings: # sign up in a short space of time never to return after their initial # session. # +# The option `mau_appservice_trial_days` is similar to `mau_trial_days`, but +# applies a different trial number if the user was registered by an appservice. +# A value of 0 means no trial days are applied. Appservices not listed in this +# dictionary use the value of `mau_trial_days` instead. +# # 'mau_limit_alerting' is a means of limiting client side alerting # should the mau limit be reached. This is useful for small instances # where the admin has 5 mau seats (say) for 5 specific people and no @@ -417,6 +422,8 @@ manhole_settings: #max_mau_value: 50 #mau_trial_days: 2 #mau_limit_alerting: false +#mau_appservice_trial_days: +# "appservice-id": 1 # If enabled, the metrics for the number of monthly active users will # be populated, however no one will be limited. If limit_usage_by_mau @@ -709,11 +716,11 @@ retention: # #allow_profile_lookup_over_federation: false -# Uncomment to disable device display name lookup over federation. By default, the -# Federation API allows other homeservers to obtain device display names of any user -# on this homeserver. Defaults to 'true'. +# Uncomment to allow device display name lookup over federation. By default, the +# Federation API prevents other homeservers from obtaining the display names of +# user devices on this homeserver. Defaults to 'false'. # -#allow_device_name_lookup_over_federation: false +#allow_device_name_lookup_over_federation: true ## Caching ## @@ -1323,6 +1330,12 @@ oembed: # #registration_requires_token: true +# Allow users to submit a token during registration to bypass any required 3pid +# steps configured in `registrations_require_3pid`. +# Defaults to false, requiring that registration tokens (if enabled) complete a 3pid flow. +# +#enable_registration_token_3pid_bypass: false + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index 2485ad25ed..3065a0e2d9 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -62,13 +62,6 @@ loggers: # information such as access tokens. level: INFO - twisted: - # We send the twisted logging directly to the file handler, - # to work around https://github.com/matrix-org/synapse/issues/3471 - # when using "buffer" logger. Use "console" to log to stderr instead. - handlers: [file] - propagate: false - root: level: INFO diff --git a/docs/upgrade.md b/docs/upgrade.md index 3a8aeb0395..b40cac86f0 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -89,6 +89,17 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.59.0 + +## Device name lookup over federation has been disabled by default + +The names of user devices are no longer visible to users on other homeservers by default. +Device IDs are unaffected, as these are necessary to facilitate end-to-end encryption. + +To re-enable this functionality, set the +[`allow_device_name_lookup_over_federation`](https://matrix-org.github.io/synapse/v1.59/usage/configuration/config_documentation.html#federation) +homeserver config option to `true`. + # Upgrading to v1.58.0 ## Groups/communities feature has been disabled by default diff --git a/docs/usage/administration/request_log.md b/docs/usage/administration/request_log.md index 316304c734..adb5f4f5f3 100644 --- a/docs/usage/administration/request_log.md +++ b/docs/usage/administration/request_log.md @@ -28,7 +28,7 @@ See the following for how to decode the dense data available from the default lo | NNNN | Total time waiting for response to DB queries across all parallel DB work from this request | | OOOO | Count of DB transactions performed | | PPPP | Response body size | -| QQQQ | Response status code (prefixed with ! if the socket was closed before the response was generated) | +| QQQQ | Response status code<br/>Suffixed with `!` if the socket was closed before the response was generated.<br/>A `499!` status code indicates that Synapse also cancelled request processing after the socket was closed.<br/> | | RRRR | Request | | SSSS | User-agent | | TTTT | Events fetched from DB to service this request (note that this does not include events fetched from the cache) | diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 968b0fbfaf..21dad0ac41 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -627,6 +627,20 @@ Example configuration: mau_trial_days: 5 ``` --- +Config option: `mau_appservice_trial_days` + +The option `mau_appservice_trial_days` is similar to `mau_trial_days`, but applies a different +trial number if the user was registered by an appservice. A value +of 0 means no trial days are applied. Appservices not listed in this dictionary +use the value of `mau_trial_days` instead. + +Example configuration: +```yaml +mau_appservice_trial_days: + my_appservice_id: 3 + another_appservice_id: 6 +``` +--- Config option: `mau_limit_alerting` The option `mau_limit_alerting` is a means of limiting client-side alerting @@ -1035,13 +1049,13 @@ allow_profile_lookup_over_federation: false --- Config option: `allow_device_name_lookup_over_federation` -Set this option to false to disable device display name lookup over federation. By default, the -Federation API allows other homeservers to obtain device display names of any user +Set this option to true to allow device display name lookup over federation. By default, the +Federation API prevents other homeservers from obtaining the display names of any user devices on this homeserver. Example configuration: ```yaml -allow_device_name_lookup_over_federation: false +allow_device_name_lookup_over_federation: true ``` --- ## Caching ## diff --git a/mypy.ini b/mypy.ini index a663bf6975..ba0de419f5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -7,6 +7,7 @@ show_error_codes = True show_traceback = True mypy_path = stubs warn_unreachable = True +warn_unused_ignores = True local_partial_types = True no_implicit_optional = True @@ -23,10 +24,6 @@ files = # https://docs.python.org/3/library/re.html#re.X exclude = (?x) ^( - |scripts-dev/build_debian_packages.py - |scripts-dev/federation_client.py - |scripts-dev/release.py - |synapse/storage/databases/__init__.py |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py @@ -134,6 +131,11 @@ disallow_untyped_defs = True [mypy-synapse.metrics.*] disallow_untyped_defs = True +[mypy-synapse.metrics._reactor_metrics] +# This module imports select.epoll. That exists on Linux, but doesn't on macOS. +# See https://github.com/matrix-org/synapse/pull/11771. +warn_unused_ignores = False + [mypy-synapse.module_api.*] disallow_untyped_defs = True @@ -239,63 +241,26 @@ disallow_untyped_defs = True [mypy-authlib.*] ignore_missing_imports = True -[mypy-bcrypt] -ignore_missing_imports = True - [mypy-canonicaljson] ignore_missing_imports = True -[mypy-constantly] -ignore_missing_imports = True - -[mypy-daemonize] -ignore_missing_imports = True - -[mypy-h11] -ignore_missing_imports = True - -[mypy-hiredis] -ignore_missing_imports = True - -[mypy-hyperlink] -ignore_missing_imports = True - [mypy-ijson.*] ignore_missing_imports = True -[mypy-importlib_metadata.*] -ignore_missing_imports = True - -[mypy-jaeger_client.*] -ignore_missing_imports = True - -[mypy-josepy.*] -ignore_missing_imports = True - -[mypy-jwt.*] -ignore_missing_imports = True - [mypy-lxml] ignore_missing_imports = True [mypy-msgpack] ignore_missing_imports = True -[mypy-nacl.*] -ignore_missing_imports = True - +# Note: WIP stubs available at +# https://github.com/microsoft/python-type-stubs/tree/64934207f523ad6b611e6cfe039d85d7175d7d0d/netaddr [mypy-netaddr] ignore_missing_imports = True [mypy-parameterized.*] ignore_missing_imports = True -[mypy-phonenumbers.*] -ignore_missing_imports = True - -[mypy-prometheus_client.*] -ignore_missing_imports = True - [mypy-pymacaroons.*] ignore_missing_imports = True @@ -308,23 +273,14 @@ ignore_missing_imports = True [mypy-saml2.*] ignore_missing_imports = True -[mypy-sentry_sdk] -ignore_missing_imports = True - [mypy-service_identity.*] ignore_missing_imports = True -[mypy-signedjson.*] +[mypy-srvlookup.*] ignore_missing_imports = True [mypy-treq.*] ignore_missing_imports = True -[mypy-twisted.*] -ignore_missing_imports = True - -[mypy-zope] -ignore_missing_imports = True - [mypy-incremental.*] ignore_missing_imports = True diff --git a/poetry.lock b/poetry.lock index 8c7af1fa1e..89e78576b7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -309,14 +309,15 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.14" -description = "Python Git Library" +version = "3.1.27" +description = "GitPython is a python library used to interact with Git repositories" category = "dev" optional = false -python-versions = ">=3.4" +python-versions = ">=3.7" [package.dependencies] gitdb = ">=4.0.1,<5" +typing-extensions = {version = ">=3.7.4.3", markers = "python_version < \"3.8\""} [[package]] name = "hiredis" @@ -1316,6 +1317,14 @@ optional = false python-versions = "*" [[package]] +name = "types-commonmark" +version = "0.9.2" +description = "Typing stubs for commonmark" +category = "dev" +optional = false +python-versions = "*" + +[[package]] name = "types-cryptography" version = "3.3.15" description = "Typing stubs for cryptography" @@ -1552,8 +1561,8 @@ url_preview = ["lxml"] [metadata] lock-version = "1.1" -python-versions = "^3.7" -content-hash = "f482a4f594a165dfe01ce253a22510d5faf38647ab0dcebc35789350cafd9bf0" +python-versions = "^3.7.1" +content-hash = "2bda1a7cfc8cc02832b4a7d16bf7e1615cb05e0639bdb30688aadf692d851942" [metadata.files] attrs = [ @@ -1766,8 +1775,8 @@ gitdb = [ {file = "gitdb-4.0.9.tar.gz", hash = "sha256:bac2fd45c0a1c9cf619e63a90d62bdc63892ef92387424b855792a6cabe789aa"}, ] gitpython = [ - {file = "GitPython-3.1.14-py3-none-any.whl", hash = "sha256:3283ae2fba31c913d857e12e5ba5f9a7772bbc064ae2bb09efafa71b0dd4939b"}, - {file = "GitPython-3.1.14.tar.gz", hash = "sha256:be27633e7509e58391f10207cd32b2a6cf5b908f92d9cd30da2e514e1137af61"}, + {file = "GitPython-3.1.27-py3-none-any.whl", hash = "sha256:5b68b000463593e05ff2b261acff0ff0972df8ab1b70d3cdbd41b546c8b8fc3d"}, + {file = "GitPython-3.1.27.tar.gz", hash = "sha256:1c885ce809e8ba2d88a29befeb385fcea06338d3640712b59ca623c220bb5704"}, ] hiredis = [ {file = "hiredis-2.0.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b4c8b0bc5841e578d5fb32a16e0c305359b987b850a06964bd5a62739d688048"}, @@ -2588,6 +2597,10 @@ types-bleach = [ {file = "types-bleach-4.1.4.tar.gz", hash = "sha256:2d30c2c4fb6854088ac636471352c9a51bf6c089289800d2a8060820a01cd43a"}, {file = "types_bleach-4.1.4-py3-none-any.whl", hash = "sha256:edffe173ed6d7b6f3543036a96204a9319c3bf6c3645917b14274e43f000cc9b"}, ] +types-commonmark = [ + {file = "types-commonmark-0.9.2.tar.gz", hash = "sha256:b894b67750c52fd5abc9a40a9ceb9da4652a391d75c1b480bba9cef90f19fc86"}, + {file = "types_commonmark-0.9.2-py3-none-any.whl", hash = "sha256:56f20199a1f9a2924443211a0ef97f8b15a8a956a7f4e9186be6950bf38d6d02"}, +] types-cryptography = [ {file = "types-cryptography-3.3.15.tar.gz", hash = "sha256:a7983a75a7b88a18f88832008f0ef140b8d1097888ec1a0824ec8fb7e105273b"}, {file = "types_cryptography-3.3.15-py3-none-any.whl", hash = "sha256:d9b0dd5465d7898d400850e7f35e5518aa93a7e23d3e11757cd81b4777089046"}, diff --git a/pyproject.toml b/pyproject.toml index 4167abb233..446895711b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ synapse_review_recent_signups = "synapse._scripts.review_recent_signups:main" update_synapse_database = "synapse._scripts.update_synapse_database:main" [tool.poetry.dependencies] -python = "^3.7" +python = "^3.7.1" # Mandatory Dependencies # ---------------------- @@ -251,6 +251,7 @@ flake8 = "*" mypy = "==0.931" mypy-zope = "==0.3.5" types-bleach = ">=4.1.0" +types-commonmark = ">=0.9.2" types-jsonschema = ">=3.2.0" types-opentracing = ">=2.4.2" types-Pillow = ">=8.3.4" @@ -270,7 +271,8 @@ idna = ">=2.5" # The following are used by the release script click = "==8.1.0" -GitPython = "==3.1.14" +# GitPython was == 3.1.14; bumped to 3.1.20, the first release with type hints. +GitPython = ">=3.1.20" commonmark = "==0.9.1" pygithub = "==1.55" # The following are executed as commands by the release script. diff --git a/scripts-dev/build_debian_packages.py b/scripts-dev/build_debian_packages.py index e3e6878686..38564893e9 100755 --- a/scripts-dev/build_debian_packages.py +++ b/scripts-dev/build_debian_packages.py @@ -17,7 +17,8 @@ import subprocess import sys import threading from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Sequence +from types import FrameType +from typing import Collection, Optional, Sequence, Set DISTS = ( "debian:buster", # oldstable: EOL 2022-08 @@ -41,15 +42,17 @@ projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) class Builder(object): def __init__( - self, redirect_stdout=False, docker_build_args: Optional[Sequence[str]] = None + self, + redirect_stdout: bool = False, + docker_build_args: Optional[Sequence[str]] = None, ): self.redirect_stdout = redirect_stdout self._docker_build_args = tuple(docker_build_args or ()) - self.active_containers = set() + self.active_containers: Set[str] = set() self._lock = threading.Lock() self._failed = False - def run_build(self, dist, skip_tests=False): + def run_build(self, dist: str, skip_tests: bool = False) -> None: """Build deb for a single distribution""" if self._failed: @@ -63,7 +66,7 @@ class Builder(object): self._failed = True raise - def _inner_build(self, dist, skip_tests=False): + def _inner_build(self, dist: str, skip_tests: bool = False) -> None: tag = dist.split(":", 1)[1] # Make the dir where the debs will live. @@ -138,7 +141,7 @@ class Builder(object): stdout.close() print("Completed build of %s" % (dist,)) - def kill_containers(self): + def kill_containers(self) -> None: with self._lock: active = list(self.active_containers) @@ -156,8 +159,10 @@ class Builder(object): self.active_containers.remove(c) -def run_builds(builder, dists, jobs=1, skip_tests=False): - def sig(signum, _frame): +def run_builds( + builder: Builder, dists: Collection[str], jobs: int = 1, skip_tests: bool = False +) -> None: + def sig(signum: int, _frame: Optional[FrameType]) -> None: print("Caught SIGINT") builder.kill_containers() diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 079d2f5ed0..763dd02c47 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -38,7 +38,7 @@ import argparse import base64 import json import sys -from typing import Any, Optional +from typing import Any, Dict, Optional, Tuple from urllib import parse as urlparse import requests @@ -47,13 +47,14 @@ import signedjson.types import srvlookup import yaml from requests.adapters import HTTPAdapter +from urllib3 import HTTPConnectionPool # uncomment the following to enable debug logging of http requests # from httplib import HTTPConnection # HTTPConnection.debuglevel = 1 -def encode_base64(input_bytes): +def encode_base64(input_bytes: bytes) -> str: """Encode bytes as a base64 string without any padding.""" input_len = len(input_bytes) @@ -63,7 +64,7 @@ def encode_base64(input_bytes): return output_string -def encode_canonical_json(value): +def encode_canonical_json(value: object) -> bytes: return json.dumps( value, # Encode code-points outside of ASCII as UTF-8 rather than \u escapes @@ -130,7 +131,7 @@ def request( sig, destination, ) - authorization_headers.append(header.encode("ascii")) + authorization_headers.append(header) print("Authorization: %s" % header, file=sys.stderr) dest = "matrix://%s%s" % (destination, path) @@ -139,7 +140,10 @@ def request( s = requests.Session() s.mount("matrix://", MatrixConnectionAdapter()) - headers = {"Host": destination, "Authorization": authorization_headers[0]} + headers: Dict[str, str] = { + "Host": destination, + "Authorization": authorization_headers[0], + } if method == "POST": headers["Content-Type"] = "application/json" @@ -154,7 +158,7 @@ def request( ) -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="Signs and sends a federation request to a matrix homeserver" ) @@ -212,6 +216,7 @@ def main(): if not args.server_name or not args.signing_key: read_args_from_config(args) + assert isinstance(args.signing_key, str) algorithm, version, key_base64 = args.signing_key.split() key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64) @@ -233,7 +238,7 @@ def main(): print("") -def read_args_from_config(args): +def read_args_from_config(args: argparse.Namespace) -> None: with open(args.config, "r") as fh: config = yaml.safe_load(fh) @@ -250,7 +255,7 @@ def read_args_from_config(args): class MatrixConnectionAdapter(HTTPAdapter): @staticmethod - def lookup(s, skip_well_known=False): + def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]: if s[-1] == "]": # ipv6 literal (with no port) return s, 8448 @@ -276,7 +281,7 @@ class MatrixConnectionAdapter(HTTPAdapter): return s, 8448 @staticmethod - def get_well_known(server_name): + def get_well_known(server_name: str) -> Optional[str]: uri = "https://%s/.well-known/matrix/server" % (server_name,) print("fetching %s" % (uri,), file=sys.stderr) @@ -299,7 +304,9 @@ class MatrixConnectionAdapter(HTTPAdapter): print("Invalid response from %s: %s" % (uri, e), file=sys.stderr) return None - def get_connection(self, url, proxies=None): + def get_connection( + self, url: str, proxies: Optional[Dict[str, str]] = None + ) -> HTTPConnectionPool: parsed = urlparse.urlparse(url) (host, port) = self.lookup(parsed.netloc) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 1217e14874..c775865212 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -16,7 +16,7 @@ can crop up, e.g the cache descriptors. """ -from typing import Callable, Optional +from typing import Callable, Optional, Type from mypy.nodes import ARG_NAMED_OPT from mypy.plugin import MethodSigContext, Plugin @@ -94,7 +94,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: return signature -def plugin(version: str): +def plugin(version: str) -> Type[SynapsePlugin]: # This is the entry point of the plugin, and let's us deal with the fact # that the mypy plugin interface is *not* stable by looking at the version # string. diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 9d7c7c445f..0031ba3e4b 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -25,7 +25,7 @@ import sys import urllib.request from os import path from tempfile import TemporaryDirectory -from typing import List, Optional +from typing import Any, List, Optional, cast import attr import click @@ -36,7 +36,9 @@ from github import Github from packaging import version -def run_until_successful(command, *args, **kwargs): +def run_until_successful( + command: str, *args: Any, **kwargs: Any +) -> subprocess.CompletedProcess: while True: completed_process = subprocess.run(command, *args, **kwargs) exit_code = completed_process.returncode @@ -50,7 +52,7 @@ def run_until_successful(command, *args, **kwargs): @click.group() -def cli(): +def cli() -> None: """An interactive script to walk through the parts of creating a release. Requires the dev dependencies be installed, which can be done via: @@ -81,19 +83,13 @@ def cli(): @cli.command() -def prepare(): +def prepare() -> None: """Do the initial stages of creating a release, including creating release branch, updating changelog and pushing to GitHub. """ # Make sure we're in a git repo. - try: - repo = git.Repo() - except git.InvalidGitRepositoryError: - raise click.ClickException("Not in Synapse repo.") - - if repo.is_dirty(): - raise click.ClickException("Uncommitted changes exist.") + repo = get_repo_and_check_clean_checkout() click.secho("Updating git repo...") repo.remote().fetch() @@ -161,22 +157,21 @@ def prepare(): click.get_current_context().abort() # Switch to the release branch. - parsed_new_version: version.Version = version.parse(new_version) + # Cast safety: parse() won't return a version.LegacyVersion from our + # version string format. + parsed_new_version = cast(version.Version, version.parse(new_version)) # We assume for debian changelogs that we only do RCs or full releases. assert not parsed_new_version.is_devrelease assert not parsed_new_version.is_postrelease - release_branch_name = ( - f"release-v{parsed_new_version.major}.{parsed_new_version.minor}" - ) + release_branch_name = get_release_branch_name(parsed_new_version) release_branch = find_ref(repo, release_branch_name) if release_branch: if release_branch.is_remote(): # If the release branch only exists on the remote we check it out # locally. repo.git.checkout(release_branch_name) - release_branch = repo.active_branch else: # If a branch doesn't exist we create one. We ask which one branch it # should be based off, defaulting to sensible values depending on the @@ -198,13 +193,15 @@ def prepare(): click.get_current_context().abort() # Check out the base branch and ensure it's up to date - repo.head.reference = base_branch + repo.head.set_reference(base_branch, "check out the base branch") repo.head.reset(index=True, working_tree=True) if not base_branch.is_remote(): update_branch(repo) # Create the new release branch - release_branch = repo.create_head(release_branch_name, commit=base_branch) + # Type ignore will no longer be needed after GitPython 3.1.28. + # See https://github.com/gitpython-developers/GitPython/pull/1419 + repo.create_head(release_branch_name, commit=base_branch) # type: ignore[arg-type] # Switch to the release branch and ensure it's up to date. repo.git.checkout(release_branch_name) @@ -265,17 +262,11 @@ def prepare(): @cli.command() @click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"]) -def tag(gh_token: Optional[str]): +def tag(gh_token: Optional[str]) -> None: """Tags the release and generates a draft GitHub release""" # Make sure we're in a git repo. - try: - repo = git.Repo() - except git.InvalidGitRepositoryError: - raise click.ClickException("Not in Synapse repo.") - - if repo.is_dirty(): - raise click.ClickException("Uncommitted changes exist.") + repo = get_repo_and_check_clean_checkout() click.secho("Updating git repo...") repo.remote().fetch() @@ -288,12 +279,26 @@ def tag(gh_token: Optional[str]): if tag_name in repo.tags: raise click.ClickException(f"Tag {tag_name} already exists!\n") + # Check we're on the right release branch + release_branch = get_release_branch_name(current_version) + if repo.active_branch.name != release_branch: + click.echo( + f"Need to be on the release branch ({release_branch}) before tagging. " + f"Currently on ({repo.active_branch.name})." + ) + click.get_current_context().abort() + # Get the appropriate changelogs and tag. changes = get_changes_for_version(current_version) click.echo_via_pager(changes) if click.confirm("Edit text?", default=False): - changes = click.edit(changes, require_save=False) + edited_changes = click.edit(changes, require_save=False) + # This assert is for mypy's benefit. click's docs are a little unclear, but + # when `require_save=False`, not saving the temp file in the editor returns + # the original string. + assert edited_changes is not None + changes = edited_changes repo.create_tag(tag_name, message=changes, sign=True) @@ -347,22 +352,16 @@ def tag(gh_token: Optional[str]): @cli.command() @click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=True) -def publish(gh_token: str): - """Publish release.""" +def publish(gh_token: str) -> None: + """Publish release on GitHub.""" # Make sure we're in a git repo. - try: - repo = git.Repo() - except git.InvalidGitRepositoryError: - raise click.ClickException("Not in Synapse repo.") - - if repo.is_dirty(): - raise click.ClickException("Uncommitted changes exist.") + get_repo_and_check_clean_checkout() current_version = get_package_version() tag_name = f"v{current_version}" - if not click.confirm(f"Publish {tag_name}?", default=True): + if not click.confirm(f"Publish release {tag_name} on GitHub?", default=True): return # Publish the draft release @@ -390,12 +389,19 @@ def publish(gh_token: str): @cli.command() -def upload(): +def upload() -> None: """Upload release to pypi.""" current_version = get_package_version() tag_name = f"v{current_version}" + # Check we have the right tag checked out. + repo = get_repo_and_check_clean_checkout() + tag = repo.tag(f"refs/tags/{tag_name}") + if repo.head.commit != tag.commit: + click.echo("Tag {tag_name} (tag.commit) is not currently checked out!") + click.get_current_context().abort() + pypi_asset_names = [ f"matrix_synapse-{current_version}-py3-none-any.whl", f"matrix-synapse-{current_version}.tar.gz", @@ -418,7 +424,7 @@ def upload(): @cli.command() -def announce(): +def announce() -> None: """Generate markdown to announce the release.""" current_version = get_package_version() @@ -428,7 +434,7 @@ def announce(): f""" Hi everyone. Synapse {current_version} has just been released. -[notes](https://github.com/matrix-org/synapse/releases/tag/{tag_name}) |\ +[notes](https://github.com/matrix-org/synapse/releases/tag/{tag_name}) | \ [docker](https://hub.docker.com/r/matrixdotorg/synapse/tags?name={tag_name}) | \ [debs](https://packages.matrix.org/debian/) | \ [pypi](https://pypi.org/project/matrix-synapse/{current_version}/)""" @@ -459,20 +465,36 @@ def get_package_version() -> version.Version: return version.Version(version_string) +def get_release_branch_name(version_number: version.Version) -> str: + return f"release-v{version_number.major}.{version_number.minor}" + + +def get_repo_and_check_clean_checkout() -> git.Repo: + """Get the project repo and check it's not got any uncommitted changes.""" + try: + repo = git.Repo() + except git.InvalidGitRepositoryError: + raise click.ClickException("Not in Synapse repo.") + if repo.is_dirty(): + raise click.ClickException("Uncommitted changes exist.") + return repo + + def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]: """Find the branch/ref, looking first locally then in the remote.""" - if ref_name in repo.refs: - return repo.refs[ref_name] + if ref_name in repo.references: + return repo.references[ref_name] elif ref_name in repo.remote().refs: return repo.remote().refs[ref_name] else: return None -def update_branch(repo: git.Repo): +def update_branch(repo: git.Repo) -> None: """Ensure branch is up to date if it has a remote""" - if repo.active_branch.tracking_branch(): - repo.git.merge(repo.active_branch.tracking_branch().name) + tracking_branch = repo.active_branch.tracking_branch() + if tracking_branch: + repo.git.merge(tracking_branch.name) def get_changes_for_version(wanted_version: version.Version) -> str: @@ -536,7 +558,9 @@ def get_changes_for_version(wanted_version: version.Version) -> str: return "\n".join(version_changelog) -def generate_and_write_changelog(current_version: version.Version, new_version: str): +def generate_and_write_changelog( + current_version: version.Version, new_version: str +) -> None: # We do this by getting a draft so that we can edit it before writing to the # changelog. result = run_until_successful( @@ -558,8 +582,8 @@ def generate_and_write_changelog(current_version: version.Version, new_version: f.write(existing_content) # Remove all the news fragments - for f in glob.iglob("changelog.d/*.*"): - os.remove(f) + for filename in glob.iglob("changelog.d/*.*"): + os.remove(filename) if __name__ == "__main__": diff --git a/scripts-dev/sign_json.py b/scripts-dev/sign_json.py index 9459543106..bb217799fb 100755 --- a/scripts-dev/sign_json.py +++ b/scripts-dev/sign_json.py @@ -27,7 +27,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.util import json_encoder -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="""Adds a signature to a JSON object. diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi index e18d617281..3a4f9c3076 100644 --- a/stubs/sortedcontainers/sorteddict.pyi +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -115,9 +115,7 @@ class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]): def __getitem__(self, index: slice) -> List[_KT_co]: ... def __delitem__(self, index: Union[int, slice]) -> None: ... -class SortedItemsView( # type: ignore - ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]] -): +class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]]): def __iter__(self) -> Iterator[Tuple[_KT_co, _VT_co]]: ... @overload def __getitem__(self, index: int) -> Tuple[_KT_co, _VT_co]: ... diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 01c32417d8..931750668e 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -187,7 +187,7 @@ class Auth: Once get_user_by_req has set up the opentracing span, this does the actual work. """ try: - ip_addr = request.getClientIP() + ip_addr = request.getClientAddress().host user_agent = get_request_user_agent(request) access_token = self.get_access_token_from_request(request) @@ -356,7 +356,7 @@ class Auth: return None, None, None if app_service.ip_range_whitelist: - ip_address = IPAddress(request.getClientIP()) + ip_address = IPAddress(request.getClientAddress().host) if ip_address not in app_service.ip_range_whitelist: return None, None, None @@ -417,7 +417,8 @@ class Auth: """ if rights == "access": - # first look in the database + # First look in the database to see if the access token is present + # as an opaque token. r = await self.store.get_user_by_access_token(token) if r: valid_until_ms = r.valid_until_ms @@ -434,7 +435,8 @@ class Auth: return r - # otherwise it needs to be a valid macaroon + # If the token isn't found in the database, then it could still be a + # macaroon, so we check that here. try: user_id, guest = self._parse_and_validate_macaroon(token, rights) @@ -482,8 +484,12 @@ class Auth: TypeError, ValueError, ) as e: - logger.warning("Invalid macaroon in auth: %s %s", type(e), e) - raise InvalidClientTokenError("Invalid macaroon passed.") + logger.warning( + "Invalid access token in auth: %s %s.", + type(e), + e, + ) + raise InvalidClientTokenError("Invalid access token passed.") def _parse_and_validate_macaroon( self, token: str, rights: str = "access" @@ -504,10 +510,7 @@ class Auth: try: macaroon = pymacaroons.Macaroon.deserialize(token) except Exception: # deserialize can throw more-or-less anything - # doesn't look like a macaroon: treat it as an opaque token which - # must be in the database. - # TODO: it would be nice to get rid of this, but apparently some - # people use access tokens which aren't macaroons + # The access token doesn't look like a macaroon. raise _InvalidMacaroonException() try: diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 0172eb60b8..0ccd4c9558 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -255,7 +255,5 @@ class GuestAccess: class ReceiptTypes: READ: Final = "m.read" - - -class ReadReceiptEventFields: - MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden" + READ_PRIVATE: Final = "org.matrix.msc2285.read.private" + FULLY_READ: Final = "m.fully_read" diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 37321f9133..d28b87a3f4 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -48,7 +48,6 @@ from twisted.logger import LoggingFile, LogLevel from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.python.threadpool import ThreadPool -import synapse from synapse.api.constants import MAX_PDU_SIZE from synapse.app import check_bind_error from synapse.app.phone_stats_home import start_phone_stats_home @@ -60,6 +59,7 @@ from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.logging.context import PreserveLoggingContext +from synapse.logging.opentracing import init_tracer from synapse.metrics import install_gc_manager, register_threadpool from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats @@ -431,7 +431,7 @@ async def start(hs: "HomeServer") -> None: refresh_certificate(hs) # Start the tracer - synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa + init_tracer(hs) # noqa # Instantiate the modules so they can register their web resources to the module API # before we start the listeners. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 421ed7481b..abed5e7edb 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -32,7 +32,7 @@ class ExperimentalConfig(Config): # MSC2716 (importing historical messages) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) - # MSC2285 (hidden read receipts) + # MSC2285 (private read receipts) self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) # MSC3244 (room version capabilities) diff --git a/synapse/config/federation.py b/synapse/config/federation.py index 0e74f70784..f83f93c0ef 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -46,7 +46,7 @@ class FederationConfig(Config): ) self.allow_device_name_lookup_over_federation = config.get( - "allow_device_name_lookup_over_federation", True + "allow_device_name_lookup_over_federation", False ) def generate_config_section(self, **kwargs: Any) -> str: @@ -81,11 +81,11 @@ class FederationConfig(Config): # #allow_profile_lookup_over_federation: false - # Uncomment to disable device display name lookup over federation. By default, the - # Federation API allows other homeservers to obtain device display names of any user - # on this homeserver. Defaults to 'true'. + # Uncomment to allow device display name lookup over federation. By default, the + # Federation API prevents other homeservers from obtaining the display names of + # user devices on this homeserver. Defaults to 'false'. # - #allow_device_name_lookup_over_federation: false + #allow_device_name_lookup_over_federation: true """ diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 99db9e1e39..470b8b4492 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -110,13 +110,6 @@ loggers: # information such as access tokens. level: INFO - twisted: - # We send the twisted logging directly to the file handler, - # to work around https://github.com/matrix-org/synapse/issues/3471 - # when using "buffer" logger. Use "console" to log to stderr instead. - handlers: [file] - propagate: false - root: level: INFO diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 39e9acb62a..d2d0425e62 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -43,6 +43,9 @@ class RegistrationConfig(Config): self.registration_requires_token = config.get( "registration_requires_token", False ) + self.enable_registration_token_3pid_bypass = config.get( + "enable_registration_token_3pid_bypass", False + ) self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) @@ -309,6 +312,12 @@ class RegistrationConfig(Config): # #registration_requires_token: true + # Allow users to submit a token during registration to bypass any required 3pid + # steps configured in `registrations_require_3pid`. + # Defaults to false, requiring that registration tokens (if enabled) complete a 3pid flow. + # + #enable_registration_token_3pid_bypass: false + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/synapse/config/server.py b/synapse/config/server.py index d771045b52..1e709c7cf5 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -186,7 +186,7 @@ KNOWN_RESOURCES = { class HttpResourceConfig: names: List[str] = attr.ib( factory=list, - validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), # type: ignore + validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), ) compress: bool = attr.ib( default=False, @@ -231,9 +231,7 @@ class ManholeConfig: class LimitRemoteRoomsConfig: enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False) complexity: Union[float, int] = attr.ib( - validator=attr.validators.instance_of( - (float, int) # type: ignore[arg-type] # noqa - ), + validator=attr.validators.instance_of((float, int)), # noqa default=1.0, ) complexity_error: str = attr.ib( @@ -415,6 +413,7 @@ class ServerConfig(Config): ) self.mau_trial_days = config.get("mau_trial_days", 0) + self.mau_appservice_trial_days = config.get("mau_appservice_trial_days", {}) self.mau_limit_alerting = config.get("mau_limit_alerting", True) # How long to keep redacted events in the database in unredacted form @@ -1107,6 +1106,11 @@ class ServerConfig(Config): # sign up in a short space of time never to return after their initial # session. # + # The option `mau_appservice_trial_days` is similar to `mau_trial_days`, but + # applies a different trial number if the user was registered by an appservice. + # A value of 0 means no trial days are applied. Appservices not listed in this + # dictionary use the value of `mau_trial_days` instead. + # # 'mau_limit_alerting' is a means of limiting client side alerting # should the mau limit be reached. This is useful for small instances # where the admin has 5 mau seats (say) for 5 specific people and no @@ -1117,6 +1121,8 @@ class ServerConfig(Config): #max_mau_value: 50 #mau_trial_days: 2 #mau_limit_alerting: false + #mau_appservice_trial_days: + # "appservice-id": 1 # If enabled, the metrics for the number of monthly active users will # be populated, however no one will be limited. If limit_usage_by_mau diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 9acb3c0cc4..c238376caf 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -213,10 +213,17 @@ class _EventInternalMetadata: return self.outlier def is_out_of_band_membership(self) -> bool: - """Whether this is an out of band membership, like an invite or an invite - rejection. This is needed as those events are marked as outliers, but - they still need to be processed as if they're new events (e.g. updating - invite state in the database, relaying to clients, etc). + """Whether this event is an out-of-band membership. + + OOB memberships are a special case of outlier events: they are membership events + for federated rooms that we aren't full members of. Examples include invites + received over federation, and rejections for such invites. + + The concept of an OOB membership is needed because these events need to be + processed as if they're new regular events (e.g. updating membership state in + the database, relaying to clients via /sync, etc) despite being outliers. + + See also https://matrix-org.github.io/synapse/develop/development/room-dag-concepts.html#out-of-band-membership-events. (Added in synapse 0.99.0, so may be unreliable for events received before that) """ diff --git a/synapse/events/utils.py b/synapse/events/utils.py index f8d3ba5456..026dcde8d8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -27,7 +27,6 @@ from typing import ( ) import attr -from frozendict import frozendict from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError @@ -204,7 +203,9 @@ def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None: key_to_move = field.pop(-1) sub_dict = src for sub_field in field: # e.g. sub_field => "content" - if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]: + if sub_field in sub_dict and isinstance( + sub_dict[sub_field], collections.abc.Mapping + ): sub_dict = sub_dict[sub_field] else: return @@ -425,13 +426,12 @@ class EventClientSerializer: # Check if there are any bundled aggregations to include with the event. if bundle_aggregations: - event_aggregations = bundle_aggregations.get(event.event_id) - if event_aggregations: + if event.event_id in bundle_aggregations: self._inject_bundled_aggregations( event, time_now, config, - event_aggregations, + bundle_aggregations, serialized_event, apply_edits=apply_edits, ) @@ -470,7 +470,7 @@ class EventClientSerializer: event: EventBase, time_now: int, config: SerializeEventConfig, - aggregations: "BundledAggregations", + bundled_aggregations: Dict[str, "BundledAggregations"], serialized_event: JsonDict, apply_edits: bool, ) -> None: @@ -480,22 +480,37 @@ class EventClientSerializer: event: The event being serialized. time_now: The current time in milliseconds config: Event serialization config - aggregations: The bundled aggregation to serialize. + bundled_aggregations: Bundled aggregations to be injected. + A map from event_id to aggregation data. Must contain at least an + entry for `event`. + + While serializing the bundled aggregations this map may be searched + again for additional events in a recursive manner. serialized_event: The serialized event which may be modified. apply_edits: Whether the content of the event should be modified to reflect any replacement in `aggregations.replace`. """ + + # We have already checked that aggregations exist for this event. + event_aggregations = bundled_aggregations[event.event_id] + + # The JSON dictionary to be added under the unsigned property of the event + # being serialized. serialized_aggregations = {} - if aggregations.annotations: - serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations + if event_aggregations.annotations: + serialized_aggregations[ + RelationTypes.ANNOTATION + ] = event_aggregations.annotations - if aggregations.references: - serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references + if event_aggregations.references: + serialized_aggregations[ + RelationTypes.REFERENCE + ] = event_aggregations.references - if aggregations.replace: + if event_aggregations.replace: # If there is an edit, optionally apply it to the event. - edit = aggregations.replace + edit = event_aggregations.replace if apply_edits: self._apply_edit(event, serialized_event, edit) @@ -506,19 +521,16 @@ class EventClientSerializer: "sender": edit.sender, } - # If this event is the start of a thread, include a summary of the replies. - if aggregations.thread: - thread = aggregations.thread + # Include any threaded replies to this event. + if event_aggregations.thread: + thread = event_aggregations.thread - # Don't bundle aggregations as this could recurse forever. - serialized_latest_event = serialize_event( - thread.latest_event, time_now, config=config + serialized_latest_event = self.serialize_event( + thread.latest_event, + time_now, + config=config, + bundle_aggregations=bundled_aggregations, ) - # Manually apply an edit, if one exists. - if thread.latest_edit: - self._apply_edit( - thread.latest_event, serialized_latest_event, thread.latest_edit - ) thread_summary = { "latest_event": serialized_latest_event, @@ -622,7 +634,7 @@ def validate_canonicaljson(value: Any) -> None: # Note that Infinity, -Infinity, and NaN are also considered floats. raise SynapseError(400, "Bad JSON value: float", Codes.BAD_JSON) - elif isinstance(value, (dict, frozendict)): + elif isinstance(value, collections.abc.Mapping): for v in value.values(): validate_canonicaljson(v) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 6a59cb4b71..b5e0b84cbc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1426,6 +1426,8 @@ class FederationClient(FederationBase): room = res.get("room") if not isinstance(room, dict): raise InvalidResponseError("'room' must be a dict") + if room.get("room_id") != room_id: + raise InvalidResponseError("wrong room returned in hierarchy response") # Validate children_state of the room. children_state = room.pop("children_state", []) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index beab1227b8..884b5d60b4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -268,8 +268,8 @@ class FederationServer(FederationBase): transaction_id=transaction_id, destination=destination, origin=origin, - origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore - pdus=transaction_data.get("pdus"), # type: ignore + origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore[arg-type] + pdus=transaction_data.get("pdus"), edus=transaction_data.get("edus"), ) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 30e2421efc..6d2f46318b 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -343,9 +343,16 @@ class FederationSender(AbstractFederationSender): last_token, self._last_poked_id, limit=100 ) - logger.debug("Handling %s -> %s", last_token, next_token) + logger.debug( + "Handling %i -> %i: %i events to send (current id %i)", + last_token, + next_token, + len(events), + self._last_poked_id, + ) if not events and next_token >= self._last_poked_id: + logger.debug("All events processed") break async def handle_event(event: EventBase) -> None: @@ -353,9 +360,53 @@ class FederationSender(AbstractFederationSender): send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of() is_mine = self.is_mine_id(event.sender) if not is_mine and send_on_behalf_of is None: + logger.debug("Not sending remote-origin event %s", event) + return + + # We also want to not send out-of-band membership events. + # + # OOB memberships are used in three (and a half) situations: + # + # (1) invite events which we have received over federation. Those + # will have a `sender` on a different server, so will be + # skipped by the "is_mine" test above anyway. + # + # (2) rejections of invites to federated rooms - either remotely + # or locally generated. (Such rejections are normally + # created via federation, in which case the remote server is + # responsible for sending out the rejection. If that fails, + # we'll create a leave event locally, but that's only really + # for the benefit of the invited user - we don't have enough + # information to send it out over federation). + # + # (2a) rescinded knocks. These are identical to rejected invites. + # + # (3) knock events which we have sent over federation. As with + # invite rejections, the remote server should send them out to + # the federation. + # + # So, in all the above cases, we want to ignore such events. + # + # OOB memberships are always(?) outliers anyway, so if we *don't* + # ignore them, we'll get an exception further down when we try to + # fetch the membership list for the room. + # + # Arguably, we could equivalently ignore all outliers here, since + # in theory the only way for an outlier with a local `sender` to + # exist is by being an OOB membership (via one of (2), (2a) or (3) + # above). + # + if event.internal_metadata.is_out_of_band_membership(): + logger.debug("Not sending OOB membership event %s", event) return + # Finally, there are some other events that we should not send out + # until someone asks for them. They are explicitly flagged as such + # with `proactively_send: False`. if not event.internal_metadata.should_proactively_send(): + logger.debug( + "Not sending event with proactively_send=false: %s", event + ) return destinations: Optional[Set[str]] = None @@ -419,7 +470,10 @@ class FederationSender(AbstractFederationSender): "federation_sender" ).observe((now - ts) / 1000) - async def handle_room_events(events: Iterable[EventBase]) -> None: + async def handle_room_events(events: List[EventBase]) -> None: + logger.debug( + "Handling %i events in room %s", len(events), events[0].room_id + ) with Measure(self.clock, "handle_room_events"): for event in events: await handle_event(event) @@ -438,6 +492,7 @@ class FederationSender(AbstractFederationSender): ) ) + logger.debug("Successfully handled up to %i", next_token) await self.store.update_federation_out_pos("events", next_token) if events: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 1421050b9a..9ce06dfa28 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -229,21 +229,21 @@ class TransportLayerClient: """ logger.debug( "send_data dest=%s, txid=%s", - transaction.destination, # type: ignore - transaction.transaction_id, # type: ignore + transaction.destination, + transaction.transaction_id, ) - if transaction.destination == self.server_name: # type: ignore + if transaction.destination == self.server_name: raise RuntimeError("Transport layer cannot send to itself!") # FIXME: This is only used by the tests. The actual json sent is # generated by the json_data_callback. json_data = transaction.get_dict() - path = _create_v1_path("/send/%s", transaction.transaction_id) # type: ignore + path = _create_v1_path("/send/%s", transaction.transaction_id) return await self.client.put_json( - transaction.destination, # type: ignore + transaction.destination, path=path, data=json_data, json_data_callback=json_data_callback, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 1b57840506..b3894666cc 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -416,7 +416,7 @@ class ApplicationServicesHandler: return typing async def _handle_receipts( - self, service: ApplicationService, new_token: Optional[int] + self, service: ApplicationService, new_token: int ) -> List[JsonDict]: """ Return the latest read receipts that the given application service should receive. @@ -447,7 +447,7 @@ class ApplicationServicesHandler: receipts_source = self.event_sources.sources.receipt receipts, _ = await receipts_source.get_new_events_as( - service=service, from_key=from_key + service=service, from_key=from_key, to_key=new_token ) return receipts diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 86991d26ce..ad41337b28 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -481,7 +481,7 @@ class AuthHandler: sid = authdict["session"] # Convert the URI and method to strings. - uri = request.uri.decode("utf-8") # type: ignore + uri = request.uri.decode("utf-8") method = request.method.decode("utf-8") # If there's no session ID, create a new session. @@ -551,7 +551,7 @@ class AuthHandler: await self.store.set_ui_auth_clientdict(sid, clientdict) user_agent = get_request_user_agent(request) - clientip = request.getClientIP() + clientip = request.getClientAddress().host await self.store.add_user_agent_ip_to_ui_auth_session( session.session_id, user_agent, clientip diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 5b94b00bc3..82a5aac3dd 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -164,7 +164,7 @@ class EventHandler: event. """ redact_behaviour = ( - EventRedactBehaviour.AS_IS if show_redacted else EventRedactBehaviour.REDACT + EventRedactBehaviour.as_is if show_redacted else EventRedactBehaviour.redact ) event = await self.store.get_event( event_id, check_room_id=room_id, redact_behaviour=redact_behaviour diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d2ba70a814..38dc5b1f6e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -316,7 +316,7 @@ class FederationHandler: events_to_check = await self.store.get_events_as_list( event_ids_to_check, - redact_behaviour=EventRedactBehaviour.AS_IS, + redact_behaviour=EventRedactBehaviour.as_is, get_prev_content=False, ) @@ -1494,7 +1494,7 @@ class FederationHandler: events = await self.store.get_events_as_list( batch, - redact_behaviour=EventRedactBehaviour.AS_IS, + redact_behaviour=EventRedactBehaviour.as_is, allow_rejected=True, ) for event in events: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 693b544286..6cf927e4ff 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -860,7 +860,7 @@ class FederationEventHandler: evs = await self._store.get_events( list(state_map.values()), get_prev_content=False, - redact_behaviour=EventRedactBehaviour.AS_IS, + redact_behaviour=EventRedactBehaviour.as_is, ) event_map.update(evs) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c183e9c465..9bca2bc4b2 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -92,7 +92,7 @@ class IdentityHandler: """ await self._3pid_validation_ratelimiter_ip.ratelimit( - None, (medium, request.getClientIP()) + None, (medium, request.getClientAddress().host) ) await self._3pid_validation_ratelimiter_address.ratelimit( None, (medium, address) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index a7db8feb57..7b94770f97 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -143,7 +143,7 @@ class InitialSyncHandler: to_key=int(now_token.receipt_key), ) if self.hs.config.experimental.msc2285_enabled: - receipt = ReceiptEventSource.filter_out_hidden(receipt, user_id) + receipt = ReceiptEventSource.filter_out_private(receipt, user_id) tags_by_room = await self.store.get_tags_for_user(user_id) @@ -449,7 +449,7 @@ class InitialSyncHandler: if not receipts: return [] if self.hs.config.experimental.msc2285_enabled: - receipts = ReceiptEventSource.filter_out_hidden(receipts, user_id) + receipts = ReceiptEventSource.filter_out_private(receipts, user_id) return receipts presence, receipts, (messages, token) = await make_deferred_yieldable( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 1b092e900e..95a89ac01f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1407,7 +1407,7 @@ class EventCreationHandler: original_event = await self.store.get_event( event.redacts, - redact_behaviour=EventRedactBehaviour.AS_IS, + redact_behaviour=EventRedactBehaviour.as_is, get_prev_content=False, allow_rejected=False, allow_none=True, @@ -1504,7 +1504,7 @@ class EventCreationHandler: original_event = await self.store.get_event( event.redacts, - redact_behaviour=EventRedactBehaviour.AS_IS, + redact_behaviour=EventRedactBehaviour.as_is, get_prev_content=False, allow_rejected=False, allow_none=True, diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 724b9cfcb4..f6ffb7d18d 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -966,7 +966,7 @@ class OidcProvider: "Mapping provider does not support de-duplicating Matrix IDs" ) - attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore + attributes = await self._user_mapping_provider.map_user_attributes( userinfo, token ) diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py new file mode 100644 index 0000000000..2599160bcc --- /dev/null +++ b/synapse/handlers/push_rules.py @@ -0,0 +1,138 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, List, Optional, Union + +import attr + +from synapse.api.errors import SynapseError, UnrecognizedRequestError +from synapse.push.baserules import BASE_RULE_IDS +from synapse.storage.push_rule import RuleNotFoundException +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RuleSpec: + scope: str + template: str + rule_id: str + attr: Optional[str] + + +class PushRulesHandler: + """A class to handle changes in push rules for users.""" + + def __init__(self, hs: "HomeServer"): + self._notifier = hs.get_notifier() + self._main_store = hs.get_datastores().main + + async def set_rule_attr( + self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict] + ) -> None: + """Set an attribute (enabled or actions) on an existing push rule. + + Notifies listeners (e.g. sync handler) of the change. + + Args: + user_id: the user for which to modify the push rule. + spec: the spec of the push rule to modify. + val: the value to change the attribute to. + + Raises: + RuleNotFoundException if the rule being modified doesn't exist. + SynapseError(400) if the value is malformed. + UnrecognizedRequestError if the attribute to change is unknown. + InvalidRuleException if we're trying to change the actions on a rule but + the provided actions aren't compliant with the spec. + """ + if spec.attr not in ("enabled", "actions"): + # for the sake of potential future expansion, shouldn't report + # 404 in the case of an unknown request so check it corresponds to + # a known attribute first. + raise UnrecognizedRequestError() + + namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}" + rule_id = spec.rule_id + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise RuleNotFoundException("Unknown rule %r" % (namespaced_rule_id,)) + if spec.attr == "enabled": + if isinstance(val, dict) and "enabled" in val: + val = val["enabled"] + if not isinstance(val, bool): + # Legacy fallback + # This should *actually* take a dict, but many clients pass + # bools directly, so let's not break them. + raise SynapseError(400, "Value for 'enabled' must be boolean") + await self._main_store.set_push_rule_enabled( + user_id, namespaced_rule_id, val, is_default_rule + ) + elif spec.attr == "actions": + if not isinstance(val, dict): + raise SynapseError(400, "Value must be a dict") + actions = val.get("actions") + if not isinstance(actions, list): + raise SynapseError(400, "Value for 'actions' must be dict") + check_actions(actions) + rule_id = spec.rule_id + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise RuleNotFoundException( + "Unknown rule %r" % (namespaced_rule_id,) + ) + await self._main_store.set_push_rule_actions( + user_id, namespaced_rule_id, actions, is_default_rule + ) + else: + raise UnrecognizedRequestError() + + self.notify_user(user_id) + + def notify_user(self, user_id: str) -> None: + """Notify listeners about a push rule change. + + Args: + user_id: the user ID the change is for. + """ + stream_id = self._main_store.get_max_push_rules_stream_id() + self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) + + +def check_actions(actions: List[Union[str, JsonDict]]) -> None: + """Check if the given actions are spec compliant. + + Args: + actions: the actions to check. + + Raises: + InvalidRuleException if the rules aren't compliant with the spec. + """ + if not isinstance(actions, list): + raise InvalidRuleException("No actions found") + + for a in actions: + if a in ["notify", "dont_notify", "coalesce"]: + pass + elif isinstance(a, dict) and "set_tweak" in a: + pass + else: + raise InvalidRuleException("Unrecognised action %s" % a) + + +class InvalidRuleException(Exception): + pass diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 6250bb3bdf..43d615357b 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple -from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes +from synapse.api.constants import ReceiptTypes from synapse.appservice import ApplicationService from synapse.streams import EventSource from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id @@ -112,7 +112,7 @@ class ReceiptsHandler: ) if not res: - # res will be None if this read receipt is 'old' + # res will be None if this receipt is 'old' continue stream_id, max_persisted_id = res @@ -138,7 +138,7 @@ class ReceiptsHandler: return True async def received_client_receipt( - self, room_id: str, receipt_type: str, user_id: str, event_id: str, hidden: bool + self, room_id: str, receipt_type: str, user_id: str, event_id: str ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. @@ -148,16 +148,14 @@ class ReceiptsHandler: receipt_type=receipt_type, user_id=user_id, event_ids=[event_id], - data={"ts": int(self.clock.time_msec()), "hidden": hidden}, + data={"ts": int(self.clock.time_msec())}, ) is_new = await self._handle_new_receipts([receipt]) if not is_new: return - if self.federation_sender and not ( - self.hs.config.experimental.msc2285_enabled and hidden - ): + if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE: await self.federation_sender.send_read_receipt(receipt) @@ -167,46 +165,37 @@ class ReceiptEventSource(EventSource[int, JsonDict]): self.config = hs.config @staticmethod - def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]: + def filter_out_private(events: List[JsonDict], user_id: str) -> List[JsonDict]: + """ + This method takes in what is returned by + get_linearized_receipts_for_rooms() and goes through read receipts + filtering out m.read.private receipts if they were not sent by the + current user. + """ + visible_events = [] - # filter out hidden receipts the user shouldn't see + # filter out private receipts the user shouldn't see for event in events: content = event.get("content", {}) new_event = event.copy() new_event["content"] = {} - for event_id in content.keys(): - event_content = content.get(event_id, {}) - m_read = event_content.get(ReceiptTypes.READ, {}) - - # If m_read is missing copy over the original event_content as there is nothing to process here - if not m_read: - new_event["content"][event_id] = event_content.copy() - continue - - new_users = {} - for rr_user_id, user_rr in m_read.items(): - try: - hidden = user_rr.get("hidden") - except AttributeError: - # Due to https://github.com/matrix-org/synapse/issues/10376 - # there are cases where user_rr is a string, in those cases - # we just ignore the read receipt - continue - - if hidden is not True or rr_user_id == user_id: - new_users[rr_user_id] = user_rr.copy() - # If hidden has a value replace hidden with the correct prefixed key - if hidden is not None: - new_users[rr_user_id].pop("hidden") - new_users[rr_user_id][ - ReadReceiptEventFields.MSC2285_HIDDEN - ] = hidden - - # Set new users unless empty - if len(new_users.keys()) > 0: - new_event["content"][event_id] = {ReceiptTypes.READ: new_users} + for event_id, event_content in content.items(): + receipt_event = {} + for receipt_type, receipt_content in event_content.items(): + if receipt_type == ReceiptTypes.READ_PRIVATE: + user_rr = receipt_content.get(user_id, None) + if user_rr: + receipt_event[ReceiptTypes.READ_PRIVATE] = { + user_id: user_rr.copy() + } + else: + receipt_event[receipt_type] = receipt_content.copy() + + # Only include the receipt event if it is non-empty. + if receipt_event: + new_event["content"][event_id] = receipt_event # Append new_event to visible_events unless empty if len(new_event["content"].keys()) > 0: @@ -234,18 +223,19 @@ class ReceiptEventSource(EventSource[int, JsonDict]): ) if self.config.experimental.msc2285_enabled: - events = ReceiptEventSource.filter_out_hidden(events, user.to_string()) + events = ReceiptEventSource.filter_out_private(events, user.to_string()) return events, to_key async def get_new_events_as( - self, from_key: int, service: ApplicationService + self, from_key: int, to_key: int, service: ApplicationService ) -> Tuple[List[JsonDict], int]: """Returns a set of new read receipt events that an appservice may be interested in. Args: from_key: the stream position at which events should be fetched from + to_key: the stream position up to which events should be fetched to service: The appservice which may be interested Returns: @@ -255,7 +245,6 @@ class ReceiptEventSource(EventSource[int, JsonDict]): * The current read receipt stream token. """ from_key = int(from_key) - to_key = self.get_current_key() if from_key == to_key: return [], to_key diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 5efb561273..c2754ec918 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections.abc import logging from typing import ( TYPE_CHECKING, @@ -24,7 +25,6 @@ from typing import ( ) import attr -from frozendict import frozendict from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError @@ -44,8 +44,6 @@ logger = logging.getLogger(__name__) class _ThreadAggregation: # The latest event in the thread. latest_event: EventBase - # The latest edit to the latest event in the thread. - latest_edit: Optional[EventBase] # The total number of events in the thread. count: int # True if the current user has sent an event to the thread. @@ -295,7 +293,7 @@ class RelationsHandler: for event_id, summary in summaries.items(): if summary: - thread_count, latest_thread_event, edit = summary + thread_count, latest_thread_event = summary # Subtract off the count of any ignored users. for ignored_user in ignored_users: @@ -340,7 +338,6 @@ class RelationsHandler: results[event_id] = _ThreadAggregation( latest_event=latest_thread_event, - latest_edit=edit, count=thread_count, # If there's a thread summary it must also exist in the # participated dictionary. @@ -359,15 +356,37 @@ class RelationsHandler: user_id: The user requesting the bundled aggregations. Returns: - A map of event ID to the bundled aggregation for the event. Not all - events may have bundled aggregations in the results. + A map of event ID to the bundled aggregations for the event. + + Not all requested events may exist in the results (if they don't have + bundled aggregations). + + The results may include additional events which are related to the + requested events. """ - # De-duplicate events by ID to handle the same event requested multiple times. - # - # State events do not get bundled aggregations. - events_by_id = { - event.event_id: event for event in events if not event.is_state() - } + # De-duplicated events by ID to handle the same event requested multiple times. + events_by_id = {} + # A map of event ID to the relation in that event, if there is one. + relations_by_id: Dict[str, str] = {} + for event in events: + # State events do not get bundled aggregations. + if event.is_state(): + continue + + relates_to = event.content.get("m.relates_to") + relation_type = None + if isinstance(relates_to, collections.abc.Mapping): + relation_type = relates_to.get("rel_type") + # An event which is a replacement (ie edit) or annotation (ie, + # reaction) may not have any other event related to it. + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + continue + + # The event should get bundled aggregations. + events_by_id[event.event_id] = event + # Track the event's relation information for later. + if isinstance(relation_type, str): + relations_by_id[event.event_id] = relation_type # event ID -> bundled aggregation in non-serialized form. results: Dict[str, BundledAggregations] = {} @@ -375,16 +394,34 @@ class RelationsHandler: # Fetch any ignored users of the requesting user. ignored_users = await self._main_store.ignored_users(user_id) + # Threads are special as the latest event of a thread might cause additional + # events to be fetched. Thus, we check those first! + + # Fetch thread summaries (but only for the directly requested events). + threads = await self.get_threads_for_events( + # It is not valid to start a thread on an event which itself relates to another event. + [eid for eid in events_by_id.keys() if eid not in relations_by_id], + user_id, + ignored_users, + ) + for event_id, thread in threads.items(): + results.setdefault(event_id, BundledAggregations()).thread = thread + + # If the latest event in a thread is not already being fetched, + # add it. This ensures that the bundled aggregations for the + # latest thread event is correct. + latest_thread_event = thread.latest_event + if latest_thread_event and latest_thread_event.event_id not in events_by_id: + events_by_id[latest_thread_event.event_id] = latest_thread_event + # Keep relations_by_id in sync with events_by_id: + # + # We know that the latest event in a thread has a thread relation + # (as that is what makes it part of the thread). + relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD + # Fetch other relations per event. for event in events_by_id.values(): - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - continue - + # Fetch any annotations (ie, reactions) to bundle with this event. annotations = await self.get_annotations_for_event( event.event_id, event.room_id, ignored_users=ignored_users ) @@ -393,6 +430,7 @@ class RelationsHandler: event.event_id, BundledAggregations() ).annotations = {"chunk": annotations} + # Fetch any references to bundle with this event. references, next_token = await self.get_relations_for_event( event.event_id, event, @@ -425,10 +463,4 @@ class RelationsHandler: for event_id, edit in edits.items(): results.setdefault(event_id, BundledAggregations()).replace = edit - threads = await self.get_threads_for_events( - events_by_id.keys(), user_id, ignored_users - ) - for event_id, thread in threads.items(): - results.setdefault(event_id, BundledAggregations()).thread = thread - return results diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 486145f48a..ff24ec8063 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -105,6 +105,7 @@ class RoomSummaryHandler: hs.get_clock(), "get_room_hierarchy", ) + self._msc3266_enabled = hs.config.experimental.msc3266_enabled async def get_room_hierarchy( self, @@ -630,7 +631,7 @@ class RoomSummaryHandler: return False async def _is_remote_room_accessible( - self, requester: str, room_id: str, room: JsonDict + self, requester: Optional[str], room_id: str, room: JsonDict ) -> bool: """ Calculate whether the room received over federation should be shown to the requester. @@ -645,7 +646,8 @@ class RoomSummaryHandler: due to an invite, etc. Args: - requester: The user requesting the summary. + requester: The user requesting the summary. If not passed only world + readability is checked. room_id: The room ID returned over federation. room: The summary of the room returned over federation. @@ -659,6 +661,8 @@ class RoomSummaryHandler: or room.get("world_readable") is True ): return True + elif not requester: + return False # Check if the user is a member of any of the allowed rooms from the response. allowed_rooms = room.get("allowed_room_ids") @@ -715,6 +719,10 @@ class RoomSummaryHandler: "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), } + if self._msc3266_enabled: + entry["im.nheko.summary.version"] = stats["version"] + entry["im.nheko.summary.encryption"] = stats["encryption"] + # Federation requests need to provide additional information so the # requested server is able to filter the response appropriately. if for_federation: @@ -812,9 +820,45 @@ class RoomSummaryHandler: room_summary["membership"] = membership or "leave" else: - # TODO federation API, descoped from initial unstable implementation - # as MSC needs more maturing on that side. - raise SynapseError(400, "Federation is not currently supported.") + # Reuse the hierarchy query over federation + if remote_room_hosts is None: + raise SynapseError(400, "Missing via to query remote room") + + ( + room_entry, + children_room_entries, + inaccessible_children, + ) = await self._summarize_remote_room_hierarchy( + _RoomQueueEntry(room_id, remote_room_hosts), + suggested_only=True, + ) + + # The results over federation might include rooms that we, as the + # requesting server, are allowed to see, but the requesting user is + # not permitted to see. + # + # Filter the returned results to only what is accessible to the user. + if not room_entry or not await self._is_remote_room_accessible( + requester, room_entry.room_id, room_entry.room + ): + raise NotFoundError("Room not found or is not accessible") + + room = dict(room_entry.room) + room.pop("allowed_room_ids", None) + + # If there was a requester, add their membership. + # We keep the membership in the local membership table unless the + # room is purged even for remote rooms. + if requester: + ( + membership, + _, + ) = await self._store.get_local_current_membership_for_user_in_room( + requester, room_id + ) + room["membership"] = membership or "leave" + + return room return room_summary diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 102dd4b57d..5619f8f50e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -357,7 +357,7 @@ class SearchHandler: itertools.chain( # The events_before and events_after for each context. itertools.chain.from_iterable( - itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type] + itertools.chain(context["events_before"], context["events_after"]) for context in contexts.values() ), # The returned events. @@ -373,10 +373,10 @@ class SearchHandler: for context in contexts.values(): context["events_before"] = self._event_serializer.serialize_events( - context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] + context["events_before"], time_now, bundle_aggregations=aggregations ) context["events_after"] = self._event_serializer.serialize_events( - context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] + context["events_after"], time_now, bundle_aggregations=aggregations ) results = [ diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e4fe94e557..1e171f3f71 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -468,7 +468,7 @@ class SsoHandler: auth_provider_id, remote_user_id, get_request_user_agent(request), - request.getClientIP(), + request.getClientAddress().host, ) new_user = True elif self._sso_update_profile_information: @@ -928,7 +928,7 @@ class SsoHandler: session.auth_provider_id, session.remote_user_id, get_request_user_agent(request), - request.getClientIP(), + request.getClientAddress().host, ) logger.info( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5125126a80..2c555a66d0 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1045,7 +1045,7 @@ class SyncHandler: last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, - receipt_type=ReceiptTypes.READ, + receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) return await self.store.get_unread_event_push_actions_by_room_for_user( diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 472b029af3..05cebb5d4d 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -256,7 +256,9 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs - self._enabled = bool(hs.config.registration.registration_requires_token) + self._enabled = bool( + hs.config.registration.registration_requires_token + ) or bool(hs.config.registration.enable_registration_token_3pid_bypass) self.store = hs.get_datastores().main def is_enabled(self) -> bool: diff --git a/synapse/http/server.py b/synapse/http/server.py index 31ca841889..657bffcddd 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -43,6 +43,7 @@ from typing_extensions import Protocol from zope.interface import implementer from twisted.internet import defer, interfaces +from twisted.internet.defer import CancelledError from twisted.python import failure from twisted.web import resource from twisted.web.server import NOT_DONE_YET, Request @@ -82,6 +83,14 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html> </html> """ +# A fictional HTTP status code for requests where the client has disconnected and we +# successfully cancelled the request. Used only for logging purposes. Clients will never +# observe this code unless cancellations leak across requests or we raise a +# `CancelledError` ourselves. +# Analogous to nginx's 499 status code: +# https://github.com/nginx/nginx/blob/release-1.21.6/src/http/ngx_http_request.h#L128-L134 +HTTP_STATUS_REQUEST_CANCELLED = 499 + def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: """Sends a JSON error response to clients.""" @@ -93,6 +102,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: error_dict = exc.error_dict() logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) + elif f.check(CancelledError): + error_code = HTTP_STATUS_REQUEST_CANCELLED + error_dict = {"error": "Request cancelled", "errcode": Codes.UNKNOWN} + + if not request._disconnected: + logger.error( + "Got cancellation before client disconnection from %r: %r", + request.request_metrics.name, + request, + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] + ) else: error_code = 500 error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN} @@ -155,6 +175,16 @@ def return_html_error( request, exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] ) + elif f.check(CancelledError): + code = HTTP_STATUS_REQUEST_CANCELLED + msg = "Request cancelled" + + if not request._disconnected: + logger.error( + "Got cancellation before client disconnection when handling request %r", + request, + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] + ) else: code = HTTPStatus.INTERNAL_SERVER_ERROR msg = "Internal server error" @@ -295,7 +325,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): if isawaitable(raw_callback_return): callback_return = await raw_callback_return else: - callback_return = raw_callback_return # type: ignore + callback_return = raw_callback_return return callback_return @@ -469,7 +499,7 @@ class JsonResource(DirectServeJsonResource): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): callback_return = await raw_callback_return else: - callback_return = raw_callback_return # type: ignore + callback_return = raw_callback_return return callback_return @@ -683,6 +713,9 @@ def respond_with_json( Returns: twisted.web.server.NOT_DONE_YET if the request is still active. """ + # The response code must always be set, for logging purposes. + request.setResponseCode(code) + # could alternatively use request.notifyFinish() and flip a flag when # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. @@ -697,7 +730,6 @@ def respond_with_json( else: encoder = _encode_json_bytes - request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") @@ -728,13 +760,15 @@ def respond_with_json_bytes( Returns: twisted.web.server.NOT_DONE_YET if the request is still active. """ + # The response code must always be set, for logging purposes. + request.setResponseCode(code) + if request._disconnected: logger.warning( "Not sending response to request %s, already disconnected.", request ) return None - request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") @@ -840,6 +874,9 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> N code: The HTTP response code. html_bytes: The HTML bytes to use as the response body. """ + # The response code must always be set, for logging purposes. + request.setResponseCode(code) + # could alternatively use request.notifyFinish() and flip a flag when # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. @@ -849,7 +886,6 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> N ) return None - request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) diff --git a/synapse/http/site.py b/synapse/http/site.py index 40f6c04894..0b85a57d77 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -238,7 +238,7 @@ class SynapseRequest(Request): request_id, request=ContextRequest( request_id=request_id, - ip_address=self.getClientIP(), + ip_address=self.getClientAddress().host, site_tag=self.synapse_site.site_tag, # The requester is going to be unknown at this point. requester=None, @@ -381,7 +381,7 @@ class SynapseRequest(Request): self.synapse_site.access_logger.debug( "%s - %s - Received request: %s %s", - self.getClientIP(), + self.getClientAddress().host, self.synapse_site.site_tag, self.get_method(), self.get_redacted_uri(), @@ -429,7 +429,7 @@ class SynapseRequest(Request): "%s - %s - {%s}" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" ' %sB %s "%s %s %s" "%s" [%d dbevts]', - self.getClientIP(), + self.getClientAddress().host, self.synapse_site.site_tag, requester, processing_time, diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 88cd8a9e1c..fd9cb97920 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -722,6 +722,11 @@ P = ParamSpec("P") R = TypeVar("R") +async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R: + """Unwraps an arbitrary awaitable by awaiting it.""" + return await awaitable + + @overload def preserve_fn( # type: ignore[misc] f: Callable[P, Awaitable[R]], @@ -802,17 +807,20 @@ def run_in_background( # type: ignore[misc] # by synchronous exceptions, so let's turn them into Failures. return defer.fail() + # `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain + # value. Convert it to a `Deferred`. if isinstance(res, typing.Coroutine): + # Wrap the coroutine in a `Deferred`. res = defer.ensureDeferred(res) - - # At this point we should have a Deferred, if not then f was a synchronous - # function, wrap it in a Deferred for consistency. - if not isinstance(res, defer.Deferred): - # `res` is not a `Deferred` and not a `Coroutine`. - # There are no other types of `Awaitable`s we expect to encounter in Synapse. - assert not isinstance(res, Awaitable) - - return defer.succeed(res) + elif isinstance(res, defer.Deferred): + pass + elif isinstance(res, Awaitable): + # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable` + # or `Future` from `make_awaitable`. + res = defer.ensureDeferred(_unwrap_awaitable(res)) + else: + # `res` is a plain value. Wrap it in a `Deferred`. + res = defer.succeed(res) if res.called and not res.paused: # The function should have maintained the logcontext, so we can diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index f86ee9aac7..a02b5bf6bd 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -884,7 +884,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False): tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, tags.HTTP_METHOD: request.get_method(), tags.HTTP_URL: request.get_redacted_uri(), - tags.PEER_HOST_IPV6: request.getClientIP(), + tags.PEER_HOST_IPV6: request.getClientAddress().host, } request_name = request.request_metrics.name diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 8f9e629274..834fe1b62c 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -82,6 +82,7 @@ from synapse.handlers.auth import ( ON_LOGGED_OUT_CALLBACK, AuthHandler, ) +from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient from synapse.http.server import ( DirectServeHtmlResource, @@ -109,6 +110,7 @@ from synapse.storage.state import StateFilter from synapse.types import ( DomainSpecificString, JsonDict, + JsonMapping, Requester, StateMap, UserID, @@ -151,6 +153,7 @@ __all__ = [ "PRESENCE_ALL_USERS", "LoginResponse", "JsonDict", + "JsonMapping", "EventBase", "StateMap", "ProfileInfo", @@ -193,6 +196,7 @@ class ModuleApi: self._clock: Clock = hs.get_clock() self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() + self._push_rules_handler = hs.get_push_rules_handler() self.custom_template_dir = hs.config.server.custom_template_directory try: @@ -1350,6 +1354,68 @@ class ModuleApi: """ await self._store.add_user_bound_threepid(user_id, medium, address, id_server) + def check_push_rule_actions( + self, actions: List[Union[str, Dict[str, str]]] + ) -> None: + """Checks if the given push rule actions are valid according to the Matrix + specification. + + See https://spec.matrix.org/v1.2/client-server-api/#actions for the list of valid + actions. + + Added in Synapse v1.58.0. + + Args: + actions: the actions to check. + + Raises: + synapse.module_api.errors.InvalidRuleException if the actions are invalid. + """ + check_actions(actions) + + async def set_push_rule_action( + self, + user_id: str, + scope: str, + kind: str, + rule_id: str, + actions: List[Union[str, Dict[str, str]]], + ) -> None: + """Changes the actions of an existing push rule for the given user. + + See https://spec.matrix.org/v1.2/client-server-api/#push-rules for more + information about push rules and their syntax. + + Can only be called on the main process. + + Added in Synapse v1.58.0. + + Args: + user_id: the user for which to change the push rule's actions. + scope: the push rule's scope, currently only "global" is allowed. + kind: the push rule's kind. + rule_id: the push rule's identifier. + actions: the actions to run when the rule's conditions match. + + Raises: + RuntimeError if this method is called on a worker or `scope` is invalid. + synapse.module_api.errors.RuleNotFoundException if the rule being modified + can't be found. + synapse.module_api.errors.InvalidRuleException if the actions are invalid. + """ + if self.worker_app is not None: + raise RuntimeError("module tried to change push rule actions on a worker") + + if scope != "global": + raise RuntimeError( + "invalid scope %s, only 'global' is currently allowed" % scope + ) + + spec = RuleSpec(scope, kind, rule_id, "actions") + await self._push_rules_handler.set_rule_attr( + user_id, spec, {"actions": actions} + ) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room @@ -1419,7 +1485,7 @@ class AccountDataManager: f"{user_id} is not local to this homeserver; can't access account data for remote users." ) - async def get_global(self, user_id: str, data_type: str) -> Optional[JsonDict]: + async def get_global(self, user_id: str, data_type: str) -> Optional[JsonMapping]: """ Gets some global account data, of a specified type, for the specified user. diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py index 1db900e41f..e58e0e60fe 100644 --- a/synapse/module_api/errors.py +++ b/synapse/module_api/errors.py @@ -20,10 +20,14 @@ from synapse.api.errors import ( SynapseError, ) from synapse.config._base import ConfigError +from synapse.handlers.push_rules import InvalidRuleException +from synapse.storage.push_rule import RuleNotFoundException __all__ = [ "InvalidClientCredentialsError", "RedirectException", "SynapseError", "ConfigError", + "InvalidRuleException", + "RuleNotFoundException", ] diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 957c9b780b..a1bf5b20dd 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -24,7 +24,9 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ) + my_receipts_by_room = await store.get_receipts_for_user( + user_id, (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE) + ) badge = len(invites) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 122892c7bc..350762f494 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -21,7 +21,7 @@ from twisted.internet.interfaces import IAddress, IConnector from twisted.internet.protocol import ReconnectingClientFactory from twisted.python.failure import Failure -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, ReceiptTypes from synapse.federation import send_queue from synapse.federation.sender import FederationSender from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable @@ -401,10 +401,8 @@ class FederationSenderHandler: # we only want to send on receipts for our own users if not self._is_mine_id(receipt.user_id): continue - if ( - receipt.data.get("hidden", False) - and self._hs.config.experimental.msc2285_enabled - ): + # Private read receipts never get sent over federation. + if receipt.receipt_type == ReceiptTypes.READ_PRIVATE: continue receipt_info = ReadReceipt( receipt.room_id, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 615f1828dd..9aba1cd451 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -537,7 +537,7 @@ class ReplicationCommandHandler: # Ignore POSITION that are just our own echoes return - logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) + logger.debug("Handling '%s %s'", cmd.NAME, cmd.to_line()) self._add_command_to_stream_queue(conn, cmd) @@ -567,6 +567,11 @@ class ReplicationCommandHandler: # between then and now. missing_updates = cmd.prev_token != current_token while missing_updates: + # Note: There may very well not be any new updates, but we check to + # make sure. This can particularly happen for the event stream where + # event persisters continuously send `POSITION`. See `resource.py` + # for why this can happen. + logger.info( "Fetching replication rows for '%s' between %i and %i", stream_name, @@ -590,7 +595,7 @@ class ReplicationCommandHandler: [stream.parse_row(row) for row in rows], ) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token) + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token) # We've now caught up to position sent to us, notify handler. await self._replication_data_handler.on_position( diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index c6870df8f9..99f09669f0 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -204,6 +204,15 @@ class ReplicationStreamer: # turns out that e.g. account data streams share # their "current token" with each other, meaning # that it is *not* safe to send a POSITION. + + # Note: `last_token` may not *actually* be the + # last token we sent out in a RDATA or POSITION. + # This can happen if we sent out an RDATA for + # position X when our current token was say X+1. + # Other workers will see RDATA for X and then a + # POSITION with last token of X+1, which will + # cause them to check if there were any missing + # updates between X and X+1. logger.info( "Sending position: %s -> %s", stream.NAME, diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 5587cae98a..bdc4a9c068 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -882,9 +882,7 @@ class WhoamiRestServlet(RestServlet): response = { "user_id": requester.user.to_string(), - # MSC: https://github.com/matrix-org/matrix-doc/pull/3069 # Entered spec in Matrix 1.2 - "org.matrix.msc3069.is_guest": bool(requester.is_guest), "is_guest": bool(requester.is_guest), } diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index e0b2b80e5b..eb77337044 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -112,7 +112,7 @@ class AuthRestServlet(RestServlet): try: await self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, authdict, request.getClientIP() + LoginType.RECAPTCHA, authdict, request.getClientAddress().host ) except LoginError as e: # Authentication failed, let user try again @@ -132,7 +132,7 @@ class AuthRestServlet(RestServlet): try: await self.auth_handler.add_oob_auth( - LoginType.TERMS, authdict, request.getClientIP() + LoginType.TERMS, authdict, request.getClientAddress().host ) except LoginError as e: # Authentication failed, let user try again @@ -161,7 +161,9 @@ class AuthRestServlet(RestServlet): try: await self.auth_handler.add_oob_auth( - LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP() + LoginType.REGISTRATION_TOKEN, + authdict, + request.getClientAddress().host, ) except LoginError as e: html = self.registration_token_template.render( diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 4a4dbe75de..cf4196ac0a 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -69,9 +69,7 @@ class LoginRestServlet(RestServlet): SSO_TYPE = "m.login.sso" TOKEN_TYPE = "m.login.token" JWT_TYPE = "org.matrix.login.jwt" - JWT_TYPE_DEPRECATED = "m.login.jwt" APPSERVICE_TYPE = "m.login.application_service" - APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service" REFRESH_TOKEN_PARAM = "refresh_token" def __init__(self, hs: "HomeServer"): @@ -126,7 +124,6 @@ class LoginRestServlet(RestServlet): flows: List[JsonDict] = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) - flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED}) if self.cas_enabled: # we advertise CAS for backwards compat, though MSC1721 renamed it @@ -156,7 +153,6 @@ class LoginRestServlet(RestServlet): flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) - flows.append({"type": LoginRestServlet.APPSERVICE_TYPE_UNSTABLE}) return 200, {"flows": flows} @@ -175,15 +171,12 @@ class LoginRestServlet(RestServlet): ) try: - if login_submission["type"] in ( - LoginRestServlet.APPSERVICE_TYPE, - LoginRestServlet.APPSERVICE_TYPE_UNSTABLE, - ): + if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: appservice = self.auth.get_appservice_by_req(request) if appservice.is_rate_limited(): await self._address_ratelimiter.ratelimit( - None, request.getClientIP() + None, request.getClientAddress().host ) result = await self._do_appservice_login( @@ -191,23 +184,29 @@ class LoginRestServlet(RestServlet): appservice, should_issue_refresh_token=should_issue_refresh_token, ) - elif self.jwt_enabled and ( - login_submission["type"] == LoginRestServlet.JWT_TYPE - or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED + elif ( + self.jwt_enabled + and login_submission["type"] == LoginRestServlet.JWT_TYPE ): - await self._address_ratelimiter.ratelimit(None, request.getClientIP()) + await self._address_ratelimiter.ratelimit( + None, request.getClientAddress().host + ) result = await self._do_jwt_login( login_submission, should_issue_refresh_token=should_issue_refresh_token, ) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - await self._address_ratelimiter.ratelimit(None, request.getClientIP()) + await self._address_ratelimiter.ratelimit( + None, request.getClientAddress().host + ) result = await self._do_token_login( login_submission, should_issue_refresh_token=should_issue_refresh_token, ) else: - await self._address_ratelimiter.ratelimit(None, request.getClientIP()) + await self._address_ratelimiter.ratelimit( + None, request.getClientAddress().host + ) result = await self._do_other_login( login_submission, should_issue_refresh_token=should_issue_refresh_token, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index ff040de6b8..24bc7c9095 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -58,7 +58,7 @@ class NotificationsServlet(RestServlet): ) receipts_by_room = await self.store.get_receipts_for_user_with_orderings( - user_id, ReceiptTypes.READ + user_id, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) notif_event_ids = [pa.event_id for pa in push_actions] diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index a93f6fd5e0..b98640b14a 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union - -import attr +from typing import TYPE_CHECKING, List, Sequence, Tuple, Union from synapse.api.errors import ( NotFoundError, @@ -22,6 +20,7 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) +from synapse.handlers.push_rules import InvalidRuleException, RuleSpec, check_actions from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -29,7 +28,6 @@ from synapse.http.servlet import ( parse_string, ) from synapse.http.site import SynapseRequest -from synapse.push.baserules import BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client._base import client_patterns @@ -40,14 +38,6 @@ if TYPE_CHECKING: from synapse.server import HomeServer -@attr.s(slots=True, frozen=True, auto_attribs=True) -class RuleSpec: - scope: str - template: str - rule_id: str - attr: Optional[str] - - class PushRuleRestServlet(RestServlet): PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True) SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( @@ -60,6 +50,7 @@ class PushRuleRestServlet(RestServlet): self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None + self._push_rules_handler = hs.get_push_rules_handler() async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: @@ -81,8 +72,13 @@ class PushRuleRestServlet(RestServlet): user_id = requester.user.to_string() if spec.attr: - await self.set_rule_attr(user_id, spec, content) - self.notify_user(user_id) + try: + await self._push_rules_handler.set_rule_attr(user_id, spec, content) + except InvalidRuleException as e: + raise SynapseError(400, "Invalid actions: %s" % e) + except RuleNotFoundException: + raise NotFoundError("Unknown rule") + return 200, {} if spec.rule_id.startswith("."): @@ -98,23 +94,23 @@ class PushRuleRestServlet(RestServlet): before = parse_string(request, "before") if before: - before = _namespaced_rule_id(spec, before) + before = f"global/{spec.template}/{before}" after = parse_string(request, "after") if after: - after = _namespaced_rule_id(spec, after) + after = f"global/{spec.template}/{after}" try: await self.store.add_push_rule( user_id=user_id, - rule_id=_namespaced_rule_id_from_spec(spec), + rule_id=f"global/{spec.template}/{spec.rule_id}", priority_class=priority_class, conditions=conditions, actions=actions, before=before, after=after, ) - self.notify_user(user_id) + self._push_rules_handler.notify_user(user_id) except InconsistentRuleException as e: raise SynapseError(400, str(e)) except RuleNotFoundException as e: @@ -133,11 +129,11 @@ class PushRuleRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}" try: await self.store.delete_push_rule(user_id, namespaced_rule_id) - self.notify_user(user_id) + self._push_rules_handler.notify_user(user_id) return 200, {} except StoreError as e: if e.code == 404: @@ -172,55 +168,6 @@ class PushRuleRestServlet(RestServlet): else: raise UnrecognizedRequestError() - def notify_user(self, user_id: str) -> None: - stream_id = self.store.get_max_push_rules_stream_id() - self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) - - async def set_rule_attr( - self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict] - ) -> None: - if spec.attr not in ("enabled", "actions"): - # for the sake of potential future expansion, shouldn't report - # 404 in the case of an unknown request so check it corresponds to - # a known attribute first. - raise UnrecognizedRequestError() - - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec.rule_id - is_default_rule = rule_id.startswith(".") - if is_default_rule: - if namespaced_rule_id not in BASE_RULE_IDS: - raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) - if spec.attr == "enabled": - if isinstance(val, dict) and "enabled" in val: - val = val["enabled"] - if not isinstance(val, bool): - # Legacy fallback - # This should *actually* take a dict, but many clients pass - # bools directly, so let's not break them. - raise SynapseError(400, "Value for 'enabled' must be boolean") - await self.store.set_push_rule_enabled( - user_id, namespaced_rule_id, val, is_default_rule - ) - elif spec.attr == "actions": - if not isinstance(val, dict): - raise SynapseError(400, "Value must be a dict") - actions = val.get("actions") - if not isinstance(actions, list): - raise SynapseError(400, "Value for 'actions' must be dict") - _check_actions(actions) - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec.rule_id - is_default_rule = rule_id.startswith(".") - if is_default_rule: - if namespaced_rule_id not in BASE_RULE_IDS: - raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - await self.store.set_push_rule_actions( - user_id, namespaced_rule_id, actions, is_default_rule - ) - else: - raise UnrecognizedRequestError() - def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec: """Turn a sequence of path components into a rule spec @@ -291,24 +238,11 @@ def _rule_tuple_from_request_object( raise InvalidRuleException("No actions found") actions = req_obj["actions"] - _check_actions(actions) + check_actions(actions) return conditions, actions -def _check_actions(actions: List[Union[str, JsonDict]]) -> None: - if not isinstance(actions, list): - raise InvalidRuleException("No actions found") - - for a in actions: - if a in ["notify", "dont_notify", "coalesce"]: - pass - elif isinstance(a, dict) and "set_tweak" in a: - pass - else: - raise InvalidRuleException("Unrecognised action") - - def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict: if path == []: raise UnrecognizedRequestError( @@ -357,17 +291,5 @@ def _priority_class_from_spec(spec: RuleSpec) -> int: return pc -def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str: - return _namespaced_rule_id(spec, spec.rule_id) - - -def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str: - return "global/%s/%s" % (spec.template, rule_id) - - -class InvalidRuleException(Exception): - pass - - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index f51be511d1..1583e903cd 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -15,8 +15,8 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes -from synapse.api.errors import Codes, SynapseError +from synapse.api.constants import ReceiptTypes +from synapse.api.errors import SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -36,6 +36,7 @@ class ReadMarkerRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() + self.config = hs.config self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() @@ -48,27 +49,38 @@ class ReadMarkerRestServlet(RestServlet): await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) - read_event_id = body.get(ReceiptTypes.READ, None) - hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) - if not isinstance(hidden, bool): + valid_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ} + if self.config.experimental.msc2285_enabled: + valid_receipt_types.add(ReceiptTypes.READ_PRIVATE) + + if set(body.keys()) > valid_receipt_types: raise SynapseError( 400, - "Param %s must be a boolean, if given" - % ReadReceiptEventFields.MSC2285_HIDDEN, - Codes.BAD_JSON, + "Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'" + if self.config.experimental.msc2285_enabled + else "Receipt type must be 'm.read' or 'm.fully_read'", ) + read_event_id = body.get(ReceiptTypes.READ, None) if read_event_id: await self.receipts_handler.received_client_receipt( room_id, ReceiptTypes.READ, user_id=requester.user.to_string(), event_id=read_event_id, - hidden=hidden, ) - read_marker_event_id = body.get("m.fully_read", None) + read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None) + if read_private_event_id and self.config.experimental.msc2285_enabled: + await self.receipts_handler.received_client_receipt( + room_id, + ReceiptTypes.READ_PRIVATE, + user_id=requester.user.to_string(), + event_id=read_private_event_id, + ) + + read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None) if read_marker_event_id: await self.read_marker_handler.received_client_read_marker( room_id, diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index b24ad2d1be..f9caab6635 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -16,8 +16,8 @@ import logging import re from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes -from synapse.api.errors import Codes, SynapseError +from synapse.api.constants import ReceiptTypes +from synapse.api.errors import SynapseError from synapse.http import get_request_user_agent from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -46,6 +46,7 @@ class ReceiptRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() + self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() async def on_POST( @@ -53,7 +54,19 @@ class ReceiptRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - if receipt_type != ReceiptTypes.READ: + if self.hs.config.experimental.msc2285_enabled and receipt_type not in [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.FULLY_READ, + ]: + raise SynapseError( + 400, + "Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'", + ) + elif ( + not self.hs.config.experimental.msc2285_enabled + and receipt_type != ReceiptTypes.READ + ): raise SynapseError(400, "Receipt type must be 'm.read'") # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body. @@ -62,26 +75,24 @@ class ReceiptRestServlet(RestServlet): if "Android" in user_agent: if pattern.match(user_agent) or "Riot" in user_agent: allow_empty_body = True - body = parse_json_object_from_request(request, allow_empty_body) - hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) - - if not isinstance(hidden, bool): - raise SynapseError( - 400, - "Param %s must be a boolean, if given" - % ReadReceiptEventFields.MSC2285_HIDDEN, - Codes.BAD_JSON, - ) + # This call makes sure possible empty body is handled correctly + parse_json_object_from_request(request, allow_empty_body) await self.presence_handler.bump_presence_active_time(requester.user) - await self.receipts_handler.received_client_receipt( - room_id, - receipt_type, - user_id=requester.user.to_string(), - event_id=event_id, - hidden=hidden, - ) + if receipt_type == ReceiptTypes.FULLY_READ: + await self.read_marker_handler.received_client_read_marker( + room_id, + user_id=requester.user.to_string(), + event_id=event_id, + ) + else: + await self.receipts_handler.received_client_receipt( + room_id, + receipt_type, + user_id=requester.user.to_string(), + event_id=event_id, + ) return 200, {} diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 70baf50fa4..e8e51a9c66 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -352,7 +352,7 @@ class UsernameAvailabilityRestServlet(RestServlet): if self.inhibit_user_in_use_error: return 200, {"available": True} - ip = request.getClientIP() + ip = request.getClientAddress().host with self.ratelimiter.ratelimit(ip) as wait_deferred: await wait_deferred @@ -394,7 +394,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: - await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) + await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,)) if not self.hs.config.registration.enable_registration: raise SynapseError( @@ -441,7 +441,7 @@ class RegisterRestServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) - client_addr = request.getClientIP() + client_addr = request.getClientAddress().host await self.ratelimiter.ratelimit(None, client_addr, update=False) @@ -929,6 +929,10 @@ def _calculate_registration_flows( # always let users provide both MSISDN & email flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) + # Add a flow that doesn't require any 3pids, if the config requests it. + if config.registration.enable_registration_token_3pid_bypass: + flows.append([LoginType.REGISTRATION_TOKEN]) + # Prepend m.login.terms to all flows if we're requiring consent if config.consent.user_consent_at_registration: for flow in flows: @@ -942,7 +946,8 @@ def _calculate_registration_flows( # Prepend registration token to all flows if we're requiring a token if config.registration.registration_requires_token: for flow in flows: - flow.insert(0, LoginType.REGISTRATION_TOKEN) + if LoginType.REGISTRATION_TOKEN not in flow: + flow.insert(0, LoginType.REGISTRATION_TOKEN) return flows diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index bfc1d4ee08..c1bd775fec 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -93,7 +93,7 @@ class VersionsRestServlet(RestServlet): "io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private, # Supports the busy presence state described in MSC3026. "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, - # Supports receiving hidden read receipts as per MSC2285 + # Supports receiving private read receipts as per MSC2285 "org.matrix.msc2285": self.config.experimental.msc2285_enabled, # Adds support for importing historical messages as per MSC2716 "org.matrix.msc2716": self.config.experimental.msc2716_enabled, diff --git a/synapse/server.py b/synapse/server.py index 37c72bd83a..d49c76518a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -91,6 +91,7 @@ from synapse.handlers.presence import ( WorkerPresenceHandler, ) from synapse.handlers.profile import ProfileHandler +from synapse.handlers.push_rules import PushRulesHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.register import RegistrationHandler @@ -811,6 +812,10 @@ class HomeServer(metaclass=abc.ABCMeta): return AccountHandler(self) @cache_in_self + def get_push_rules_handler(self) -> PushRulesHandler: + return PushRulesHandler(self) + + @cache_in_self def get_outbound_redis_connection(self) -> "ConnectionHandler": """ The Redis connection used for replication. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index fbf7ba4600..cad3b42640 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -800,7 +800,7 @@ class StateResolutionStore: return self.store.get_events( event_ids, - redact_behaviour=EventRedactBehaviour.AS_IS, + redact_behaviour=EventRedactBehaviour.as_is, get_prev_content=False, allow_rejected=allow_rejected, ) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 951031af50..5895b89202 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -15,12 +15,17 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from synapse.config.homeserver import HomeServerConfig -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.stats import UserSortOrder -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.types import Cursor from synapse.storage.util.id_generators import ( IdGenerator, MultiWriterIdGenerator, @@ -266,7 +271,9 @@ class DataStore( A tuple of a list of mappings from user to information and a count of total users. """ - def get_users_paginate_txn(txn): + def get_users_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: filters = [] args = [self.hs.config.server.server_name] @@ -301,7 +308,7 @@ class DataStore( """ sql = "SELECT COUNT(*) as total_users " + sql_base txn.execute(sql, args) - count = txn.fetchone()[0] + count = cast(Tuple[int], txn.fetchone())[0] sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, @@ -338,7 +345,9 @@ class DataStore( ) -def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig): +def check_database_before_upgrade( + cur: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig +) -> None: """Called before upgrading an existing database to check that it is broadly sane compared with the configuration. """ diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index fa732edcca..945707b0ec 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast from synapse.appservice import ( ApplicationService, @@ -83,7 +83,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): txn.execute( "SELECT COALESCE(max(txn_id), 0) FROM application_services_txns" ) - return txn.fetchone()[0] # type: ignore + return cast(Tuple[int], txn.fetchone())[0] self._as_txn_seq_gen = build_sequence_generator( db_conn, diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index b4a1b041b1..599b418383 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -14,7 +14,17 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + cast, +) from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -118,7 +128,13 @@ class DeviceInboxWorkerStore(SQLBaseStore): prefilled_cache=device_outbox_prefill, ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[ToDeviceStream.ToDeviceStreamRow], + ) -> None: if stream_name == ToDeviceStream.NAME: # If replication is happening than postgres must be being used. assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator) @@ -134,7 +150,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) return super().process_replication_rows(stream_name, instance_name, token, rows) - def get_to_device_stream_token(self): + def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() async def get_messages_for_user_devices( @@ -301,7 +317,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): if not user_ids_to_query: return {}, to_stream_id - def get_device_messages_txn(txn: LoggingTransaction): + def get_device_messages_txn( + txn: LoggingTransaction, + ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: # Build a query to select messages from any of the given devices that # are between the given stream id bounds. @@ -428,7 +446,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): log_kv({"message": "No changes in cache since last check"}) return 0 - def delete_messages_for_device_txn(txn): + def delete_messages_for_device_txn(txn: LoggingTransaction) -> int: sql = ( "DELETE FROM device_inbox" " WHERE user_id = ? AND device_id = ?" @@ -455,15 +473,14 @@ class DeviceInboxWorkerStore(SQLBaseStore): @trace async def get_new_device_msgs_for_remote( - self, destination, last_stream_id, current_stream_id, limit - ) -> Tuple[List[dict], int]: + self, destination: str, last_stream_id: int, current_stream_id: int, limit: int + ) -> Tuple[List[JsonDict], int]: """ Args: - destination(str): The name of the remote server. - last_stream_id(int|long): The last position of the device message stream + destination: The name of the remote server. + last_stream_id: The last position of the device message stream that the server sent up to. - current_stream_id(int|long): The current position of the device - message stream. + current_stream_id: The current position of the device message stream. Returns: A list of messages for the device and where in the stream the messages got to. """ @@ -485,7 +502,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): return [], last_stream_id @trace - def get_new_messages_for_remote_destination_txn(txn): + def get_new_messages_for_remote_destination_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: sql = ( "SELECT stream_id, messages_json FROM device_federation_outbox" " WHERE destination = ?" @@ -527,7 +546,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): up_to_stream_id: Where to delete messages up to. """ - def delete_messages_for_remote_destination_txn(txn): + def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None: sql = ( "DELETE FROM device_federation_outbox" " WHERE destination = ?" @@ -566,7 +585,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_new_device_messages_txn(txn): + def get_all_new_device_messages_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # We limit like this as we might have multiple rows per stream_id, and # we want to make sure we always get all entries for any stream_id # we return. @@ -607,8 +628,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): @trace async def add_messages_to_device_inbox( self, - local_messages_by_user_then_device: dict, - remote_messages_by_destination: dict, + local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], + remote_messages_by_destination: Dict[str, JsonDict], ) -> int: """Used to send messages from this server. @@ -624,7 +645,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): assert self._can_write_to_device - def add_messages_txn(txn, now_ms, stream_id): + def add_messages_txn( + txn: LoggingTransaction, now_ms: int, stream_id: int + ) -> None: # Add the local messages directly to the local inbox. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device @@ -677,11 +700,16 @@ class DeviceInboxWorkerStore(SQLBaseStore): return self._device_inbox_id_gen.get_current_token() async def add_messages_from_remote_to_device_inbox( - self, origin: str, message_id: str, local_messages_by_user_then_device: dict + self, + origin: str, + message_id: str, + local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], ) -> int: assert self._can_write_to_device - def add_messages_txn(txn, now_ms, stream_id): + def add_messages_txn( + txn: LoggingTransaction, now_ms: int, stream_id: int + ) -> None: # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. @@ -727,8 +755,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): return stream_id def _add_messages_to_local_device_inbox_txn( - self, txn, stream_id, messages_by_user_then_device - ): + self, + txn: LoggingTransaction, + stream_id: int, + messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], + ) -> None: assert self._can_write_to_device local_by_user_then_device = {} @@ -840,8 +871,10 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): self._remove_dead_devices_from_device_inbox, ) - async def _background_drop_index_device_inbox(self, progress, batch_size): - def reindex_txn(conn): + async def _background_drop_index_device_inbox( + self, progress: JsonDict, batch_size: int + ) -> int: + def reindex_txn(conn: LoggingDatabaseConnection) -> None: txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.close() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 483dd80406..2df4dd4ed4 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -25,6 +25,7 @@ from typing import ( Optional, Set, Tuple, + cast, ) from synapse.api.errors import Codes, StoreError @@ -136,7 +137,9 @@ class DeviceWorkerStore(SQLBaseStore): Number of devices of this users. """ - def count_devices_by_users_txn(txn, user_ids): + def count_devices_by_users_txn( + txn: LoggingTransaction, user_ids: List[str] + ) -> int: sql = """ SELECT count(*) FROM devices @@ -149,7 +152,7 @@ class DeviceWorkerStore(SQLBaseStore): ) txn.execute(sql + clause, args) - return txn.fetchone()[0] + return cast(Tuple[int], txn.fetchone())[0] if not user_ids: return 0 @@ -468,7 +471,7 @@ class DeviceWorkerStore(SQLBaseStore): """ txn.execute(sql, (destination, from_stream_id, now_stream_id, limit)) - return list(txn) + return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall()) async def _get_device_update_edus_by_remote( self, @@ -549,7 +552,7 @@ class DeviceWorkerStore(SQLBaseStore): async def _get_last_device_update_for_remote_user( self, destination: str, user_id: str, from_stream_id: int ) -> int: - def f(txn): + def f(txn: LoggingTransaction) -> int: prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id FROM device_lists_outbound_last_success @@ -767,7 +770,7 @@ class DeviceWorkerStore(SQLBaseStore): if not user_ids_to_check: return set() - def _get_users_whose_devices_changed_txn(txn): + def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: changes = set() stream_id_where_clause = "stream_id > ?" @@ -966,7 +969,9 @@ class DeviceWorkerStore(SQLBaseStore): async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: """Mark that we no longer track device lists for remote user.""" - def _mark_remote_user_device_list_as_unsubscribed_txn(txn): + def _mark_remote_user_device_list_as_unsubscribed_txn( + txn: LoggingTransaction, + ) -> None: self.db_pool.simple_delete_txn( txn, table="device_lists_remote_extremeties", @@ -1004,7 +1009,7 @@ class DeviceWorkerStore(SQLBaseStore): ) def _store_dehydrated_device_txn( - self, txn, user_id: str, device_id: str, device_data: str + self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str ) -> Optional[str]: old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, @@ -1081,7 +1086,7 @@ class DeviceWorkerStore(SQLBaseStore): """ yesterday = self._clock.time_msec() - prune_age - def _prune_txn(txn): + def _prune_txn(txn: LoggingTransaction) -> None: # look for (user, destination) pairs which have an update older than # the cutoff. # @@ -1204,8 +1209,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): "drop_device_lists_outbound_last_success_non_unique_idx", ) - async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): - def f(conn): + async def _drop_device_list_streams_non_unique_indexes( + self, progress: JsonDict, batch_size: int + ) -> int: + def f(conn: LoggingDatabaseConnection) -> None: txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") @@ -1217,7 +1224,9 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): ) return 1 - async def _remove_duplicate_outbound_pokes(self, progress, batch_size): + async def _remove_duplicate_outbound_pokes( + self, progress: JsonDict, batch_size: int + ) -> int: # for some reason, we have accumulated duplicate entries in # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less # efficient. @@ -1230,7 +1239,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, ) - def _txn(txn): + def _txn(txn: LoggingTransaction) -> int: clause, args = make_tuple_comparison_clause( [(x, last_row[x]) for x in KEY_COLS] ) @@ -1602,7 +1611,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context = get_active_span_text_map() - def add_device_changes_txn(txn, stream_ids): + def add_device_changes_txn( + txn: LoggingTransaction, stream_ids: List[int] + ) -> None: self._add_device_change_to_stream_txn( txn, user_id, @@ -1635,8 +1646,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn: LoggingTransaction, user_id: str, device_ids: Collection[str], - stream_ids: List[str], - ): + stream_ids: List[int], + ) -> None: txn.call_after( self._device_list_stream_cache.entity_has_changed, user_id, @@ -1720,7 +1731,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_ids: Iterable[str], room_ids: Collection[str], - stream_ids: List[str], + stream_ids: List[int], context: Dict[str, str], ) -> None: """Record the user in the room has updated their device.""" @@ -1775,7 +1786,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): LIMIT ? """ - def get_uncoverted_outbound_room_pokes_txn(txn): + def get_uncoverted_outbound_room_pokes_txn( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: txn.execute(sql, (limit,)) return [ @@ -1808,7 +1821,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Marks the associated row in `device_lists_changes_in_room` as handled. """ - def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]): + def add_device_list_outbound_pokes_txn( + txn: LoggingTransaction, stream_ids: List[int] + ) -> None: if hosts: self._add_device_outbound_poke_to_stream_txn( txn, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 2a1e567ce0..9a6c2fd47a 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -47,6 +47,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry +from synapse.storage.engines.postgres import PostgresEngine from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id @@ -364,6 +365,20 @@ class PersistEventsStore: min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + # We check that the room still exists for events we're trying to + # persist. This is to protect against races with deleting a room. + # + # Annoyingly SQLite doesn't support row level locking. + if isinstance(self.database_engine, PostgresEngine): + for room_id in {e.room_id for e, _ in events_and_contexts}: + txn.execute( + "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE", + (room_id,), + ) + row = txn.fetchone() + if row is None: + raise Exception(f"Room does not exist {room_id}") + # stream orderings should have been assigned by now assert min_stream_order assert max_stream_order diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c31fc00eaa..a4a604a499 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -14,6 +14,7 @@ import logging import threading +from enum import Enum, auto from typing import ( TYPE_CHECKING, Any, @@ -30,7 +31,6 @@ from typing import ( ) import attr -from constantly import NamedConstant, Names from prometheus_client import Gauge from typing_extensions import Literal @@ -150,14 +150,14 @@ class _EventRow: outlier: bool -class EventRedactBehaviour(Names): +class EventRedactBehaviour(Enum): """ What to do when retrieving a redacted event from the database. """ - AS_IS = NamedConstant() - REDACT = NamedConstant() - BLOCK = NamedConstant() + as_is = auto() + redact = auto() + block = auto() class EventsWorkerStore(SQLBaseStore): @@ -327,7 +327,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_event( self, event_id: str, - redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.redact, get_prev_content: bool = ..., allow_rejected: bool = ..., allow_none: Literal[False] = ..., @@ -339,7 +339,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_event( self, event_id: str, - redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.redact, get_prev_content: bool = ..., allow_rejected: bool = ..., allow_none: Literal[True] = ..., @@ -350,7 +350,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_event( self, event_id: str, - redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.redact, get_prev_content: bool = False, allow_rejected: bool = False, allow_none: bool = False, @@ -362,9 +362,9 @@ class EventsWorkerStore(SQLBaseStore): event_id: The event_id of the event to fetch redact_behaviour: Determine what to do with a redacted event. Possible values: - * AS_IS - Return the full event body with no redacted content - * REDACT - Return the event but with a redacted body - * DISALLOW - Do not return redacted events (behave as per allow_none + * as_is - Return the full event body with no redacted content + * redact - Return the event but with a redacted body + * block - Do not return redacted events (behave as per allow_none if the event is redacted) get_prev_content: If True and event is a state event, @@ -406,7 +406,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_events( self, event_ids: Collection[str], - redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.redact, get_prev_content: bool = False, allow_rejected: bool = False, ) -> Dict[str, EventBase]: @@ -417,9 +417,9 @@ class EventsWorkerStore(SQLBaseStore): redact_behaviour: Determine what to do with a redacted event. Possible values: - * AS_IS - Return the full event body with no redacted content - * REDACT - Return the event but with a redacted body - * DISALLOW - Do not return redacted events (omit them from the response) + * as_is - Return the full event body with no redacted content + * redact - Return the event but with a redacted body + * block - Do not return redacted events (omit them from the response) get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. @@ -442,7 +442,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_events_as_list( self, event_ids: Collection[str], - redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.redact, get_prev_content: bool = False, allow_rejected: bool = False, ) -> List[EventBase]: @@ -455,9 +455,9 @@ class EventsWorkerStore(SQLBaseStore): event_ids: The event_ids of the events to fetch redact_behaviour: Determine what to do with a redacted event. Possible values: - * AS_IS - Return the full event body with no redacted content - * REDACT - Return the event but with a redacted body - * DISALLOW - Do not return redacted events (omit them from the response) + * as_is - Return the full event body with no redacted content + * redact - Return the event but with a redacted body + * block - Do not return redacted events (omit them from the response) get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. @@ -568,10 +568,10 @@ class EventsWorkerStore(SQLBaseStore): event = entry.event if entry.redacted_event: - if redact_behaviour == EventRedactBehaviour.BLOCK: + if redact_behaviour == EventRedactBehaviour.block: # Skip this event continue - elif redact_behaviour == EventRedactBehaviour.REDACT: + elif redact_behaviour == EventRedactBehaviour.redact: event = entry.redacted_event events.append(event) @@ -1094,6 +1094,18 @@ class EventsWorkerStore(SQLBaseStore): original_ev.internal_metadata.stream_ordering = row.stream_ordering original_ev.internal_metadata.outlier = row.outlier + # Consistency check: if the content of the event has been modified in the + # database, then the calculated event ID will not match the event id in the + # database. + if original_ev.event_id != event_id: + # it's difficult to see what to do here. Pretty much all bets are off + # if Synapse cannot rely on the consistency of its database. + raise RuntimeError( + f"Database corruption: Event {event_id} in room {d['room_id']} " + f"from the database appears to have been modified (calculated " + f"event id {original_ev.event_id})" + ) + event_map[event_id] = original_ev # finally, we can decide whether each one needs redacting, and build diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 0aef121d83..04efad9e9a 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -522,7 +522,9 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_joined_groups", ) - async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]: + async def get_all_groups_for_user( + self, user_id: str, now_token: int + ) -> List[JsonDict]: def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: sql = """ SELECT group_id, type, membership, u.content diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 6990f3ed1d..0a19f607bd 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -15,11 +15,12 @@ import itertools import logging -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple from signedjson.key import decode_verify_key_bytes from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction from synapse.storage.keys import FetchKeyResult from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList @@ -35,7 +36,9 @@ class KeyStore(SQLBaseStore): """Persistence for signature verification keys""" @cached() - def _get_server_verify_key(self, server_name_and_key_id): + def _get_server_verify_key( + self, server_name_and_key_id: Tuple[str, str] + ) -> FetchKeyResult: raise NotImplementedError() @cachedList( @@ -179,19 +182,21 @@ class KeyStore(SQLBaseStore): async def get_server_keys_json( self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] - ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: + ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: """Retrieve the key json for a list of server_keys and key ids. If no keys are found for a given server, key_id and source then that server, key_id, and source triplet entry will be an empty list. The JSON is returned as a byte array so that it can be efficiently used in an HTTP response. Args: - server_keys (list): List of (server_name, key_id, source) triplets. + server_keys: List of (server_name, key_id, source) triplets. Returns: A mapping from (server_name, key_id, source) triplets to a list of dicts """ - def _get_server_keys_json_txn(txn): + def _get_server_keys_json_txn( + txn: LoggingTransaction, + ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: results = {} for server_name, key_id, from_server in server_keys: keyvalues = {"server_name": server_name} diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 322ed05390..40ac377ca9 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -388,7 +388,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) async def store_url_cache( - self, url, response_code, etag, expires_ts, og, media_id, download_ts + self, + url: str, + response_code: int, + etag: Optional[str], + expires_ts: int, + og: Optional[str], + media_id: str, + download_ts: int, ) -> None: await self.db_pool.simple_insert( "local_media_repository_url_cache", @@ -441,7 +448,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) async def get_cached_remote_media( - self, origin, media_id: str + self, origin: str, media_id: str ) -> Optional[Dict[str, Any]]: return await self.db_pool.simple_select_one( "remote_media_cache", @@ -608,7 +615,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) async def delete_remote_media(self, media_origin: str, media_id: str) -> None: - def delete_remote_media_txn(txn): + def delete_remote_media_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, "remote_media_cache", diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 4f1c22c71b..5beb8f1d4b 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -232,10 +232,10 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): # is racy. # Have resolved to invalidate the whole cache for now and do # something about it if and when the perf becomes significant - self._invalidate_all_cache_and_stream( # type: ignore[attr-defined] + self._invalidate_all_cache_and_stream( txn, self.user_last_seen_monthly_active ) - self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined] + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) reserved_users = await self.get_registered_reserved_users() await self.db_pool.runInteraction( @@ -363,7 +363,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group - is_guest = await self.is_guest(user_id) # type: ignore[attr-defined] + is_guest = await self.is_guest(user_id) if is_guest: return is_trial = await self.is_trial_user(user_id) diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index d3c4611686..b47c511450 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream @@ -103,7 +103,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): prefilled_cache=presence_cache_prefill, ) - async def update_presence(self, presence_states) -> Tuple[int, int]: + async def update_presence( + self, presence_states: List[UserPresenceState] + ) -> Tuple[int, int]: assert self._can_persist_presence stream_ordering_manager = self._presence_id_gen.get_next_mult( @@ -121,7 +123,10 @@ class PresenceStore(PresenceBackgroundUpdateStore): return stream_orderings[-1], self._presence_id_gen.get_current_token() def _update_presence_txn( - self, txn: LoggingTransaction, stream_orderings, presence_states + self, + txn: LoggingTransaction, + stream_orderings: List[int], + presence_states: List[UserPresenceState], ) -> None: for stream_id, state in zip(stream_orderings, presence_states): txn.call_after( @@ -405,7 +410,13 @@ class PresenceStore(PresenceBackgroundUpdateStore): self._presence_on_startup = [] return active_on_startup - def process_replication_rows(self, stream_name, instance_name, token, rows) -> None: + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == PresenceStream.NAME: self._presence_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 2e3818e432..bfc85b3add 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -324,7 +324,12 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): ) def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: - # First we fetch all the state groups that should be deleted, before + # We *immediately* delete the room from the rooms table. This ensures + # that we don't race when persisting events (as that transaction checks + # that the room exists). + txn.execute("DELETE FROM rooms WHERE room_id = ?", (room_id,)) + + # Next, we fetch all the state groups that should be deleted, before # we delete that information. txn.execute( """ @@ -403,7 +408,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "room_stats_state", "room_stats_current", "room_stats_earliest_token", - "rooms", "stream_ordering_to_exterm", "users_in_public_rooms", "users_who_share_private_rooms", diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 92539f5d41..eb85bbd392 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -16,7 +16,7 @@ import abc import logging from typing import TYPE_CHECKING, Dict, List, Tuple, Union -from synapse.api.errors import NotFoundError, StoreError +from synapse.api.errors import StoreError from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json @@ -618,7 +618,7 @@ class PushRuleStore(PushRulesWorkerStore): are always stored in the database `push_rules` table). Raises: - NotFoundError if the rule does not exist. + RuleNotFoundException if the rule does not exist. """ async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() @@ -668,8 +668,7 @@ class PushRuleStore(PushRulesWorkerStore): ) txn.execute(sql, (user_id, rule_id)) if txn.fetchone() is None: - # needed to set NOT_FOUND code. - raise NotFoundError("Push rule does not exist.") + raise RuleNotFoundException("Push rule does not exist.") self.db_pool.simple_upsert_txn( txn, @@ -698,9 +697,6 @@ class PushRuleStore(PushRulesWorkerStore): """ Sets the `actions` state of a push rule. - Will throw NotFoundError if the rule does not exist; the Code for this - is NOT_FOUND. - Args: user_id: the user ID of the user who wishes to enable/disable the rule e.g. '@tina:example.org' @@ -712,6 +708,9 @@ class PushRuleStore(PushRulesWorkerStore): is_default_rule: True if and only if this is a server-default rule. This skips the check for existence (as only user-created rules are always stored in the database `push_rules` table). + + Raises: + RuleNotFoundException if the rule does not exist. """ actions_json = json_encoder.encode(actions) @@ -744,7 +743,7 @@ class PushRuleStore(PushRulesWorkerStore): except StoreError as serr: if serr.code == 404: # this sets the NOT_FOUND error Code - raise NotFoundError("Push rule does not exist") + raise RuleNotFoundException("Push rule does not exist") else: raise diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index cf64cd63a4..91286c9b65 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -14,11 +14,25 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + cast, +) from synapse.push import PusherConfig, ThrottleParams from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder @@ -117,7 +131,7 @@ class PusherWorkerStore(SQLBaseStore): return self._decode_pushers_rows(ret) async def get_all_pushers(self) -> Iterator[PusherConfig]: - def get_pushers(txn): + def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]: txn.execute("SELECT * FROM pushers") rows = self.db_pool.cursor_to_dict(txn) @@ -152,7 +166,9 @@ class PusherWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_pushers_rows_txn(txn): + def get_all_updated_pushers_rows_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT id, user_name, app_id, pushkey FROM pushers @@ -160,10 +176,13 @@ class PusherWorkerStore(SQLBaseStore): ORDER BY id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [ - (stream_id, (user_name, app_id, pushkey, False)) - for stream_id, user_name, app_id, pushkey in txn - ] + updates = cast( + List[Tuple[int, tuple]], + [ + (stream_id, (user_name, app_id, pushkey, False)) + for stream_id, user_name, app_id, pushkey in txn + ], + ) sql = """ SELECT stream_id, user_id, app_id, pushkey @@ -192,12 +211,12 @@ class PusherWorkerStore(SQLBaseStore): ) @cached(num_args=1, max_entries=15000) - async def get_if_user_has_pusher(self, user_id: str): + async def get_if_user_has_pusher(self, user_id: str) -> None: # This only exists for the cachedList decorator raise NotImplementedError() async def update_pusher_last_stream_ordering( - self, app_id, pushkey, user_id, last_stream_ordering + self, app_id: str, pushkey: str, user_id: str, last_stream_ordering: int ) -> None: await self.db_pool.simple_update_one( "pushers", @@ -291,7 +310,7 @@ class PusherWorkerStore(SQLBaseStore): last_user = progress.get("last_user", "") - def _delete_pushers(txn) -> int: + def _delete_pushers(txn: LoggingTransaction) -> int: sql = """ SELECT name FROM users @@ -339,7 +358,7 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) - def _delete_pushers(txn) -> int: + def _delete_pushers(txn: LoggingTransaction) -> int: sql = """ SELECT p.id, access_token FROM pushers AS p @@ -396,7 +415,7 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) - def _delete_pushers(txn) -> int: + def _delete_pushers(txn: LoggingTransaction) -> int: sql = """ SELECT p.id, p.user_name, p.app_id, p.pushkey @@ -502,7 +521,7 @@ class PusherStore(PusherWorkerStore): async def delete_pusher_by_app_id_pushkey_user_id( self, app_id: str, pushkey: str, user_id: str ) -> None: - def delete_pusher_txn(txn, stream_id): + def delete_pusher_txn(txn: LoggingTransaction, stream_id: int) -> None: self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_if_user_has_pusher, (user_id,) ) @@ -547,7 +566,7 @@ class PusherStore(PusherWorkerStore): # account. pushers = list(await self.get_pushers_by_user_id(user_id)) - def delete_pushers_txn(txn, stream_ids): + def delete_pushers_txn(txn: LoggingTransaction, stream_ids: List[int]) -> None: self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_if_user_has_pusher, (user_id,) ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 332e901dda..d035969a31 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -22,7 +22,6 @@ from typing import ( Iterable, List, Optional, - Set, Tuple, cast, ) @@ -117,54 +116,139 @@ class ReceiptsWorkerStore(SQLBaseStore): """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() - @cached() - async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]: - receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ) - return {r["user_id"] for r in receipts} - - @cached(num_args=2) - async def get_receipts_for_room( - self, room_id: str, receipt_type: str - ) -> List[Dict[str, Any]]: - return await self.db_pool.simple_select_list( - table="receipts_linearized", - keyvalues={"room_id": room_id, "receipt_type": receipt_type}, - retcols=("user_id", "event_id"), - desc="get_receipts_for_room", - ) - - @cached(num_args=3) async def get_last_receipt_event_id_for_user( - self, user_id: str, room_id: str, receipt_type: str + self, user_id: str, room_id: str, receipt_types: Iterable[str] ) -> Optional[str]: - return await self.db_pool.simple_select_one_onecol( - table="receipts_linearized", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, - retcol="event_id", - desc="get_own_receipt_for_user", - allow_none=True, - ) + """ + Fetch the event ID for the latest receipt in a room with one of the given receipt types. + + Args: + user_id: The user to fetch receipts for. + room_id: The room ID to fetch the receipt for. + receipt_type: The receipt types to fetch. Earlier receipt types + are given priority if multiple receipts point to the same event. + + Returns: + The latest receipt, if one exists. + """ + latest_event_id: Optional[str] = None + latest_stream_ordering = 0 + for receipt_type in receipt_types: + result = await self._get_last_receipt_event_id_for_user( + user_id, room_id, receipt_type + ) + if result is None: + continue + event_id, stream_ordering = result + + if latest_event_id is None or latest_stream_ordering < stream_ordering: + latest_event_id = event_id + latest_stream_ordering = stream_ordering + + return latest_event_id + + @cached() + async def _get_last_receipt_event_id_for_user( + self, user_id: str, room_id: str, receipt_type: str + ) -> Optional[Tuple[str, int]]: + """ + Fetch the event ID and stream ordering for the latest receipt. + + Args: + user_id: The user to fetch receipts for. + room_id: The room ID to fetch the receipt for. + receipt_type: The receipt type to fetch. + + Returns: + The event ID and stream ordering of the latest receipt, if one exists; + otherwise `None`. + """ + sql = """ + SELECT event_id, stream_ordering + FROM receipts_linearized + INNER JOIN events USING (room_id, event_id) + WHERE user_id = ? + AND room_id = ? + AND receipt_type = ? + """ + + def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]: + txn.execute(sql, (user_id, room_id, receipt_type)) + return cast(Optional[Tuple[str, int]], txn.fetchone()) + + return await self.db_pool.runInteraction("get_own_receipt_for_user", f) - @cached(num_args=2) async def get_receipts_for_user( - self, user_id: str, receipt_type: str + self, user_id: str, receipt_types: Iterable[str] ) -> Dict[str, str]: - rows = await self.db_pool.simple_select_list( - table="receipts_linearized", - keyvalues={"user_id": user_id, "receipt_type": receipt_type}, - retcols=("room_id", "event_id"), - desc="get_receipts_for_user", + """ + Fetch the event IDs for the latest receipts sent by the given user. + + Args: + user_id: The user to fetch receipts for. + receipt_types: The receipt types to check. + + Returns: + A map of room ID to the event ID of the latest receipt for that room. + + If the user has not sent a receipt to a room then it will not appear + in the returned dictionary. + """ + results = await self.get_receipts_for_user_with_orderings( + user_id, receipt_types ) - return {row["room_id"]: row["event_id"] for row in rows} + # Reduce the result to room ID -> event ID. + return { + room_id: room_result["event_id"] for room_id, room_result in results.items() + } async def get_receipts_for_user_with_orderings( + self, user_id: str, receipt_types: Iterable[str] + ) -> JsonDict: + """ + Fetch receipts for all rooms that the given user is joined to. + + Args: + user_id: The user to fetch receipts for. + receipt_types: The receipt types to fetch. Earlier receipt types + are given priority if multiple receipts point to the same event. + + Returns: + A map of room ID to the latest receipt (for the given types). + """ + results: JsonDict = {} + for receipt_type in receipt_types: + partial_result = await self._get_receipts_for_user_with_orderings( + user_id, receipt_type + ) + for room_id, room_result in partial_result.items(): + # If the room has not yet been seen, or the receipt is newer, + # use it. + if ( + room_id not in results + or results[room_id]["stream_ordering"] + < room_result["stream_ordering"] + ): + results[room_id] = room_result + + return results + + @cached() + async def _get_receipts_for_user_with_orderings( self, user_id: str, receipt_type: str ) -> JsonDict: + """ + Fetch receipts for all rooms that the given user is joined to. + + Args: + user_id: The user to fetch receipts for. + receipt_type: The receipt type to fetch. + + Returns: + A map of room ID to the latest receipt information. + """ + def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: sql = ( "SELECT rl.room_id, rl.event_id," @@ -174,8 +258,9 @@ class ReceiptsWorkerStore(SQLBaseStore): " WHERE rl.room_id = e.room_id" " AND rl.event_id = e.event_id" " AND user_id = ?" + " AND receipt_type = ?" ) - txn.execute(sql, (user_id,)) + txn.execute(sql, (user_id, receipt_type)) return cast(List[Tuple[str, str, int, int]], txn.fetchall()) rows = await self.db_pool.runInteraction( @@ -241,7 +326,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) - @cached(num_args=3, tree=True) + @cached(tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None ) -> List[JsonDict]: @@ -486,33 +571,14 @@ class ReceiptsWorkerStore(SQLBaseStore): "get_all_updated_receipts", get_all_updated_receipts_txn ) - def _invalidate_get_users_with_receipts_in_room( - self, room_id: str, receipt_type: str, user_id: str - ) -> None: - if receipt_type != ReceiptTypes.READ: - return - - res = self.get_users_with_read_receipts_in_room.cache.get_immediate( - room_id, None, update_metrics=False - ) - - if res and user_id in res: - # We'd only be adding to the set, so no point invalidating if the - # user is already there - return - - self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - def invalidate_caches_for_receipt( self, room_id: str, receipt_type: str, user_id: str ) -> None: - self.get_receipts_for_user.invalidate((user_id, receipt_type)) + self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate((room_id,)) - self.get_last_receipt_event_id_for_user.invalidate( + self._get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) - self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) - self.get_receipts_for_room.invalidate((room_id, receipt_type)) def process_replication_rows( self, @@ -541,11 +607,11 @@ class ReceiptsWorkerStore(SQLBaseStore): data: JsonDict, stream_id: int, ) -> Optional[int]: - """Inserts a read-receipt into the database if it's newer than the current RR + """Inserts a receipt into the database if it's newer than the current one. Returns: - None if the RR is older than the current RR - otherwise, the rx timestamp of the event that the RR corresponds to + None if the receipt is older than the current receipt + otherwise, the rx timestamp of the event that the receipt corresponds to (or 0 if the event is unknown) """ assert self._can_write_to_receipts @@ -566,7 +632,7 @@ class ReceiptsWorkerStore(SQLBaseStore): if stream_ordering is not None: sql = ( "SELECT stream_ordering, event_id FROM events" - " INNER JOIN receipts_linearized as r USING (event_id, room_id)" + " INNER JOIN receipts_linearized AS r USING (event_id, room_id)" " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" ) txn.execute(sql, (room_id, receipt_type, user_id)) @@ -607,7 +673,10 @@ class ReceiptsWorkerStore(SQLBaseStore): lock=False, ) - if receipt_type == ReceiptTypes.READ and stream_ordering is not None: + if ( + receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE) + and stream_ordering is not None + ): self._remove_old_push_actions_before_txn( # type: ignore[attr-defined] txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) @@ -626,6 +695,10 @@ class ReceiptsWorkerStore(SQLBaseStore): Automatically does conversion between linearized and graph representations. + + Returns: + The new receipts stream ID and token, if the receipt is newer than + what was previously persisted. None, otherwise. """ assert self._can_write_to_receipts @@ -673,6 +746,7 @@ class ReceiptsWorkerStore(SQLBaseStore): stream_id=stream_id, ) + # If the receipt was older than the currently persisted one, nothing to do. if event_ts is None: return None @@ -721,14 +795,10 @@ class ReceiptsWorkerStore(SQLBaseStore): ) -> None: assert self._can_write_to_receipts - txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after( - self._invalidate_get_users_with_receipts_in_room, - room_id, - receipt_type, - user_id, + self._get_receipts_for_user_with_orderings.invalidate, + (user_id, receipt_type), ) - txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index d43163c27c..4991360b70 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -215,7 +215,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first - N days of registration defined by `mau_trial_days` config + N days of registration defined by `mau_trial_days` config or the + `mau_appservice_trial_days` config. Args: user_id: The user to check for trial status. @@ -226,7 +227,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return False now = self._clock.time_msec() - trial_duration_ms = self.config.server.mau_trial_days * 24 * 60 * 60 * 1000 + days = self.config.server.mau_appservice_trial_days.get( + info["appservice_id"], self.config.server.mau_trial_days + ) + trial_duration_ms = days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms return is_trial diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index a5c31f6787..484976ca6b 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -445,8 +445,8 @@ class RelationsWorkerStore(SQLBaseStore): @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") async def get_thread_summaries( self, event_ids: Collection[str] - ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: - """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. + ) -> Dict[str, Optional[Tuple[int, EventBase]]]: + """Get the number of threaded replies and the latest reply (if any) for the given events. Args: event_ids: Summarize the thread related to this event ID. @@ -458,7 +458,6 @@ class RelationsWorkerStore(SQLBaseStore): Each summary is a tuple of: The number of events in the thread. The most recent event in the thread. - The most recent edit to the most recent event in the thread, if applicable. """ def _get_thread_summaries_txn( @@ -544,9 +543,6 @@ class RelationsWorkerStore(SQLBaseStore): latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] - # Check to see if any of those events are edited. - latest_edits = await self.get_applicable_edits(latest_event_ids.values()) - # Map to the event IDs to the thread summary. # # There might not be a summary due to there not being a thread or @@ -557,8 +553,7 @@ class RelationsWorkerStore(SQLBaseStore): summary = None if latest_event: - latest_edit = latest_edits.get(latest_event_id) - summary = (counts[parent_event_id], latest_event, latest_edit) + summary = (counts[parent_event_id], latest_event) summaries[parent_event_id] = summary return summaries diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index d4482c06db..3c49e7ec98 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -494,11 +494,11 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # We set redact_behaviour to block here to prevent redacted events being returned in # search results (which is a data leak) events = await self.get_events_as_list( # type: ignore[attr-defined] [r["event_id"] for r in results], - redact_behaviour=EventRedactBehaviour.BLOCK, + redact_behaviour=EventRedactBehaviour.block, ) event_map = {ev.event_id: ev for ev in events} @@ -652,11 +652,11 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # We set redact_behaviour to block here to prevent redacted events being returned in # search results (which is a data leak) events = await self.get_events_as_list( # type: ignore[attr-defined] [r["event_id"] for r in results], - redact_behaviour=EventRedactBehaviour.BLOCK, + redact_behaviour=EventRedactBehaviour.block, ) event_map = {ev.event_id: ev for ev in events} diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index 95148fd227..05da15074a 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -48,7 +48,7 @@ class SignatureWorkerStore(EventsWorkerStore): """ events = await self.get_events( event_ids, - redact_behaviour=EventRedactBehaviour.AS_IS, + redact_behaviour=EventRedactBehaviour.as_is, allow_rejected=True, ) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index e653841fe5..18ae8aee29 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -12,11 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections.abc import logging from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple -from frozendict import frozendict - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion @@ -160,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): predecessor = create_event.content.get("predecessor", None) # Ensure the key is a dictionary - if not isinstance(predecessor, (dict, frozendict)): + if not isinstance(predecessor, collections.abc.Mapping): return None # The keys must be strings since the data is JSON. @@ -370,10 +369,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def _update_state_for_partial_state_event_txn( self, - txn, + txn: LoggingTransaction, event: EventBase, context: EventContext, - ): + ) -> None: # we shouldn't have any outliers here assert not event.internal_metadata.is_outlier() diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 2d339b6008..f38bedbbcd 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -131,7 +131,7 @@ class UIAuthWorkerStore(SQLBaseStore): session_id: str, stage_type: str, result: Union[str, bool, JsonDict], - ): + ) -> None: """ Mark a session stage as completed. @@ -200,7 +200,9 @@ class UIAuthWorkerStore(SQLBaseStore): desc="set_ui_auth_client_dict", ) - async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): + async def set_ui_auth_session_data( + self, session_id: str, key: str, value: Any + ) -> None: """ Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by @@ -223,7 +225,7 @@ class UIAuthWorkerStore(SQLBaseStore): def _set_ui_auth_session_data_txn( self, txn: LoggingTransaction, session_id: str, key: str, value: Any - ): + ) -> None: # Get the current value. result = cast( Dict[str, Any], @@ -275,7 +277,7 @@ class UIAuthWorkerStore(SQLBaseStore): session_id: str, user_agent: str, ip: str, - ): + ) -> None: """Add the given user agent / IP to the tracking table""" await self.db_pool.simple_upsert( table="ui_auth_sessions_ips", @@ -318,7 +320,7 @@ class UIAuthWorkerStore(SQLBaseStore): def _delete_old_ui_auth_sessions_txn( self, txn: LoggingTransaction, expiration_time: int - ): + ) -> None: # Get the expired sessions. sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" txn.execute(sql, [expiration_time]) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index e496ba7bed..97118045a1 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -943,7 +943,7 @@ class EventsPersistenceStorage: dropped_events = await self.main_store.get_events( dropped_extrems, allow_rejected=True, - redact_behaviour=EventRedactBehaviour.AS_IS, + redact_behaviour=EventRedactBehaviour.as_is, ) new_senders = {get_domain_from_id(e.sender) for e, _ in events_context} @@ -974,7 +974,7 @@ class EventsPersistenceStorage: prev_events = await self.main_store.get_events( new_events, allow_rejected=True, - redact_behaviour=EventRedactBehaviour.AS_IS, + redact_behaviour=EventRedactBehaviour.as_is, ) events_to_check = prev_events.values() diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e3153d1a4a..546d6bae6e 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -501,11 +501,11 @@ def _upgrade_existing_database( if hasattr(module, "run_create"): logger.info("Running %s:run_create", relative_path) - module.run_create(cur, database_engine) # type: ignore + module.run_create(cur, database_engine) if not is_empty and hasattr(module, "run_upgrade"): logger.info("Running %s:run_upgrade", relative_path) - module.run_upgrade(cur, database_engine, config=config) # type: ignore + module.run_upgrade(cur, database_engine, config=config) elif ext == ".pyc" or file_name == "__pycache__": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 0b9ac26b69..f6b3ee31e4 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -107,7 +107,7 @@ class TTLCache(Generic[KT, VT]): self._metrics.inc_hits() return e.value, e.expiry_time, e.ttl - def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore + def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: """Remove a value from the cache If key is in the cache, remove it and return its value, else return default. diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 9c405eb4d7..7223af1a36 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections.abc from typing import Any from frozendict import frozendict @@ -35,7 +36,7 @@ def freeze(o: Any) -> Any: def unfreeze(o: Any) -> Any: - if isinstance(o, (dict, frozendict)): + if isinstance(o, collections.abc.Mapping): return {k: unfreeze(v) for k, v in o.items()} if isinstance(o, (bytes, str)): diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 3e05789923..d547df8a64 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -105,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) @@ -124,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "192.168.10.10" + request.getClientAddress.return_value.host = "192.168.10.10" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) @@ -143,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "131.111.8.42" + request.getClientAddress.return_value.host = "131.111.8.42" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() f = self.get_failure( @@ -190,7 +190,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -209,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -236,7 +236,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_device = simple_async_mock({"hidden": False}) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] @@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_device = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] @@ -288,7 +288,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) self.store.insert_client_ip = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.get_success(self.auth.get_user_by_req(request)) @@ -305,7 +305,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) self.store.insert_client_ip = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.get_success(self.auth.get_user_by_req(request)) diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index ec8864dafe..268a48d7ba 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -83,7 +83,7 @@ class FederationClientTest(FederatingHomeserverTestCase): ) # mock up the response, and have the agent return it - self._mock_agent.request.return_value = defer.succeed( + self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( _mock_response( { "pdus": [ diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 91f982518e..6b26353d5e 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -226,7 +226,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # Send the server a device list EDU for the other user, this will cause # it to try and resync the device lists. self.hs.get_federation_transport_client().query_user_devices.return_value = ( - defer.succeed( + make_awaitable( { "stream_id": "1", "user_id": "@user2:host2", diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 8c72cf6b30..5b0cd1ab86 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -411,6 +411,88 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): "exclusive_as_user", "password", self.exclusive_as_user_device_id ) + def test_sending_read_receipt_batches_to_application_services(self): + """Tests that a large batch of read receipts are sent correctly to + interested application services. + """ + # Register an application service that's interested in a certain user + # and room prefix + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@exclusive_as_user:.+", + "exclusive": True, + } + ], + ApplicationService.NS_ROOMS: [ + { + "regex": "!fakeroom_.*", + "exclusive": True, + } + ], + }, + ) + + # "Complete" a transaction. + # All this really does for us is make an entry in the application_services_state + # database table, which tracks the current stream_token per stream ID per AS. + self.get_success( + self.hs.get_datastores().main.complete_appservice_txn( + 0, + interested_appservice, + ) + ) + + # Now, pretend that we receive a large burst of read receipts (300 total) that + # all come in at once. + for i in range(300): + self.get_success( + # Insert a fake read receipt into the database + self.hs.get_datastores().main.insert_receipt( + # We have to use unique room ID + user ID combinations here, as the db query + # is an upsert. + room_id=f"!fakeroom_{i}:test", + receipt_type="m.read", + user_id=self.local_user, + event_ids=[f"$eventid_{i}"], + data={}, + ) + ) + + # Now notify the appservice handler that 300 read receipts have all arrived + # at once. What will it do! + # note: stream tokens start at 2 + for stream_token in range(2, 303): + self.get_success( + self.hs.get_application_service_handler()._notify_interested_services_ephemeral( + services=[interested_appservice], + stream_key="receipt_key", + new_token=stream_token, + users=[self.exclusive_as_user], + ) + ) + + # Using our txn send mock, we can see what the AS received. After iterating over every + # transaction, we'd like to see all 300 read receipts accounted for. + # No more, no less. + all_ephemeral_events = [] + for call in self.send_mock.call_args_list: + ephemeral_events = call[0][2] + all_ephemeral_events += ephemeral_events + + # Ensure that no duplicate events were sent + self.assertEqual(len(all_ephemeral_events), 300) + + # Check that the ephemeral event is a read receipt with the expected structure + latest_read_receipt = all_ephemeral_events[-1] + self.assertEqual(latest_read_receipt["type"], "m.receipt") + + event_id = list(latest_read_receipt["content"].keys())[0] + self.assertEqual( + latest_read_receipt["content"][event_id]["m.read"], {self.local_user: {}} + ) + @unittest.override_config( {"experimental_features": {"msc2409_to_device_messages_enabled": True}} ) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index a54aa29cf1..2b21547d0f 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -201,4 +201,16 @@ class CasHandlerTestCase(HomeserverTestCase): def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "getHeader", "_disconnected"]) + mock = Mock( + spec=[ + "finish", + "getClientAddress", + "getHeader", + "setHeader", + "setResponseCode", + "write", + ] + ) + # `_disconnected` musn't be another `Mock`, otherwise it will be truthy. + mock._disconnected = False + return mock diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 8c74ed1fcf..1e6ad4b663 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -19,7 +19,6 @@ from unittest import mock from parameterized import parameterized from signedjson import key as key, sign as sign -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms @@ -704,7 +703,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" self.hs.get_federation_client().query_client_keys = mock.Mock( - return_value=defer.succeed( + return_value=make_awaitable( { "device_keys": {remote_user_id: {}}, "master_keys": { @@ -777,14 +776,14 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): # Pretend we're sharing a room with the user we're querying. If not, # `_query_devices_for_destination` will return early. self.store.get_rooms_for_user = mock.Mock( - return_value=defer.succeed({"some_room_id"}) + return_value=make_awaitable({"some_room_id"}) ) remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" self.hs.get_federation_client().query_user_devices = mock.Mock( - return_value=defer.succeed( + return_value=make_awaitable( { "user_id": remote_user_id, "stream_id": 1, diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 9684120c70..1231aed944 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -1300,7 +1300,7 @@ def _build_callback_request( "getCookie", "cookies", "requestHeaders", - "getClientIP", + "getClientAddress", "getHeader", ] ) @@ -1310,5 +1310,5 @@ def _build_callback_request( request.args = {} request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] - request.getClientIP.return_value = ip_address + request.getClientAddress.return_value.host = ip_address return request diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index d401fda938..82b3bb3b73 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -17,8 +17,6 @@ from typing import Any, Type, Union from unittest.mock import Mock -from twisted.internet import defer - import synapse from synapse.api.constants import LoginType from synapse.api.errors import Codes @@ -32,11 +30,9 @@ from tests.server import FakeChannel from tests.test_utils import make_awaitable from tests.unittest import override_config -# (possibly experimental) login flows we expect to appear in the list after the normal -# ones +# Login flows we expect to appear in the list after the normal ones. ADDITIONAL_LOGIN_FLOWS = [ {"type": "m.login.application_service"}, - {"type": "uk.half-shot.msc2778.login.application_service"}, ] # a mock instance which the dummy auth providers delegate to, so we can see what's going @@ -190,7 +186,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) # check_password must return an awaitable - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@u:test", channel.json_body["user_id"]) @@ -226,13 +222,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.get_success(module_api.register_user("u")) # log in twice, to get two devices - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) tok1 = self.login("u", "p") self.login("u", "p", device_id="dev2") mock_password_provider.reset_mock() # have the auth provider deny the request to start with - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) # make the initial request which returns a 401 session = self._start_delete_device_session(tok1, "dev2") @@ -246,7 +242,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # Finally, check the request goes through when we allow it - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") @@ -260,7 +256,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, 403, channel.result) @@ -277,7 +273,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # have the auth provider deny the request - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) # log in twice, to get two devices tok1 = self.login("localuser", "localpass") @@ -320,7 +316,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -342,7 +338,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # allow login via the auth provider - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) # log in twice, to get two devices tok1 = self.login("localuser", "p") @@ -359,7 +355,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_password.assert_not_called() # now try deleting with the local password - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._authed_delete_device( tok1, "dev2", session, "localuser", "localpass" ) @@ -413,7 +409,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", None) ) channel = self._send_login("test.login_type", "u", test_field="y") @@ -427,7 +423,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # try a weird username. Again, it's unclear what we *expect* to happen # in these cases, but at least we can guard against the API changing # unexpectedly - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") @@ -477,7 +473,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", None) ) body["auth"]["test_field"] = "foo" @@ -490,7 +486,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # and finally, succeed - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) @@ -508,9 +504,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.custom_auth_provider_callback_test_body() def custom_auth_provider_callback_test_body(self): - callback = Mock(return_value=defer.succeed(None)) + callback = Mock(return_value=make_awaitable(None)) - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", callback) ) channel = self._send_login("test.login_type", "u", test_field="y") @@ -646,7 +642,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 5081b97573..0482a1ea34 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -15,7 +15,7 @@ from typing import List -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReceiptTypes from synapse.types import JsonDict from tests import unittest @@ -25,20 +25,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.event_source = hs.get_event_sources().sources.receipt - # In the first param of _test_filters_hidden we use "hidden" instead of - # ReadReceiptEventFields.MSC2285_HIDDEN. We do this because we're mocking - # the data from the database which doesn't use the prefix - - def test_filters_out_hidden_receipt(self): - self._test_filters_hidden( + def test_filters_out_private_receipt(self): + self._test_filters_private( [ { "content": { "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, - "hidden": True, } } } @@ -50,58 +45,23 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [], ) - def test_does_not_filter_out_our_hidden_receipt(self): - self._test_filters_hidden( - [ - { - "content": { - "$1435641916hfgh4394fHBLK:matrix.org": { - "m.read": { - "@me:server.org": { - "ts": 1436451550453, - "hidden": True, - }, - } - } - }, - "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", - } - ], - [ - { - "content": { - "$1435641916hfgh4394fHBLK:matrix.org": { - "m.read": { - "@me:server.org": { - "ts": 1436451550453, - ReadReceiptEventFields.MSC2285_HIDDEN: True, - }, - } - } - }, - "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", - } - ], - ) - - def test_filters_out_hidden_receipt_and_ignores_rest(self): - self._test_filters_hidden( + def test_filters_out_private_receipt_and_ignores_rest(self): + self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, - "hidden": True, }, + }, + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, }, - } - } + }, + }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", "type": "m.receipt", @@ -111,7 +71,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -124,21 +84,20 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_filters_out_event_with_only_hidden_receipts_and_ignores_the_rest(self): - self._test_filters_hidden( + def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(self): + self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, - "hidden": True, }, } }, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -153,7 +112,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -167,13 +126,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ) def test_handles_missing_content_of_m_read(self): - self._test_filters_hidden( + self._test_filters_private( [ { "content": { - "$14356419ggffg114394fHBLK:matrix.org": {"m.read": {}}, + "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}}, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -187,9 +146,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [ { "content": { - "$14356419ggffg114394fHBLK:matrix.org": {"m.read": {}}, + "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}}, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -203,13 +162,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ) def test_handles_empty_event(self): - self._test_filters_hidden( + self._test_filters_private( [ { "content": { "$143564gdfg6114394fHBLK:matrix.org": {}, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -223,9 +182,8 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [ { "content": { - "$143564gdfg6114394fHBLK:matrix.org": {}, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -238,16 +196,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_filters_out_receipt_event_with_only_hidden_receipt_and_ignores_rest(self): - self._test_filters_hidden( + def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(self): + self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, - "hidden": True, }, } }, @@ -258,7 +215,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -273,7 +230,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -292,12 +249,12 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): Context: https://github.com/matrix-org/synapse/issues/10603 """ - self._test_filters_hidden( + self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@rikj:jki.re": "string", } }, @@ -306,12 +263,78 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): "type": "m.receipt", }, ], - [], + [ + { + "content": { + "$14356419edgd14394fHBLK:matrix.org": { + ReceiptTypes.READ: { + "@rikj:jki.re": "string", + } + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + }, + ], + ) + + def test_leaves_our_private_and_their_public(self): + self._test_filters_private( + [ + { + "content": { + "$1dgdgrd5641916114394fHBLK:matrix.org": { + ReceiptTypes.READ_PRIVATE: { + "@me:server.org": { + "ts": 1436451550453, + }, + }, + ReceiptTypes.READ: { + "@rikj:jki.re": { + "ts": 1436451550453, + }, + }, + "a.receipt.type": { + "@rikj:jki.re": { + "ts": 1436451550453, + }, + }, + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + [ + { + "content": { + "$1dgdgrd5641916114394fHBLK:matrix.org": { + ReceiptTypes.READ_PRIVATE: { + "@me:server.org": { + "ts": 1436451550453, + }, + }, + ReceiptTypes.READ: { + "@rikj:jki.re": { + "ts": 1436451550453, + }, + }, + "a.receipt.type": { + "@rikj:jki.re": { + "ts": 1436451550453, + }, + }, + } + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], ) - def _test_filters_hidden( + def _test_filters_private( self, events: List[JsonDict], expected_output: List[JsonDict] ): - """Tests that the _filter_out_hidden returns the expected output""" - filtered_events = self.event_source.filter_out_hidden(events, "@me:server.org") + """Tests that the _filter_out_private returns the expected output""" + filtered_events = self.event_source.filter_out_private(events, "@me:server.org") self.assertEqual(filtered_events, expected_output) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 45fd30cf43..b6ba19c739 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -193,8 +193,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self): - # Type ignore: mypy doesn't like us assigning to methods. - self.store.count_monthly_users = Mock( # type: ignore[assignment] + self.store.count_monthly_users = Mock( return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception @@ -202,8 +201,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_blocked(self): - # Type ignore: mypy doesn't like us assigning to methods. - self.store.get_monthly_active_count = Mock( # type: ignore[assignment] + self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) self.get_failure( @@ -211,8 +209,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - # Type ignore: mypy doesn't like us assigning to methods. - self.store.get_monthly_active_count = Mock( # type: ignore[assignment] + self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index d37292ce13..e74eb71774 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -1092,3 +1092,29 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): ) result = self.get_success(self.handler.get_room_summary(user2, self.room)) self.assertEqual(result.get("room_id"), self.room) + + def test_fed(self): + """ + Return data over federation and ensure that it is handled properly. + """ + fed_hostname = self.hs.hostname + "2" + fed_room = "#fed_room:" + fed_hostname + + requested_room_entry = _RoomEntry( + fed_room, + {"room_id": fed_room, "world_readable": True}, + ) + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return requested_room_entry, {}, set() + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_summary( + self.user, fed_room, remote_room_hosts=[fed_hostname] + ) + ) + self.assertEqual(result.get("room_id"), fed_room) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 8d4404eda1..a0f84e2940 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -349,4 +349,16 @@ class SamlHandlerTestCase(HomeserverTestCase): def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "getHeader", "_disconnected"]) + mock = Mock( + spec=[ + "finish", + "getClientAddress", + "getHeader", + "setHeader", + "setResponseCode", + "write", + ] + ) + # `_disconnected` musn't be another `Mock`, otherwise it will be truthy. + mock._disconnected = False + return mock diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ffd5c4cb93..5f2e26a5fc 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -65,11 +65,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) - mock_keyring.verify_json_for_server.return_value = defer.succeed(True) + mock_keyring.verify_json_for_server.return_value = make_awaitable(True) # we mock out the federation client too mock_federation_client = Mock(spec=["put_json"]) - mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) + mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) # the tests assume that we are starting at unix time 1000 reactor.pump((1000,)) @@ -98,7 +98,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore = hs.get_datastores().main self.datastore.get_destination_retry_timings = Mock( - return_value=defer.succeed(None) + return_value=make_awaitable(None) ) self.datastore.get_device_updates_by_remote = Mock( diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index c6e501c7be..96e2e3039b 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -15,7 +15,6 @@ from typing import Tuple from unittest.mock import Mock, patch from urllib.parse import quote -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -30,6 +29,7 @@ from synapse.util import Clock from tests import unittest from tests.storage.test_user_directory import GetUserDirectoryTables +from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_member_event from tests.unittest import override_config @@ -439,7 +439,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) - mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) + mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) with patch.object( self.store, "remove_from_user_dir", mock_remove_from_user_dir ): @@ -454,7 +454,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.store.register_user(user_id=r_user_id, password_hash=None) ) - mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) + mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) with patch.object( self.store, "remove_from_user_dir", mock_remove_from_user_dir ): diff --git a/tests/module_api/test_account_data_manager.py b/tests/module_api/test_account_data_manager.py index bec018d9e7..89009bea8c 100644 --- a/tests/module_api/test_account_data_manager.py +++ b/tests/module_api/test_account_data_manager.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import SynapseError from synapse.rest import admin +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -22,7 +26,9 @@ class ModuleApiTestCase(HomeserverTestCase): admin.register_servlets, ] - def prepare(self, reactor, clock, homeserver) -> None: + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self._store = homeserver.get_datastores().main self._module_api = homeserver.get_module_api() self._account_data_mgr = self._module_api.account_data_manager @@ -91,7 +97,7 @@ class ModuleApiTestCase(HomeserverTestCase): ) with self.assertRaises(TypeError): # This throws an exception because it's a frozen dict. - the_data["wombat"] = False + the_data["wombat"] = False # type: ignore[index] def test_put_global(self) -> None: """ @@ -143,15 +149,14 @@ class ModuleApiTestCase(HomeserverTestCase): with self.assertRaises(TypeError): # The account data type must be a string. self.get_success_or_raise( - self._module_api.account_data_manager.put_global( - self.user_id, 42, {} # type: ignore - ) + self._module_api.account_data_manager.put_global(self.user_id, 42, {}) # type: ignore[arg-type] ) with self.assertRaises(TypeError): # The account data dict must be a dict. + # noinspection PyTypeChecker self.get_success_or_raise( self._module_api.account_data_manager.put_global( - self.user_id, "test.data", 42 # type: ignore + self.user_id, "test.data", 42 # type: ignore[arg-type] ) ) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 9fd5d59c55..8bc84aaaca 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -19,8 +19,9 @@ from synapse.api.constants import EduTypes, EventTypes from synapse.events import EventBase from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState +from synapse.handlers.push_rules import InvalidRuleException from synapse.rest import admin -from synapse.rest.client import login, presence, profile, room +from synapse.rest.client import login, notifications, presence, profile, room from synapse.types import create_requester from tests.events.test_presence_router import send_presence_update, sync_presence @@ -38,6 +39,7 @@ class ModuleApiTestCase(HomeserverTestCase): room.register_servlets, presence.register_servlets, profile.register_servlets, + notifications.register_servlets, ] def prepare(self, reactor, clock, homeserver): @@ -553,6 +555,86 @@ class ModuleApiTestCase(HomeserverTestCase): self.assertEqual(state[("org.matrix.test", "")].state_key, "") self.assertEqual(state[("org.matrix.test", "")].content, {}) + def test_set_push_rules_action(self) -> None: + """Test that a module can change the actions of an existing push rule for a user.""" + + # Create a room with 2 users in it. Push rules must not match if the user is the + # event's sender, so we need one user to send messages and one user to receive + # notifications. + user_id = self.register_user("user", "password") + tok = self.login("user", "password") + + room_id = self.helper.create_room_as(user_id, is_public=True, tok=tok) + + user_id2 = self.register_user("user2", "password") + tok2 = self.login("user2", "password") + self.helper.join(room_id, user_id2, tok=tok2) + + # Register a 3rd user and join them to the room, so that we don't accidentally + # trigger 1:1 push rules. + user_id3 = self.register_user("user3", "password") + tok3 = self.login("user3", "password") + self.helper.join(room_id, user_id3, tok=tok3) + + # Send a message as the second user and check that it notifies. + res = self.helper.send(room_id=room_id, body="here's a message", tok=tok2) + event_id = res["event_id"] + + channel = self.make_request( + "GET", + "/notifications", + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) + self.assertEqual( + channel.json_body["notifications"][0]["event"]["event_id"], + event_id, + channel.json_body, + ) + + # Change the .m.rule.message actions to not notify on new messages. + self.get_success( + defer.ensureDeferred( + self.module_api.set_push_rule_action( + user_id=user_id, + scope="global", + kind="underride", + rule_id=".m.rule.message", + actions=["dont_notify"], + ) + ) + ) + + # Send another message as the second user and check that the number of + # notifications didn't change. + self.helper.send(room_id=room_id, body="here's another message", tok=tok2) + + channel = self.make_request( + "GET", + "/notifications?from=", + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) + + def test_check_push_rules_actions(self) -> None: + """Test that modules can check whether a list of push rules actions are spec + compliant. + """ + with self.assertRaises(InvalidRuleException): + self.module_api.check_push_rule_actions(["foo"]) + + with self.assertRaises(InvalidRuleException): + self.module_api.check_push_rule_actions({"foo": "bar"}) + + self.module_api.check_push_rule_actions(["notify"]) + + self.module_api.check_push_rule_actions( + [{"set_tweak": "sound", "value": "default"}] + ) + class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase): """For testing ModuleApi functionality in a multi-worker setup""" diff --git a/tests/replication/_base.py b/tests/replication/_base.py index a0589b6d6a..a7602b4c96 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -154,10 +154,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(port, 8765) # Set up client side protocol - client_protocol = client_factory.buildProtocol(None) + client_address = IPv4Address("TCP", "127.0.0.1", 1234) + client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234)) # Set up the server side protocol - channel = self.site.buildProtocol(None) + server_address = IPv4Address("TCP", host, port) + channel = self.site.buildProtocol((host, port)) # hook into the channel's request factory so that we can keep a record # of the requests @@ -173,12 +175,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Connect client to server and vice versa. client_to_server_transport = FakeTransport( - channel, self.reactor, client_protocol + channel, self.reactor, client_protocol, server_address, client_address ) client_protocol.makeConnection(client_to_server_transport) server_to_client_transport = FakeTransport( - client_protocol, self.reactor, channel + client_protocol, self.reactor, channel, client_address, server_address ) channel.makeConnection(server_to_client_transport) @@ -406,19 +408,21 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(port, repl_port) # Set up client side protocol - client_protocol = client_factory.buildProtocol(None) + client_address = IPv4Address("TCP", "127.0.0.1", 1234) + client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234)) # Set up the server side protocol - channel = self._hs_to_site[hs].buildProtocol(None) + server_address = IPv4Address("TCP", host, port) + channel = self._hs_to_site[hs].buildProtocol((host, port)) # Connect client to server and vice versa. client_to_server_transport = FakeTransport( - channel, self.reactor, client_protocol + channel, self.reactor, client_protocol, server_address, client_address ) client_protocol.makeConnection(client_to_server_transport) server_to_client_transport = FakeTransport( - client_protocol, self.reactor, channel + client_protocol, self.reactor, channel, client_address, server_address ) channel.makeConnection(server_to_client_transport) diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index f47d94f690..5bbbd5fbcb 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -12,23 +12,248 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.types import UserID, create_requester + +from tests.test_utils.event_injection import create_event from ._base import BaseSlavedStoreTestCase -USER_ID = "@feeling:blue" -ROOM_ID = "!room:blue" -EVENT_ID = "$event:blue" +OTHER_USER_ID = "@other:test" +OUR_USER_ID = "@our:test" class SlavedReceiptTestCase(BaseSlavedStoreTestCase): STORE_TYPE = SlavedReceiptsStore - def test_receipt(self): - self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) + def prepare(self, reactor, clock, homeserver): + super().prepare(reactor, clock, homeserver) + self.room_creator = homeserver.get_room_creation_handler() + self.persist_event_storage = self.hs.get_storage().persistence + + # Create a test user + self.ourUser = UserID.from_string(OUR_USER_ID) + self.ourRequester = create_requester(self.ourUser) + + # Create a second test user + self.otherUser = UserID.from_string(OTHER_USER_ID) + self.otherRequester = create_requester(self.otherUser) + + # Create a test room + info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {})) + self.room_id1 = info["room_id"] + + # Create a second test room + info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {})) + self.room_id2 = info["room_id"] + + # Join the second user to the first room + memberEvent, memberEventContext = self.get_success( + create_event( + self.hs, + room_id=self.room_id1, + type="m.room.member", + sender=self.otherRequester.user.to_string(), + state_key=self.otherRequester.user.to_string(), + content={"membership": "join"}, + ) + ) self.get_success( - self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}) + self.persist_event_storage.persist_event(memberEvent, memberEventContext) + ) + + # Join the second user to the second room + memberEvent, memberEventContext = self.get_success( + create_event( + self.hs, + room_id=self.room_id2, + type="m.room.member", + sender=self.otherRequester.user.to_string(), + state_key=self.otherRequester.user.to_string(), + content={"membership": "join"}, + ) + ) + self.get_success( + self.persist_event_storage.persist_event(memberEvent, memberEventContext) + ) + + def test_return_empty_with_no_data(self): + res = self.get_success( + self.master_store.get_receipts_for_user( + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, {}) + + res = self.get_success( + self.master_store.get_receipts_for_user_with_orderings( + OUR_USER_ID, + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + ) + ) + self.assertEqual(res, {}) + + res = self.get_success( + self.master_store.get_last_receipt_event_id_for_user( + OUR_USER_ID, + self.room_id1, + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + ) + ) + self.assertEqual(res, None) + + def test_get_receipts_for_user(self): + # Send some events into the first room + event1_1_id = self.create_and_send_event( + self.room_id1, UserID.from_string(OTHER_USER_ID) + ) + event1_2_id = self.create_and_send_event( + self.room_id1, UserID.from_string(OTHER_USER_ID) + ) + + # Send public read receipt for the first event + self.get_success( + self.master_store.insert_receipt( + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + ) + ) + # Send private read receipt for the second event + self.get_success( + self.master_store.insert_receipt( + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + ) + ) + + # Test we get the latest event when we want both private and public receipts + res = self.get_success( + self.master_store.get_receipts_for_user( + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, {self.room_id1: event1_2_id}) + + # Test we get the older event when we want only public receipt + res = self.get_success( + self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) + ) + self.assertEqual(res, {self.room_id1: event1_1_id}) + + # Test we get the latest event when we want only the public receipt + res = self.get_success( + self.master_store.get_receipts_for_user( + OUR_USER_ID, [ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, {self.room_id1: event1_2_id}) + + # Test receipt updating + self.get_success( + self.master_store.insert_receipt( + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + ) + ) + res = self.get_success( + self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) + ) + self.assertEqual(res, {self.room_id1: event1_2_id}) + + # Send some events into the second room + event2_1_id = self.create_and_send_event( + self.room_id2, UserID.from_string(OTHER_USER_ID) + ) + + # Test new room is reflected in what the method returns + self.get_success( + self.master_store.insert_receipt( + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + ) + ) + res = self.get_success( + self.master_store.get_receipts_for_user( + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) + + def test_get_last_receipt_event_id_for_user(self): + # Send some events into the first room + event1_1_id = self.create_and_send_event( + self.room_id1, UserID.from_string(OTHER_USER_ID) + ) + event1_2_id = self.create_and_send_event( + self.room_id1, UserID.from_string(OTHER_USER_ID) + ) + + # Send public read receipt for the first event + self.get_success( + self.master_store.insert_receipt( + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + ) + ) + # Send private read receipt for the second event + self.get_success( + self.master_store.insert_receipt( + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + ) + ) + + # Test we get the latest event when we want both private and public receipts + res = self.get_success( + self.master_store.get_last_receipt_event_id_for_user( + OUR_USER_ID, + self.room_id1, + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + ) + ) + self.assertEqual(res, event1_2_id) + + # Test we get the older event when we want only public receipt + res = self.get_success( + self.master_store.get_last_receipt_event_id_for_user( + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] + ) + ) + self.assertEqual(res, event1_1_id) + + # Test we get the latest event when we want only the private receipt + res = self.get_success( + self.master_store.get_last_receipt_event_id_for_user( + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, event1_2_id) + + # Test receipt updating + self.get_success( + self.master_store.insert_receipt( + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + ) + ) + res = self.get_success( + self.master_store.get_last_receipt_event_id_for_user( + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] + ) + ) + self.assertEqual(res, event1_2_id) + + # Send some events into the second room + event2_1_id = self.create_and_send_event( + self.room_id2, UserID.from_string(OTHER_USER_ID) + ) + + # Test new room is reflected in what the method returns + self.get_success( + self.master_store.insert_receipt( + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + ) + ) + res = self.get_success( + self.master_store.get_last_receipt_event_id_for_user( + OUR_USER_ID, + self.room_id2, + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + ) ) - self.replicate() - self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}) + self.assertEqual(res, event2_1_id) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index ba1a63c0d6..6104a55aa1 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -102,8 +102,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): for i in range(20): server_name = "other_server_%d" % (i,) room = self.create_room_with_remote_server(user, token, server_name) - mock_client1.reset_mock() # type: ignore[attr-defined] - mock_client2.reset_mock() # type: ignore[attr-defined] + mock_client1.reset_mock() + mock_client2.reset_mock() self.create_and_send_event(room, UserID.from_string(user)) self.replicate() @@ -167,8 +167,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): for i in range(20): server_name = "other_server_%d" % (i,) room = self.create_room_with_remote_server(user, token, server_name) - mock_client1.reset_mock() # type: ignore[attr-defined] - mock_client2.reset_mock() # type: ignore[attr-defined] + mock_client1.reset_mock() + mock_client2.reset_mock() self.get_success( typing_handler.started_typing( diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index e00b5c171c..e0a11da97b 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -520,8 +520,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": device_id, - # MSC3069 entered spec in Matrix 1.2 but maintained compatibility - "org.matrix.msc3069.is_guest": False, "is_guest": False, }, ) @@ -540,8 +538,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": device_id, - # MSC3069 entered spec in Matrix 1.2 but maintained compatibility - "org.matrix.msc3069.is_guest": True, "is_guest": True, }, ) @@ -564,8 +560,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): whoami, { "user_id": user_id, - # MSC3069 entered spec in Matrix 1.2 but maintained compatibility - "org.matrix.msc3069.is_guest": False, "is_guest": False, }, ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 0a3d017dc9..4920468f7a 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -81,11 +81,9 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"' # the query params in TEST_CLIENT_REDIRECT_URL EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')] -# (possibly experimental) login flows we expect to appear in the list after the normal -# ones +# Login flows we expect to appear in the list after the normal ones. ADDITIONAL_LOGIN_FLOWS = [ {"type": "m.login.application_service"}, - {"type": "uk.half-shot.msc2778.login.application_service"}, ] diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 0abe378fe4..b3738a0304 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -14,7 +14,6 @@ from http import HTTPStatus from unittest.mock import Mock -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.presence import PresenceHandler @@ -24,6 +23,7 @@ from synapse.types import UserID from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable class PresenceTestCase(unittest.HomeserverTestCase): @@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: presence_handler = Mock(spec=PresenceHandler) - presence_handler.set_state.return_value = defer.succeed(None) + presence_handler.set_state.return_value = make_awaitable(None) hs = self.setup_test_homeserver( "red", diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 39667e3225..27dee8f697 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -620,6 +620,19 @@ class RelationsTestCase(BaseRelationsTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) + # Directly requesting the edit should not have the edit to the edit applied. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{edit_event_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual("Wibble", channel.json_body["content"]["body"]) + self.assertIn("m.new_content", channel.json_body["content"]) + + # The relations information should not include the edit to the edit. + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def test_unknown_relations(self) -> None: """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") @@ -984,6 +997,24 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) + def test_annotation_to_annotation(self) -> None: + """Any relation to an annotation should be ignored.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + event_id = channel.json_body["event_id"] + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=event_id + ) + + # Fetch the initial annotation event to see if it has bundled aggregations. + channel = self.make_request( + "GET", + f"/_matrix/client/v3/rooms/{self.room}/event/{event_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + # The first annotationt should not have any bundled aggregations. + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def test_reference(self) -> None: """ Test that references get correctly bundled. @@ -1029,7 +1060,106 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations.get("latest_event"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 10) + + def test_thread_with_bundled_aggregations_for_latest(self) -> None: + """ + Bundled aggregations should get applied to the latest thread event. + """ + self._send_relation(RelationTypes.THREAD, "m.room.test") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_2 = channel.json_body["event_id"] + + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2 + ) + + def assert_thread(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertTrue(bundled_aggregations.get("current_user_participated")) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user_id, + "type": "m.room.test", + }, + bundled_aggregations.get("latest_event"), + ) + # Check the unsigned field on the latest event. + self.assert_dict( + { + "m.relations": { + RelationTypes.ANNOTATION: { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 1}, + ] + }, + } + }, + bundled_aggregations["latest_event"].get("unsigned"), + ) + + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 10) + + def test_nested_thread(self) -> None: + """ + Ensure that a nested thread gets ignored by bundled aggregations, as + those are forbidden. + """ + + # Start a thread. + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + reply_event_id = channel.json_body["event_id"] + + # Disable the validation to pretend this came over federation, since it is + # not an event the Client-Server API will allow.. + with patch( + "synapse.handlers.message.EventCreationHandler._validate_event_relation", + new=lambda self, event: make_awaitable(None), + ): + # Create a sub-thread off the thread, which is not allowed. + self._send_relation( + RelationTypes.THREAD, "m.room.test", parent_id=reply_event_id + ) + + # Fetch the thread root, to get the bundled aggregation for the thread. + relations_from_event = self._get_bundled_aggregations() + + # Ensure that requesting the room messages also does not return the sub-thread. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + event = self._find_event_in_chunk(channel.json_body["chunk"]) + relations_from_messages = event["unsigned"]["m.relations"] + + # Check the bundled aggregations from each point. + for aggregations, desc in ( + (relations_from_event, "/event"), + (relations_from_messages, "/messages"), + ): + # The latest event should have bundled aggregations. + self.assertIn(RelationTypes.THREAD, aggregations, desc) + thread_summary = aggregations[RelationTypes.THREAD] + self.assertIn("latest_event", thread_summary, desc) + self.assertEqual( + thread_summary["latest_event"]["event_id"], reply_event_id, desc + ) + + # The latest event should not have any bundled aggregations (since the + # only relation to it is another thread, which is invalid). + self.assertNotIn( + "m.relations", thread_summary["latest_event"]["unsigned"], desc + ) def test_thread_edit_latest_event(self) -> None: """Test that editing the latest event in a thread works.""" @@ -1049,6 +1179,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=threaded_event_id, ) + edit_event_id = channel.json_body["event_id"] # Fetch the thread root, to get the bundled aggregation for the thread. relations_dict = self._get_bundled_aggregations() @@ -1061,6 +1192,12 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assertIn("latest_event", thread_summary) latest_event_in_thread = thread_summary["latest_event"] self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") + # The latest event in the thread should have the edit appear under the + # bundled aggregations. + self.assertDictContainsSubset( + {"event_id": edit_event_id, "sender": "@alice:test"}, + latest_event_in_thread["unsigned"]["m.relations"][RelationTypes.REPLACE], + ) def test_aggregation_get_event_for_annotation(self) -> None: """Test that annotations do not get bundled aggregations included diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 6ff79b9e2e..9443daa056 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, List, Optional from unittest.mock import Mock, call from urllib import parse as urlparse -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -1426,9 +1425,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): def test_simple(self) -> None: "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined] - {} - ) + self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] search_filter = {"generic_search_term": "foobar"} @@ -1456,7 +1453,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] HttpResponseException(404, "Not Found", b""), - defer.succeed({}), + make_awaitable({}), ) search_filter = {"generic_search_term": "foobar"} diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 773c16a54c..0108337649 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -23,7 +23,7 @@ import synapse.rest.admin from synapse.api.constants import ( EventContentFields, EventTypes, - ReadReceiptEventFields, + ReceiptTypes, RelationTypes, ) from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync @@ -346,7 +346,7 @@ class SyncKnockTestCase( # Knock on a room channel = self.make_request( "POST", - "/_matrix/client/r0/knock/%s" % (self.room_id,), + f"/_matrix/client/r0/knock/{self.room_id}", b"{}", self.knocker_tok, ) @@ -407,22 +407,83 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_hidden_read_receipts(self) -> None: + def test_private_read_receipts(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) - # Send a read receipt to tell the server the first user's message was read - body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") + # Send a private read receipt to tell the server the first user's message was read channel = self.make_request( "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - body, + f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", + {}, access_token=self.tok2, ) self.assertEqual(channel.code, 200) - # Test that the first user can't see the other user's hidden read receipt - self.assertEqual(self._get_read_receipt(), None) + # Test that the first user can't see the other user's private read receipt + self.assertIsNone(self._get_read_receipt()) + + @override_config({"experimental_features": {"msc2285_enabled": True}}) + def test_public_receipt_can_override_private(self) -> None: + """ + Sending a public read receipt to the same event which has a private read + receipt should cause that receipt to become public. + """ + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a private read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + self.assertIsNone(self._get_read_receipt()) + + # Send a public read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that we did override the private read receipt + self.assertNotEqual(self._get_read_receipt(), None) + + @override_config({"experimental_features": {"msc2285_enabled": True}}) + def test_private_receipt_cannot_override_public(self) -> None: + """ + Sending a private read receipt to the same event which has a public read + receipt should cause no change. + """ + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a public read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + self.assertNotEqual(self._get_read_receipt(), None) + + # Send a private read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that we didn't override the public read receipt + self.assertIsNone(self._get_read_receipt()) @parameterized.expand( [ @@ -454,7 +515,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a read receipt for this message with an empty body channel = self.make_request( "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}", access_token=self.tok2, custom_headers=[("User-Agent", user_agent)], ) @@ -478,6 +539,9 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Store the next batch for the next request. self.next_batch = channel.json_body["next_batch"] + if channel.json_body.get("rooms", None) is None: + return None + # Return the read receipt ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ "ephemeral" @@ -498,7 +562,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): def default_config(self) -> JsonDict: config = super().default_config() - config["experimental_features"] = {"msc2654_enabled": True} + config["experimental_features"] = { + "msc2654_enabled": True, + "msc2285_enabled": True, + } return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -560,10 +627,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(1) # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8") channel = self.make_request( "POST", - "/rooms/%s/read_markers" % self.room_id, + f"/rooms/{self.room_id}/read_markers", body, access_token=self.tok, ) @@ -572,16 +639,15 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Check that the unread counter is back to 0. self._check_unread_count(0) - # Check that hidden read receipts don't break unread counts + # Check that private read receipts don't break unread counts res = self.helper.send(self.room_id, "hello", tok=self.tok2) self._check_unread_count(1) # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") channel = self.make_request( "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - body, + f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", + {}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) @@ -643,13 +709,73 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(4) # Check that tombstone events changes increase the unread counter. - self.helper.send_state( + res1 = self.helper.send_state( self.room_id, EventTypes.Tombstone, {"replacement_room": "!someroom:test"}, tok=self.tok2, ) self._check_unread_count(5) + res2 = self.helper.send(self.room_id, "hello", tok=self.tok2) + + # Make sure both m.read and org.matrix.msc2285.read.private advance + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}", + {}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + self._check_unread_count(1) + + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}", + {}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + self._check_unread_count(0) + + # We test for both receipt types that influence notification counts + @parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]) + def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None: + # Join the new user + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + + # Send messages + res1 = self.helper.send(self.room_id, "hello", tok=self.tok2) + res2 = self.helper.send(self.room_id, "hello", tok=self.tok2) + + # Read last event + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", + {}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + self._check_unread_count(0) + + # Make sure neither m.read nor org.matrix.msc2285.read.private make the + # read receipt go up to an older event + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}", + {}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + self._check_unread_count(0) + + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}", + {}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + self._check_unread_count(0) def _check_unread_count(self, expected_count: int) -> None: """Syncs and compares the unread count with the expected value.""" @@ -662,9 +788,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.json_body) - room_entry = channel.json_body["rooms"]["join"][self.room_id] + room_entry = ( + channel.json_body.get("rooms", {}).get("join", {}).get(self.room_id, {}) + ) self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], + room_entry.get("org.matrix.msc2654.unread_count", 0), expected_count, room_entry, ) diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 8d8251b2ac..21a1ca2a68 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -22,6 +22,7 @@ from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionC from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import MockClock @@ -38,7 +39,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_executes_given_function(self): - cb = Mock(return_value=defer.succeed(self.mock_http_response)) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg" ) @@ -47,7 +48,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_deduplicates_based_on_key(self): - cb = Mock(return_value=defer.succeed(self.mock_http_response)) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) for i in range(3): # invoke multiple times res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg", changing_args=i @@ -130,7 +131,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cleans_up(self): - cb = Mock(return_value=defer.succeed(self.mock_http_response)) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") # should NOT have cleaned up yet self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2) diff --git a/tests/server.py b/tests/server.py index 16559d2588..8f30e250c8 100644 --- a/tests/server.py +++ b/tests/server.py @@ -181,7 +181,7 @@ class FakeChannel: self.resource_usage = _self.logcontext.get_resource_usage() def getPeer(self): - # We give an address so that getClientIP returns a non null entry, + # We give an address so that getClientAddress/getClientIP returns a non null entry, # causing us to record the MAU return address.IPv4Address("TCP", self._ip, 3423) @@ -562,7 +562,10 @@ class FakeTransport: """ _peer_address: Optional[IAddress] = attr.ib(default=None) - """The value to be returend by getPeer""" + """The value to be returned by getPeer""" + + _host_address: Optional[IAddress] = attr.ib(default=None) + """The value to be returned by getHost""" disconnecting = False disconnected = False @@ -571,11 +574,11 @@ class FakeTransport: producer = attr.ib(default=None) autoflush = attr.ib(default=True) - def getPeer(self): + def getPeer(self) -> Optional[IAddress]: return self._peer_address - def getHost(self): - return None + def getHost(self) -> Optional[IAddress]: + return self._host_address def loseConnection(self, reason=None): if not self.disconnecting: diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 02b96c9e6e..9ee9509d3a 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -14,8 +14,6 @@ from unittest.mock import Mock -from twisted.internet import defer - from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.rest import admin @@ -68,16 +66,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return_value=make_awaitable(1000) ) self._rlsn._server_notices_manager.send_notice = Mock( - return_value=defer.succeed(Mock()) + return_value=make_awaitable(Mock()) ) self._send_notice = self._rlsn._server_notices_manager.send_notice self.user_id = "@user_id:test" self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( - return_value=defer.succeed("!something:localhost") + return_value=make_awaitable("!something:localhost") ) - self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) + self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) @override_config({"hs_disabled": True}) @@ -95,7 +93,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) + self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) @@ -111,7 +109,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user has blocked notice, but notice ought to be there (NOOP) """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") + return_value=make_awaitable(None), + side_effect=ResourceLimitError(403, "foo"), ) mock_event = Mock( @@ -130,7 +129,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user does not have blocked notice, but should have one """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") + return_value=make_awaitable(None), + side_effect=ResourceLimitError(403, "foo"), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -141,7 +141,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) + self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -152,7 +152,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) + self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=make_awaitable(None) ) @@ -167,7 +167,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): an alert message is not sent into the room """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), @@ -182,7 +182,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test that when a server is disabled, that MAU limit alerting is ignored. """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED ), @@ -199,14 +199,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): is suppressed that the room is returned to an unblocked state. """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), ) self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( - return_value=defer.succeed((True, [])) + return_value=make_awaitable((True, [])) ) mock_event = Mock( diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index bf6374f93d..c237a8c7e2 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -13,7 +13,7 @@ # limitations under the License. import json from contextlib import contextmanager -from typing import Generator, Tuple +from typing import Generator, List, Tuple from unittest import mock from twisted.enterprise.adbapi import ConnectionPool @@ -21,6 +21,7 @@ from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import EventFormatVersions, RoomVersions +from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room @@ -49,23 +50,28 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): ) ) - for idx, (rid, eid) in enumerate( + self.event_ids: List[str] = [] + for idx, rid in enumerate( ( - ("room1", "event10"), - ("room1", "event11"), - ("room1", "event12"), - ("room2", "event20"), + "room1", + "room1", + "room1", + "room2", ) ): + event_json = {"type": f"test {idx}", "room_id": rid} + event = make_event_from_dict(event_json, room_version=RoomVersions.V4) + event_id = event.event_id + self.get_success( self.store.db_pool.simple_insert( "events", { - "event_id": eid, + "event_id": event_id, "room_id": rid, "topological_ordering": idx, "stream_ordering": idx, - "type": "test", + "type": event.type, "processed": True, "outlier": False, }, @@ -75,21 +81,22 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): self.store.db_pool.simple_insert( "event_json", { - "event_id": eid, + "event_id": event_id, "room_id": rid, - "json": json.dumps({"type": "test", "room_id": rid}), + "json": json.dumps(event_json), "internal_metadata": "{}", "format_version": 3, }, ) ) + self.event_ids.append(event_id) def test_simple(self): with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", ["event10", "event19"]) + self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) ) - self.assertEqual(res, {"event10"}) + self.assertEqual(res, {self.event_ids[0]}) # that should result in a single db query self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) @@ -97,19 +104,21 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # a second lookup of the same events should cause no queries with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", ["event10", "event19"]) + self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) ) - self.assertEqual(res, {"event10"}) + self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) def test_query_via_event_cache(self): # fetch an event into the event cache - self.get_success(self.store.get_event("event10")) + self.get_success(self.store.get_event(self.event_ids[0])) # looking it up should now cause no db hits with LoggingContext(name="test") as ctx: - res = self.get_success(self.store.have_seen_events("room1", ["event10"])) - self.assertEqual(res, {"event10"}) + res = self.get_success( + self.store.have_seen_events("room1", [self.event_ids[0]]) + ) + self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) @@ -167,7 +176,6 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): self.store: EventsWorkerStore = hs.get_datastores().main self.room_id = f"!room:{hs.hostname}" - self.event_ids = [f"event{i}" for i in range(20)] self._populate_events() @@ -190,8 +198,14 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): ) ) - self.event_ids = [f"event{i}" for i in range(20)] - for idx, event_id in enumerate(self.event_ids): + self.event_ids: List[str] = [] + for idx in range(20): + event_json = { + "type": f"test {idx}", + "room_id": self.room_id, + } + event = make_event_from_dict(event_json, room_version=RoomVersions.V4) + event_id = event.event_id self.get_success( self.store.db_pool.simple_upsert( "events", @@ -201,7 +215,7 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): "room_id": self.room_id, "topological_ordering": idx, "stream_ordering": idx, - "type": "test", + "type": event.type, "processed": True, "outlier": False, }, @@ -213,12 +227,13 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): {"event_id": event_id}, { "room_id": self.room_id, - "json": json.dumps({"type": "test", "room_id": self.room_id}), + "json": json.dumps(event_json), "internal_metadata": "{}", "format_version": EventFormatVersions.V3, }, ) ) + self.event_ids.append(event_id) @contextmanager def _outage(self) -> Generator[None, None, None]: diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 60c8d37594..0fbf465670 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -14,7 +14,6 @@ from typing import Any, Dict, List from unittest.mock import Mock -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes @@ -259,10 +258,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_should_update(self): self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment] + self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(None) + return_value=make_awaitable(None) ) d = self.store.populate_monthly_active_users("user_id") self.get_success(d) @@ -272,9 +271,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_should_not_update(self): self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment] + self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(self.hs.get_clock().time_msec()) + return_value=make_awaitable(self.hs.get_clock().time_msec()) ) d = self.store.populate_monthly_active_users("user_id") diff --git a/tests/test_federation.py b/tests/test_federation.py index c39816de85..0cbef70bfa 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -233,7 +233,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register mock device list retrieval on the federation client. federation_client = self.homeserver.get_federation_client() federation_client.query_user_devices = Mock( - return_value=succeed( + return_value=make_awaitable( { "user_id": remote_user_id, "stream_id": 1, diff --git a/tests/test_mau.py b/tests/test_mau.py index 46bd3075de..5bbc361aa2 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -14,6 +14,8 @@ """Tests REST events for /rooms paths.""" +from typing import List + from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.appservice import ApplicationService @@ -229,6 +231,78 @@ class TestMauLimit(unittest.HomeserverTestCase): self.reactor.advance(100) self.assertEqual(2, self.successResultOf(count)) + @override_config( + { + "mau_trial_days": 3, + "mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2}, + } + ) + def test_as_trial_days(self): + user_tokens: List[str] = [] + + def advance_time_and_sync(): + self.reactor.advance(24 * 60 * 61) + for token in user_tokens: + self.do_sync_for_user(token) + + # Cheekily add an application service that we use to register a new user + # with. + as_token_1 = "foobartoken1" + self.store.services_cache.append( + ApplicationService( + token=as_token_1, + hostname=self.hs.hostname, + id="SomeASID", + sender="@as_sender_1:test", + namespaces={"users": [{"regex": "@as_1.*", "exclusive": True}]}, + ) + ) + + as_token_2 = "foobartoken2" + self.store.services_cache.append( + ApplicationService( + token=as_token_2, + hostname=self.hs.hostname, + id="AnotherASID", + sender="@as_sender_2:test", + namespaces={"users": [{"regex": "@as_2.*", "exclusive": True}]}, + ) + ) + + user_tokens.append(self.create_user("kermit1")) + user_tokens.append(self.create_user("kermit2")) + user_tokens.append( + self.create_user("as_1kermit3", token=as_token_1, appservice=True) + ) + user_tokens.append( + self.create_user("as_2kermit4", token=as_token_2, appservice=True) + ) + + # Advance time by 1 day to include the first appservice + advance_time_and_sync() + self.assertEqual( + self.get_success(self.store.get_monthly_active_count_by_service()), + {"SomeASID": 1}, + ) + + # Advance time by 1 day to include the next appservice + advance_time_and_sync() + self.assertEqual( + self.get_success(self.store.get_monthly_active_count_by_service()), + {"SomeASID": 1, "AnotherASID": 1}, + ) + + # Advance time by 1 day to include the native users + advance_time_and_sync() + self.assertEqual( + self.get_success(self.store.get_monthly_active_count_by_service()), + { + "SomeASID": 1, + "AnotherASID": 1, + "native": 2, + }, + ) + def create_user(self, localpart, token=None, appservice=False): request_data = { "username": localpart, diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index f05a373aa0..0d0d6faf0d 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -52,7 +52,7 @@ def make_awaitable(result: TV) -> Awaitable[TV]: This uses Futures as they can be awaited multiple times so can be returned to multiple callers. """ - future = Future() # type: ignore + future: Future[TV] = Future() future.set_result(result) return future @@ -69,7 +69,7 @@ def setup_awaitable_errors() -> Callable[[], None]: # State shared between unraisablehook and check_for_unraisable_exceptions. unraisable_exceptions = [] - orig_unraisablehook = sys.unraisablehook # type: ignore + orig_unraisablehook = sys.unraisablehook def unraisablehook(unraisable): unraisable_exceptions.append(unraisable.exc_value) @@ -78,11 +78,11 @@ def setup_awaitable_errors() -> Callable[[], None]: """ A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. """ - sys.unraisablehook = orig_unraisablehook # type: ignore + sys.unraisablehook = orig_unraisablehook if unraisable_exceptions: raise unraisable_exceptions.pop() - sys.unraisablehook = unraisablehook # type: ignore + sys.unraisablehook = unraisablehook return cleanup diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index 51a197a8c6..9228454c9e 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -27,7 +27,7 @@ class ToTwistedHandler(logging.Handler): def emit(self, record): log_entry = self.format(record) log_level = record.levelname.lower().replace("warning", "warn") - self.tx_log.emit( # type: ignore + self.tx_log.emit( twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry ) |