diff --git a/.buildkite/scripts/test_old_deps.sh b/.buildkite/scripts/test_old_deps.sh
index 3753f41a40..9270d55f04 100755
--- a/.buildkite/scripts/test_old_deps.sh
+++ b/.buildkite/scripts/test_old_deps.sh
@@ -1,6 +1,6 @@
#!/usr/bin/env bash
-# this script is run by buildkite in a plain `xenial` container; it installs the
+# this script is run by buildkite in a plain `bionic` container; it installs the
# minimal requirements for tox and hands over to the py3-old tox environment.
set -ex
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
new file mode 100644
index 0000000000..12c82ac620
--- /dev/null
+++ b/.github/workflows/tests.yml
@@ -0,0 +1,322 @@
+name: Tests
+
+on:
+ push:
+ branches: ["develop", "release-*"]
+ pull_request:
+
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ toxenv:
+ - "check-sampleconfig"
+ - "check_codestyle"
+ - "check_isort"
+ - "mypy"
+ - "packaging"
+
+ steps:
+ - uses: actions/checkout@v2
+ - uses: actions/setup-python@v2
+ - run: pip install tox
+ - run: tox -e ${{ matrix.toxenv }}
+
+ lint-crlf:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - name: Check line endings
+ run: scripts-dev/check_line_terminators.sh
+
+ lint-newsfile:
+ if: ${{ github.base_ref == 'develop' || contains(github.base_ref, 'release-') }}
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - uses: actions/setup-python@v2
+ - run: pip install tox
+ - name: Patch Buildkite-specific test script
+ run: |
+ sed -i -e 's/\$BUILDKITE_PULL_REQUEST/${{ github.event.number }}/' \
+ scripts-dev/check-newsfragment
+ - run: scripts-dev/check-newsfragment
+
+ lint-sdist:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - uses: actions/setup-python@v2
+ with:
+ python-version: "3.x"
+ - run: pip install wheel
+ - run: python setup.py sdist bdist_wheel
+ - uses: actions/upload-artifact@v2
+ with:
+ name: Python Distributions
+ path: dist/*
+
+ # Dummy step to gate other tests on without repeating the whole list
+ linting-done:
+ if: ${{ always() }} # Run this even if prior jobs were skipped
+ needs: [lint, lint-crlf, lint-newsfile, lint-sdist]
+ runs-on: ubuntu-latest
+ steps:
+ - run: "true"
+
+ trial:
+ if: ${{ !failure() }} # Allow previous steps to be skipped, but not fail
+ needs: linting-done
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.6", "3.7", "3.8", "3.9"]
+ database: ["sqlite"]
+ include:
+ # Newest Python without optional deps
+ - python-version: "3.9"
+ toxenv: "py-noextras,combine"
+
+ # Oldest Python with PostgreSQL
+ - python-version: "3.6"
+ database: "postgres"
+ postgres-version: "9.6"
+
+ # Newest Python with PostgreSQL
+ - python-version: "3.9"
+ database: "postgres"
+ postgres-version: "13"
+
+ steps:
+ - uses: actions/checkout@v2
+ - run: sudo apt-get -qq install xmlsec1
+ - name: Set up PostgreSQL ${{ matrix.postgres-version }}
+ if: ${{ matrix.postgres-version }}
+ run: |
+ docker run -d -p 5432:5432 \
+ -e POSTGRES_PASSWORD=postgres \
+ -e POSTGRES_INITDB_ARGS="--lc-collate C --lc-ctype C --encoding UTF8" \
+ postgres:${{ matrix.postgres-version }}
+ - uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - run: pip install tox
+ - name: Await PostgreSQL
+ if: ${{ matrix.postgres-version }}
+ timeout-minutes: 2
+ run: until pg_isready -h localhost; do sleep 1; done
+ - run: tox -e py,combine
+ env:
+ TRIAL_FLAGS: "--jobs=2"
+ SYNAPSE_POSTGRES: ${{ matrix.database == 'postgres' || '' }}
+ SYNAPSE_POSTGRES_HOST: localhost
+ SYNAPSE_POSTGRES_USER: postgres
+ SYNAPSE_POSTGRES_PASSWORD: postgres
+ - name: Dump logs
+ # Note: Dumps to workflow logs instead of using actions/upload-artifact
+ # This keeps logs colocated with failing jobs
+ # It also ignores find's exit code; this is a best effort affair
+ run: >-
+ find _trial_temp -name '*.log'
+ -exec echo "::group::{}" \;
+ -exec cat {} \;
+ -exec echo "::endgroup::" \;
+ || true
+
+ trial-olddeps:
+ if: ${{ !failure() }} # Allow previous steps to be skipped, but not fail
+ needs: linting-done
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - name: Test with old deps
+ uses: docker://ubuntu:bionic # For old python and sqlite
+ with:
+ workdir: /github/workspace
+ entrypoint: .buildkite/scripts/test_old_deps.sh
+ env:
+ TRIAL_FLAGS: "--jobs=2"
+ - name: Dump logs
+ # Note: Dumps to workflow logs instead of using actions/upload-artifact
+ # This keeps logs colocated with failing jobs
+ # It also ignores find's exit code; this is a best effort affair
+ run: >-
+ find _trial_temp -name '*.log'
+ -exec echo "::group::{}" \;
+ -exec cat {} \;
+ -exec echo "::endgroup::" \;
+ || true
+
+ trial-pypy:
+ # Very slow; only run if the branch name includes 'pypy'
+ if: ${{ contains(github.ref, 'pypy') && !failure() }}
+ needs: linting-done
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["pypy-3.6"]
+
+ steps:
+ - uses: actions/checkout@v2
+ - run: sudo apt-get -qq install xmlsec1 libxml2-dev libxslt-dev
+ - uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - run: pip install tox
+ - run: tox -e py,combine
+ env:
+ TRIAL_FLAGS: "--jobs=2"
+ - name: Dump logs
+ # Note: Dumps to workflow logs instead of using actions/upload-artifact
+ # This keeps logs colocated with failing jobs
+ # It also ignores find's exit code; this is a best effort affair
+ run: >-
+ find _trial_temp -name '*.log'
+ -exec echo "::group::{}" \;
+ -exec cat {} \;
+ -exec echo "::endgroup::" \;
+ || true
+
+ sytest:
+ if: ${{ !failure() }}
+ needs: linting-done
+ runs-on: ubuntu-latest
+ container:
+ image: matrixdotorg/sytest-synapse:${{ matrix.sytest-tag }}
+ volumes:
+ - ${{ github.workspace }}:/src
+ env:
+ BUILDKITE_BRANCH: ${{ github.head_ref }}
+ POSTGRES: ${{ matrix.postgres && 1}}
+ MULTI_POSTGRES: ${{ (matrix.postgres == 'multi-postgres') && 1}}
+ WORKERS: ${{ matrix.workers && 1 }}
+ REDIS: ${{ matrix.redis && 1 }}
+ BLACKLIST: ${{ matrix.workers && 'synapse-blacklist-with-workers' }}
+
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - sytest-tag: bionic
+
+ - sytest-tag: bionic
+ postgres: postgres
+
+ - sytest-tag: testing
+ postgres: postgres
+
+ - sytest-tag: bionic
+ postgres: multi-postgres
+ workers: workers
+
+ - sytest-tag: buster
+ postgres: multi-postgres
+ workers: workers
+
+ - sytest-tag: buster
+ postgres: postgres
+ workers: workers
+ redis: redis
+
+ steps:
+ - uses: actions/checkout@v2
+ - name: Prepare test blacklist
+ run: cat sytest-blacklist .buildkite/worker-blacklist > synapse-blacklist-with-workers
+ - name: Run SyTest
+ run: /bootstrap.sh synapse
+ working-directory: /src
+ - name: Dump results.tap
+ if: ${{ always() }}
+ run: cat /logs/results.tap
+ - name: Upload SyTest logs
+ uses: actions/upload-artifact@v2
+ if: ${{ always() }}
+ with:
+ name: Sytest Logs - ${{ job.status }} - (${{ join(matrix.*, ', ') }})
+ path: |
+ /logs/results.tap
+ /logs/**/*.log*
+
+ portdb:
+ if: ${{ !failure() }} # Allow previous steps to be skipped, but not fail
+ needs: linting-done
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ include:
+ - python-version: "3.6"
+ postgres-version: "9.6"
+
+ - python-version: "3.9"
+ postgres-version: "13"
+
+ services:
+ postgres:
+ image: postgres:${{ matrix.postgres-version }}
+ ports:
+ - 5432:5432
+ env:
+ POSTGRES_PASSWORD: "postgres"
+ POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
+
+ steps:
+ - uses: actions/checkout@v2
+ - run: sudo apt-get -qq install xmlsec1
+ - uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Patch Buildkite-specific test scripts
+ run: |
+ sed -i -e 's/host="postgres"/host="localhost"/' .buildkite/scripts/create_postgres_db.py
+ sed -i -e 's/host: postgres/host: localhost/' .buildkite/postgres-config.yaml
+ sed -i -e 's|/src/||' .buildkite/{sqlite,postgres}-config.yaml
+ sed -i -e 's/\$TOP/\$GITHUB_WORKSPACE/' .coveragerc
+ - run: .buildkite/scripts/test_synapse_port_db.sh
+
+ complement:
+ if: ${{ !failure() }}
+ needs: linting-done
+ runs-on: ubuntu-latest
+ container:
+ # https://github.com/matrix-org/complement/blob/master/dockerfiles/ComplementCIBuildkite.Dockerfile
+ image: matrixdotorg/complement:latest
+ env:
+ CI: true
+ ports:
+ - 8448:8448
+ volumes:
+ - /var/run/docker.sock:/var/run/docker.sock
+
+ steps:
+ - name: Run actions/checkout@v2 for synapse
+ uses: actions/checkout@v2
+ with:
+ path: synapse
+
+ - name: Run actions/checkout@v2 for complement
+ uses: actions/checkout@v2
+ with:
+ repository: "matrix-org/complement"
+ path: complement
+
+ # Build initial Synapse image
+ - run: docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile .
+ working-directory: synapse
+
+ # Build a ready-to-run Synapse image based on the initial image above.
+ # This new image includes a config file, keys for signing and TLS, and
+ # other settings to make it suitable for testing under Complement.
+ - run: docker build -t complement-synapse -f Synapse.Dockerfile .
+ working-directory: complement/dockerfiles
+
+ # Run Complement
+ - run: go test -v -tags synapse_blacklist ./tests
+ env:
+ COMPLEMENT_BASE_IMAGE: complement-synapse:latest
+ working-directory: complement
diff --git a/CHANGES.md b/CHANGES.md
index 27483532d0..482863c0e8 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,84 @@
+Synapse 1.32.0 (2021-04-20)
+===========================
+
+**Note:** This release requires Python 3.6+ and Postgres 9.6+ or SQLite 3.22+.
+
+This release removes the deprecated `GET /_synapse/admin/v1/users/<user_id>` admin API. Please use the [v2 API](https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/user_admin_api.rst#query-user-account) instead, which has improved capabilities.
+
+This release requires Application Services to use type `m.login.application_service` when registering users via the `/_matrix/client/r0/register` endpoint to comply with the spec. Please ensure your Application Services are up to date.
+
+Bugfixes
+--------
+
+- Fix the log lines of nested logging contexts. Broke in 1.32.0rc1. ([\#9829](https://github.com/matrix-org/synapse/issues/9829))
+
+
+Synapse 1.32.0rc1 (2021-04-13)
+==============================
+
+Features
+--------
+
+- Add a Synapse module for routing presence updates between users. ([\#9491](https://github.com/matrix-org/synapse/issues/9491))
+- Add an admin API to manage ratelimit for a specific user. ([\#9648](https://github.com/matrix-org/synapse/issues/9648))
+- Include request information in structured logging output. ([\#9654](https://github.com/matrix-org/synapse/issues/9654))
+- Add `order_by` to the admin API `GET /_synapse/admin/v2/users`. Contributed by @dklimpel. ([\#9691](https://github.com/matrix-org/synapse/issues/9691))
+- Replace the `room_invite_state_types` configuration setting with `room_prejoin_state`. ([\#9700](https://github.com/matrix-org/synapse/issues/9700))
+- Add experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership. ([\#9717](https://github.com/matrix-org/synapse/issues/9717), [\#9735](https://github.com/matrix-org/synapse/issues/9735))
+- Update experimental support for Spaces: include `m.room.create` in the room state sent with room-invites. ([\#9710](https://github.com/matrix-org/synapse/issues/9710))
+- Synapse now requires Python 3.6 or later. It also requires Postgres 9.6 or later or SQLite 3.22 or later. ([\#9766](https://github.com/matrix-org/synapse/issues/9766))
+
+
+Bugfixes
+--------
+
+- Prevent `synapse_forward_extremities` and `synapse_excess_extremity_events` Prometheus metrics from initially reporting zero-values after startup. ([\#8926](https://github.com/matrix-org/synapse/issues/8926))
+- Fix recently added ratelimits to correctly honour the application service `rate_limited` flag. ([\#9711](https://github.com/matrix-org/synapse/issues/9711))
+- Fix longstanding bug which caused `duplicate key value violates unique constraint "remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"` errors. ([\#9725](https://github.com/matrix-org/synapse/issues/9725))
+- Fix bug where sharded federation senders could get stuck repeatedly querying the DB in a loop, using lots of CPU. ([\#9770](https://github.com/matrix-org/synapse/issues/9770))
+- Fix duplicate logging of exceptions thrown during federation transaction processing. ([\#9780](https://github.com/matrix-org/synapse/issues/9780))
+
+
+Updates to the Docker image
+---------------------------
+
+- Move opencontainers labels to the final Docker image such that users can inspect them. ([\#9765](https://github.com/matrix-org/synapse/issues/9765))
+
+
+Improved Documentation
+----------------------
+
+- Make the `allowed_local_3pids` regex example in the sample config stricter. ([\#9719](https://github.com/matrix-org/synapse/issues/9719))
+
+
+Deprecations and Removals
+-------------------------
+
+- Remove old admin API `GET /_synapse/admin/v1/users/<user_id>`. ([\#9401](https://github.com/matrix-org/synapse/issues/9401))
+- Make `/_matrix/client/r0/register` expect a type of `m.login.application_service` when an Application Service registers a user, to align with [the relevant spec](https://spec.matrix.org/unstable/application-service-api/#server-admin-style-permissions). ([\#9548](https://github.com/matrix-org/synapse/issues/9548))
+
+
+Internal Changes
+----------------
+
+- Replace deprecated `imp` module with successor `importlib`. Contributed by Cristina Muñoz. ([\#9718](https://github.com/matrix-org/synapse/issues/9718))
+- Experiment with GitHub Actions for CI. ([\#9661](https://github.com/matrix-org/synapse/issues/9661))
+- Introduce flake8-bugbear to the test suite and fix some of its lint violations. ([\#9682](https://github.com/matrix-org/synapse/issues/9682))
+- Update `scripts-dev/complement.sh` to use a local checkout of Complement, allow running a subset of tests and have it use Synapse's Complement test blacklist. ([\#9685](https://github.com/matrix-org/synapse/issues/9685))
+- Improve Jaeger tracing for `to_device` messages. ([\#9686](https://github.com/matrix-org/synapse/issues/9686))
+- Add release helper script for automating part of the Synapse release process. ([\#9713](https://github.com/matrix-org/synapse/issues/9713))
+- Add type hints to expiring cache. ([\#9730](https://github.com/matrix-org/synapse/issues/9730))
+- Convert various testcases to `HomeserverTestCase`. ([\#9736](https://github.com/matrix-org/synapse/issues/9736))
+- Start linting mypy with `no_implicit_optional`. ([\#9742](https://github.com/matrix-org/synapse/issues/9742))
+- Add missing type hints to federation handler and server. ([\#9743](https://github.com/matrix-org/synapse/issues/9743))
+- Check that a `ConfigError` is raised, rather than simply `Exception`, when appropriate in homeserver config file generation tests. ([\#9753](https://github.com/matrix-org/synapse/issues/9753))
+- Fix incompatibility with `tox` 2.5. ([\#9769](https://github.com/matrix-org/synapse/issues/9769))
+- Enable Complement tests for [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946): Spaces Summary API. ([\#9771](https://github.com/matrix-org/synapse/issues/9771))
+- Use mock from the standard library instead of a separate package. ([\#9772](https://github.com/matrix-org/synapse/issues/9772))
+- Update Black configuration to target Python 3.6. ([\#9781](https://github.com/matrix-org/synapse/issues/9781))
+- Add option to skip unit tests when building Debian packages. ([\#9793](https://github.com/matrix-org/synapse/issues/9793))
+
+
Synapse 1.31.0 (2021-04-06)
===========================
diff --git a/README.rst b/README.rst
index 655a2bf3be..1a5503572e 100644
--- a/README.rst
+++ b/README.rst
@@ -393,7 +393,12 @@ massive excess of outgoing federation requests (see `discussion
indicate that your server is also issuing far more outgoing federation
requests than can be accounted for by your users' activity, this is a
likely cause. The misbehavior can be worked around by setting
-``use_presence: false`` in the Synapse config file.
+the following in the Synapse config file:
+
+.. code-block:: yaml
+
+ presence:
+ enabled: false
People can't accept room invitations from me
--------------------------------------------
diff --git a/UPGRADE.rst b/UPGRADE.rst
index ba488e1041..7a9b869055 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -85,6 +85,37 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
+Upgrading to v1.32.0
+====================
+
+Dropping support for old Python, Postgres and SQLite versions
+-------------------------------------------------------------
+
+In line with our `deprecation policy <https://github.com/matrix-org/synapse/blob/release-v1.32.0/docs/deprecation_policy.md>`_,
+we've dropped support for Python 3.5 and PostgreSQL 9.5, as they are no longer supported upstream.
+
+This release of Synapse requires Python 3.6+ and PostgresSQL 9.6+ or SQLite 3.22+.
+
+Removal of old List Accounts Admin API
+--------------------------------------
+
+The deprecated v1 "list accounts" admin API (``GET /_synapse/admin/v1/users/<user_id>``) has been removed in this version.
+
+The `v2 list accounts API <https://github.com/matrix-org/synapse/blob/master/docs/admin_api/user_admin_api.rst#list-accounts>`_
+has been available since Synapse 1.7.0 (2019-12-13), and is accessible under ``GET /_synapse/admin/v2/users``.
+
+The deprecation of the old endpoint was announced with Synapse 1.28.0 (released on 2021-02-25).
+
+Application Services must use type ``m.login.application_service`` when registering users
+-----------------------------------------------------------------------------------------
+
+In compliance with the
+`Application Service spec <https://matrix.org/docs/spec/application_service/r0.1.2#server-admin-style-permissions>`_,
+Application Services are now required to use the ``m.login.application_service`` type when registering users via the
+``/_matrix/client/r0/register`` endpoint. This behaviour was deprecated in Synapse v1.30.0.
+
+Please ensure your Application Services are up to date.
+
Upgrading to v1.29.0
====================
diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py
index 67e032244e..856dd437db 100755
--- a/contrib/cmdclient/console.py
+++ b/contrib/cmdclient/console.py
@@ -24,6 +24,7 @@ import sys
import time
import urllib
from http import TwistedHttpClient
+from typing import Optional
import nacl.encoding
import nacl.signing
@@ -718,7 +719,7 @@ class SynapseCmd(cmd.Cmd):
method,
path,
data=None,
- query_params={"access_token": None},
+ query_params: Optional[dict] = None,
alt_text=None,
):
"""Runs an HTTP request and pretty prints the output.
@@ -729,6 +730,8 @@ class SynapseCmd(cmd.Cmd):
data: Raw JSON data if any
query_params: dict of query parameters to add to the url
"""
+ query_params = query_params or {"access_token": None}
+
url = self._url() + path
if "access_token" in query_params:
query_params["access_token"] = self._tok()
diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py
index 851e80c25b..1cf913756e 100644
--- a/contrib/cmdclient/http.py
+++ b/contrib/cmdclient/http.py
@@ -16,6 +16,7 @@
import json
import urllib
from pprint import pformat
+from typing import Optional
from twisted.internet import defer, reactor
from twisted.web.client import Agent, readBody
@@ -85,8 +86,9 @@ class TwistedHttpClient(HttpClient):
body = yield readBody(response)
defer.returnValue(json.loads(body))
- def _create_put_request(self, url, json_data, headers_dict={}):
+ def _create_put_request(self, url, json_data, headers_dict: Optional[dict] = None):
"""Wrapper of _create_request to issue a PUT request"""
+ headers_dict = headers_dict or {}
if "Content-Type" not in headers_dict:
raise defer.error(RuntimeError("Must include Content-Type header for PUTs"))
@@ -95,14 +97,22 @@ class TwistedHttpClient(HttpClient):
"PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict
)
- def _create_get_request(self, url, headers_dict={}):
+ def _create_get_request(self, url, headers_dict: Optional[dict] = None):
"""Wrapper of _create_request to issue a GET request"""
- return self._create_request("GET", url, headers_dict=headers_dict)
+ return self._create_request("GET", url, headers_dict=headers_dict or {})
@defer.inlineCallbacks
def do_request(
- self, method, url, data=None, qparams=None, jsonreq=True, headers={}
+ self,
+ method,
+ url,
+ data=None,
+ qparams=None,
+ jsonreq=True,
+ headers: Optional[dict] = None,
):
+ headers = headers or {}
+
if qparams:
url = "%s?%s" % (url, urllib.urlencode(qparams, True))
@@ -123,8 +133,12 @@ class TwistedHttpClient(HttpClient):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def _create_request(self, method, url, producer=None, headers_dict={}):
+ def _create_request(
+ self, method, url, producer=None, headers_dict: Optional[dict] = None
+ ):
"""Creates and sends a request to the given url"""
+ headers_dict = headers_dict or {}
+
headers_dict["User-Agent"] = ["Synapse Cmd Client"]
retries_left = 5
diff --git a/debian/build_virtualenv b/debian/build_virtualenv
index cad7d16883..21caad90cc 100755
--- a/debian/build_virtualenv
+++ b/debian/build_virtualenv
@@ -50,15 +50,24 @@ PACKAGE_BUILD_DIR="debian/matrix-synapse-py3"
VIRTUALENV_DIR="${PACKAGE_BUILD_DIR}${DH_VIRTUALENV_INSTALL_ROOT}/matrix-synapse"
TARGET_PYTHON="${VIRTUALENV_DIR}/bin/python"
-# we copy the tests to a temporary directory so that we can put them on the
-# PYTHONPATH without putting the uninstalled synapse on the pythonpath.
-tmpdir=`mktemp -d`
-trap "rm -r $tmpdir" EXIT
+case "$DEB_BUILD_OPTIONS" in
+ *nocheck*)
+ # Skip running tests if "nocheck" present in $DEB_BUILD_OPTIONS
+ ;;
+
+ *)
+ # Copy tests to a temporary directory so that we can put them on the
+ # PYTHONPATH without putting the uninstalled synapse on the pythonpath.
+ tmpdir=`mktemp -d`
+ trap "rm -r $tmpdir" EXIT
+
+ cp -r tests "$tmpdir"
-cp -r tests "$tmpdir"
+ PYTHONPATH="$tmpdir" \
+ "${TARGET_PYTHON}" -m twisted.trial --reporter=text -j2 tests
-PYTHONPATH="$tmpdir" \
- "${TARGET_PYTHON}" -m twisted.trial --reporter=text -j2 tests
+ ;;
+esac
# build the config file
"${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_config" \
diff --git a/debian/changelog b/debian/changelog
index 09602ff54b..83be4497ec 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,13 @@
+matrix-synapse-py3 (1.32.0) stable; urgency=medium
+
+ [ Dan Callahan ]
+ * Skip tests when DEB_BUILD_OPTIONS contains "nocheck".
+
+ [ Synapse Packaging team ]
+ * New synapse release 1.32.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Tue, 20 Apr 2021 14:28:39 +0100
+
matrix-synapse-py3 (1.31.0) stable; urgency=medium
* New synapse release 1.31.0.
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 5b7bf02776..4f5cd06d72 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -18,11 +18,6 @@ ARG PYTHON_VERSION=3.8
###
FROM docker.io/python:${PYTHON_VERSION}-slim as builder
-LABEL org.opencontainers.image.url='https://matrix.org/docs/projects/server/synapse'
-LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/synapse/blob/master/docker/README.md'
-LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git'
-LABEL org.opencontainers.image.licenses='Apache-2.0'
-
# install the OS build deps
RUN apt-get update && apt-get install -y \
build-essential \
@@ -66,6 +61,11 @@ RUN pip install --prefix="/install" --no-deps --no-warn-script-location /synapse
FROM docker.io/python:${PYTHON_VERSION}-slim
+LABEL org.opencontainers.image.url='https://matrix.org/docs/projects/server/synapse'
+LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/synapse/blob/master/docker/README.md'
+LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git'
+LABEL org.opencontainers.image.licenses='Apache-2.0'
+
RUN apt-get update && apt-get install -y \
curl \
gosu \
diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml
index 0dea62a87d..a792899540 100644
--- a/docker/conf/homeserver.yaml
+++ b/docker/conf/homeserver.yaml
@@ -173,18 +173,10 @@ report_stats: False
## API Configuration ##
-room_invite_state_types:
- - "m.room.join_rules"
- - "m.room.canonical_alias"
- - "m.room.avatar"
- - "m.room.name"
-
{% if SYNAPSE_APPSERVICES %}
app_service_config_files:
{% for appservice in SYNAPSE_APPSERVICES %} - "{{ appservice }}"
{% endfor %}
-{% else %}
-app_service_config_files: []
{% endif %}
macaroon_secret_key: "{{ SYNAPSE_MACAROON_SECRET_KEY }}"
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 8d4ec5a6f9..dbce9c90b6 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -111,35 +111,16 @@ List Accounts
=============
This API returns all local user accounts.
+By default, the response is ordered by ascending user ID.
-The api is::
+The API is::
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
-The parameter ``from`` is optional but used for pagination, denoting the
-offset in the returned results. This should be treated as an opaque value and
-not explicitly set to anything other than the return value of ``next_token``
-from a previous call.
-
-The parameter ``limit`` is optional but is used for pagination, denoting the
-maximum number of items to return in this call. Defaults to ``100``.
-
-The parameter ``user_id`` is optional and filters to only return users with user IDs
-that contain this value. This parameter is ignored when using the ``name`` parameter.
-
-The parameter ``name`` is optional and filters to only return users with user ID localparts
-**or** displaynames that contain this value.
-
-The parameter ``guests`` is optional and if ``false`` will **exclude** guest users.
-Defaults to ``true`` to include guest users.
-
-The parameter ``deactivated`` is optional and if ``true`` will **include** deactivated users.
-Defaults to ``false`` to exclude deactivated users.
-
-A JSON body is returned with the following shape:
+A response body like the following is returned:
.. code:: json
@@ -175,6 +156,66 @@ with ``from`` set to the value of ``next_token``. This will return a new page.
If the endpoint does not return a ``next_token`` then there are no more users
to paginate through.
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - Is optional and filters to only return users with user IDs
+ that contain this value. This parameter is ignored when using the ``name`` parameter.
+- ``name`` - Is optional and filters to only return users with user ID localparts
+ **or** displaynames that contain this value.
+- ``guests`` - string representing a bool - Is optional and if ``false`` will **exclude** guest users.
+ Defaults to ``true`` to include guest users.
+- ``deactivated`` - string representing a bool - Is optional and if ``true`` will **include** deactivated users.
+ Defaults to ``false`` to exclude deactivated users.
+- ``limit`` - string representing a positive integer - Is optional but is used for pagination,
+ denoting the maximum number of items to return in this call. Defaults to ``100``.
+- ``from`` - string representing a positive integer - Is optional but used for pagination,
+ denoting the offset in the returned results. This should be treated as an opaque value and
+ not explicitly set to anything other than the return value of ``next_token`` from a previous call.
+ Defaults to ``0``.
+- ``order_by`` - The method by which to sort the returned list of users.
+ If the ordered field has duplicates, the second order is always by ascending ``name``,
+ which guarantees a stable ordering. Valid values are:
+
+ - ``name`` - Users are ordered alphabetically by ``name``. This is the default.
+ - ``is_guest`` - Users are ordered by ``is_guest`` status.
+ - ``admin`` - Users are ordered by ``admin`` status.
+ - ``user_type`` - Users are ordered alphabetically by ``user_type``.
+ - ``deactivated`` - Users are ordered by ``deactivated`` status.
+ - ``shadow_banned`` - Users are ordered by ``shadow_banned`` status.
+ - ``displayname`` - Users are ordered alphabetically by ``displayname``.
+ - ``avatar_url`` - Users are ordered alphabetically by avatar URL.
+
+- ``dir`` - Direction of media order. Either ``f`` for forwards or ``b`` for backwards.
+ Setting this value to ``b`` will reverse the above sort order. Defaults to ``f``.
+
+Caution. The database only has indexes on the columns ``name`` and ``created_ts``.
+This means that if a different sort order is used (``is_guest``, ``admin``,
+``user_type``, ``deactivated``, ``shadow_banned``, ``avatar_url`` or ``displayname``),
+this can cause a large load on the database, especially for large environments.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- ``users`` - An array of objects, each containing information about an user.
+ User objects contain the following fields:
+
+ - ``name`` - string - Fully-qualified user ID (ex. ``@user:server.com``).
+ - ``is_guest`` - bool - Status if that user is a guest account.
+ - ``admin`` - bool - Status if that user is a server administrator.
+ - ``user_type`` - string - Type of the user. Normal users are type ``None``.
+ This allows user type specific behaviour. There are also types ``support`` and ``bot``.
+ - ``deactivated`` - bool - Status if that user has been marked as deactivated.
+ - ``shadow_banned`` - bool - Status if that user has been marked as shadow banned.
+ - ``displayname`` - string - The user's display name if they have set one.
+ - ``avatar_url`` - string - The user's avatar URL if they have set one.
+
+- ``next_token``: string representing a positive integer - Indication for pagination. See above.
+- ``total`` - integer - Total number of media.
+
+
Query current sessions for a user
=================================
@@ -823,3 +864,118 @@ The following parameters should be set in the URL:
- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
be local.
+
+Override ratelimiting for users
+===============================
+
+This API allows to override or disable ratelimiting for a specific user.
+There are specific APIs to set, get and delete a ratelimit.
+
+Get status of ratelimit
+-----------------------
+
+The API is::
+
+ GET /_synapse/admin/v1/users/<user_id>/override_ratelimit
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+A response body like the following is returned:
+
+.. code:: json
+
+ {
+ "messages_per_second": 0,
+ "burst_count": 0
+ }
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
+ be local.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- ``messages_per_second`` - integer - The number of actions that can
+ be performed in a second. `0` mean that ratelimiting is disabled for this user.
+- ``burst_count`` - integer - How many actions that can be performed before
+ being limited.
+
+If **no** custom ratelimit is set, an empty JSON dict is returned.
+
+.. code:: json
+
+ {}
+
+Set ratelimit
+-------------
+
+The API is::
+
+ POST /_synapse/admin/v1/users/<user_id>/override_ratelimit
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+A response body like the following is returned:
+
+.. code:: json
+
+ {
+ "messages_per_second": 0,
+ "burst_count": 0
+ }
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
+ be local.
+
+Body parameters:
+
+- ``messages_per_second`` - positive integer, optional. The number of actions that can
+ be performed in a second. Defaults to ``0``.
+- ``burst_count`` - positive integer, optional. How many actions that can be performed
+ before being limited. Defaults to ``0``.
+
+To disable users' ratelimit set both values to ``0``.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- ``messages_per_second`` - integer - The number of actions that can
+ be performed in a second.
+- ``burst_count`` - integer - How many actions that can be performed before
+ being limited.
+
+Delete ratelimit
+----------------
+
+The API is::
+
+ DELETE /_synapse/admin/v1/users/<user_id>/override_ratelimit
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+An empty JSON dict is returned.
+
+.. code:: json
+
+ {}
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
+ be local.
+
diff --git a/docs/code_style.md b/docs/code_style.md
index 190f8ab2de..28fb7277c4 100644
--- a/docs/code_style.md
+++ b/docs/code_style.md
@@ -128,6 +128,9 @@ Some guidelines follow:
will be if no sub-options are enabled).
- Lines should be wrapped at 80 characters.
- Use two-space indents.
+- `true` and `false` are spelt thus (as opposed to `True`, etc.)
+- Use single quotes (`'`) rather than double-quotes (`"`) or backticks
+ (`` ` ``) to refer to configuration options.
Example:
diff --git a/docs/presence_router_module.md b/docs/presence_router_module.md
new file mode 100644
index 0000000000..d6566d978d
--- /dev/null
+++ b/docs/presence_router_module.md
@@ -0,0 +1,235 @@
+# Presence Router Module
+
+Synapse supports configuring a module that can specify additional users
+(local or remote) to should receive certain presence updates from local
+users.
+
+Note that routing presence via Application Service transactions is not
+currently supported.
+
+The presence routing module is implemented as a Python class, which will
+be imported by the running Synapse.
+
+## Python Presence Router Class
+
+The Python class is instantiated with two objects:
+
+* A configuration object of some type (see below).
+* An instance of `synapse.module_api.ModuleApi`.
+
+It then implements methods related to presence routing.
+
+Note that one method of `ModuleApi` that may be useful is:
+
+```python
+async def ModuleApi.send_local_online_presence_to(users: Iterable[str]) -> None
+```
+
+which can be given a list of local or remote MXIDs to broadcast known, online user
+presence to (for those users that the receiving user is considered interested in).
+It does not include state for users who are currently offline, and it can only be
+called on workers that support sending federation.
+
+### Module structure
+
+Below is a list of possible methods that can be implemented, and whether they are
+required.
+
+#### `parse_config`
+
+```python
+def parse_config(config_dict: dict) -> Any
+```
+
+**Required.** A static method that is passed a dictionary of config options, and
+ should return a validated config object. This method is described further in
+ [Configuration](#configuration).
+
+#### `get_users_for_states`
+
+```python
+async def get_users_for_states(
+ self,
+ state_updates: Iterable[UserPresenceState],
+) -> Dict[str, Set[UserPresenceState]]:
+```
+
+**Required.** An asynchronous method that is passed an iterable of user presence
+state. This method can determine whether a given presence update should be sent to certain
+users. It does this by returning a dictionary with keys representing local or remote
+Matrix User IDs, and values being a python set
+of `synapse.handlers.presence.UserPresenceState` instances.
+
+Synapse will then attempt to send the specified presence updates to each user when
+possible.
+
+#### `get_interested_users`
+
+```python
+async def get_interested_users(self, user_id: str) -> Union[Set[str], str]
+```
+
+**Required.** An asynchronous method that is passed a single Matrix User ID. This
+method is expected to return the users that the passed in user may be interested in the
+presence of. Returned users may be local or remote. The presence routed as a result of
+what this method returns is sent in addition to the updates already sent between users
+that share a room together. Presence updates are deduplicated.
+
+This method should return a python set of Matrix User IDs, or the object
+`synapse.events.presence_router.PresenceRouter.ALL_USERS` to indicate that the passed
+user should receive presence information for *all* known users.
+
+For clarity, if the user `@alice:example.org` is passed to this method, and the Set
+`{"@bob:example.com", "@charlie:somewhere.org"}` is returned, this signifies that Alice
+should receive presence updates sent by Bob and Charlie, regardless of whether these
+users share a room.
+
+### Example
+
+Below is an example implementation of a presence router class.
+
+```python
+from typing import Dict, Iterable, Set, Union
+from synapse.events.presence_router import PresenceRouter
+from synapse.handlers.presence import UserPresenceState
+from synapse.module_api import ModuleApi
+
+class PresenceRouterConfig:
+ def __init__(self):
+ # Config options with their defaults
+ # A list of users to always send all user presence updates to
+ self.always_send_to_users = [] # type: List[str]
+
+ # A list of users to ignore presence updates for. Does not affect
+ # shared-room presence relationships
+ self.blacklisted_users = [] # type: List[str]
+
+class ExamplePresenceRouter:
+ """An example implementation of synapse.presence_router.PresenceRouter.
+ Supports routing all presence to a configured set of users, or a subset
+ of presence from certain users to members of certain rooms.
+
+ Args:
+ config: A configuration object.
+ module_api: An instance of Synapse's ModuleApi.
+ """
+ def __init__(self, config: PresenceRouterConfig, module_api: ModuleApi):
+ self._config = config
+ self._module_api = module_api
+
+ @staticmethod
+ def parse_config(config_dict: dict) -> PresenceRouterConfig:
+ """Parse a configuration dictionary from the homeserver config, do
+ some validation and return a typed PresenceRouterConfig.
+
+ Args:
+ config_dict: The configuration dictionary.
+
+ Returns:
+ A validated config object.
+ """
+ # Initialise a typed config object
+ config = PresenceRouterConfig()
+ always_send_to_users = config_dict.get("always_send_to_users")
+ blacklisted_users = config_dict.get("blacklisted_users")
+
+ # Do some validation of config options... otherwise raise a
+ # synapse.config.ConfigError.
+ config.always_send_to_users = always_send_to_users
+ config.blacklisted_users = blacklisted_users
+
+ return config
+
+ async def get_users_for_states(
+ self,
+ state_updates: Iterable[UserPresenceState],
+ ) -> Dict[str, Set[UserPresenceState]]:
+ """Given an iterable of user presence updates, determine where each one
+ needs to go. Returned results will not affect presence updates that are
+ sent between users who share a room.
+
+ Args:
+ state_updates: An iterable of user presence state updates.
+
+ Returns:
+ A dictionary of user_id -> set of UserPresenceState that the user should
+ receive.
+ """
+ destination_users = {} # type: Dict[str, Set[UserPresenceState]
+
+ # Ignore any updates for blacklisted users
+ desired_updates = set()
+ for update in state_updates:
+ if update.state_key not in self._config.blacklisted_users:
+ desired_updates.add(update)
+
+ # Send all presence updates to specific users
+ for user_id in self._config.always_send_to_users:
+ destination_users[user_id] = desired_updates
+
+ return destination_users
+
+ async def get_interested_users(
+ self,
+ user_id: str,
+ ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+ """
+ Retrieve a list of users that `user_id` is interested in receiving the
+ presence of. This will be in addition to those they share a room with.
+ Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate
+ that this user should receive all incoming local and remote presence updates.
+
+ Note that this method will only be called for local users.
+
+ Args:
+ user_id: A user requesting presence updates.
+
+ Returns:
+ A set of user IDs to return additional presence updates for, or
+ PresenceRouter.ALL_USERS to return presence updates for all other users.
+ """
+ if user_id in self._config.always_send_to_users:
+ return PresenceRouter.ALL_USERS
+
+ return set()
+```
+
+#### A note on `get_users_for_states` and `get_interested_users`
+
+Both of these methods are effectively two different sides of the same coin. The logic
+regarding which users should receive updates for other users should be the same
+between them.
+
+`get_users_for_states` is called when presence updates come in from either federation
+or local users, and is used to either direct local presence to remote users, or to
+wake up the sync streams of local users to collect remote presence.
+
+In contrast, `get_interested_users` is used to determine the users that presence should
+be fetched for when a local user is syncing. This presence is then retrieved, before
+being fed through `get_users_for_states` once again, with only the syncing user's
+routing information pulled from the resulting dictionary.
+
+Their routing logic should thus line up, else you may run into unintended behaviour.
+
+## Configuration
+
+Once you've crafted your module and installed it into the same Python environment as
+Synapse, amend your homeserver config file with the following.
+
+```yaml
+presence:
+ routing_module:
+ module: my_module.ExamplePresenceRouter
+ config:
+ # Any configuration options for your module. The below is an example.
+ # of setting options for ExamplePresenceRouter.
+ always_send_to_users: ["@presence_gobbler:example.org"]
+ blacklisted_users:
+ - "@alice:example.com"
+ - "@bob:example.com"
+ ...
+```
+
+The contents of `config` will be passed as a Python dictionary to the static
+`parse_config` method of your class. The object returned by this method will
+then be passed to the `__init__` method of your module as `config`.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index be5e84f0ad..6627fb2c7a 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -82,9 +82,28 @@ pid_file: DATADIR/homeserver.pid
#
#soft_file_limit: 0
-# Set to false to disable presence tracking on this homeserver.
+# Presence tracking allows users to see the state (e.g online/offline)
+# of other local and remote users.
#
-#use_presence: false
+presence:
+ # Uncomment to disable presence tracking on this homeserver. This option
+ # replaces the previous top-level 'use_presence' option.
+ #
+ #enabled: false
+
+ # Presence routers are third-party modules that can specify additional logic
+ # to where presence updates from users are routed.
+ #
+ presence_router:
+ # The custom module's class. Uncomment to use a custom presence router module.
+ #
+ #module: "my_custom_router.PresenceRouter"
+
+ # Configuration options of the custom module. Refer to your module's
+ # documentation for available options.
+ #
+ #config:
+ # example_option: 'something'
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
@@ -1304,9 +1323,9 @@ url_preview_accept_language:
#
#allowed_local_3pids:
# - medium: email
-# pattern: '.*@matrix\.org'
+# pattern: '^[^@]+@matrix\.org$'
# - medium: email
-# pattern: '.*@vector\.im'
+# pattern: '^[^@]+@vector\.im$'
# - medium: msisdn
# pattern: '\+44'
@@ -1629,16 +1648,31 @@ metrics_flags:
## API Configuration ##
-# A list of event types from a room that will be given to users when they
-# are invited to a room. This allows clients to display information about the
-# room that they've been invited to, without actually being in the room yet.
+# Controls for the state that is shared with users who receive an invite
+# to a room
#
-#room_invite_state_types:
-# - "m.room.join_rules"
-# - "m.room.canonical_alias"
-# - "m.room.avatar"
-# - "m.room.encryption"
-# - "m.room.name"
+room_prejoin_state:
+ # By default, the following state event types are shared with users who
+ # receive invites to the room:
+ #
+ # - m.room.join_rules
+ # - m.room.canonical_alias
+ # - m.room.avatar
+ # - m.room.encryption
+ # - m.room.name
+ #
+ # Uncomment the following to disable these defaults (so that only the event
+ # types listed in 'additional_event_types' are shared). Defaults to 'false'.
+ #
+ #disable_default_event_types: true
+
+ # Additional state event types to share with users when they are invited
+ # to a room.
+ #
+ # By default, this list is empty (so only the default event types are shared).
+ #
+ #additional_event_types:
+ # - org.example.custom.event.type
# A list of application service config files to use
diff --git a/mypy.ini b/mypy.ini
index 3ae5d45787..32e6197409 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -8,6 +8,7 @@ show_traceback = True
mypy_path = stubs
warn_unreachable = True
local_partial_types = True
+no_implicit_optional = True
# To find all folders that pass mypy you run:
#
diff --git a/pyproject.toml b/pyproject.toml
index cd880d4e39..8bca1fa4ef 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,7 +35,7 @@
showcontent = true
[tool.black]
-target-version = ['py35']
+target-version = ['py36']
exclude = '''
(
diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages
index d0685c8b35..3bb6e2c7ea 100755
--- a/scripts-dev/build_debian_packages
+++ b/scripts-dev/build_debian_packages
@@ -18,11 +18,9 @@ import threading
from concurrent.futures import ThreadPoolExecutor
DISTS = (
- "debian:stretch",
"debian:buster",
"debian:bullseye",
"debian:sid",
- "ubuntu:xenial",
"ubuntu:bionic",
"ubuntu:focal",
"ubuntu:groovy",
@@ -43,7 +41,7 @@ class Builder(object):
self._lock = threading.Lock()
self._failed = False
- def run_build(self, dist):
+ def run_build(self, dist, skip_tests=False):
"""Build deb for a single distribution"""
if self._failed:
@@ -51,13 +49,13 @@ class Builder(object):
raise Exception("failed")
try:
- self._inner_build(dist)
+ self._inner_build(dist, skip_tests)
except Exception as e:
print("build of %s failed: %s" % (dist, e), file=sys.stderr)
self._failed = True
raise
- def _inner_build(self, dist):
+ def _inner_build(self, dist, skip_tests=False):
projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
os.chdir(projdir)
@@ -101,6 +99,7 @@ class Builder(object):
"--volume=" + debsdir + ":/debs",
"-e", "TARGET_USERID=%i" % (os.getuid(), ),
"-e", "TARGET_GROUPID=%i" % (os.getgid(), ),
+ "-e", "DEB_BUILD_OPTIONS=%s" % ("nocheck" if skip_tests else ""),
"dh-venv-builder:" + tag,
], stdout=stdout, stderr=subprocess.STDOUT)
@@ -124,7 +123,7 @@ class Builder(object):
self.active_containers.remove(c)
-def run_builds(dists, jobs=1):
+def run_builds(dists, jobs=1, skip_tests=False):
builder = Builder(redirect_stdout=(jobs > 1))
def sig(signum, _frame):
@@ -133,7 +132,7 @@ def run_builds(dists, jobs=1):
signal.signal(signal.SIGINT, sig)
with ThreadPoolExecutor(max_workers=jobs) as e:
- res = e.map(builder.run_build, dists)
+ res = e.map(lambda dist: builder.run_build(dist, skip_tests), dists)
# make sure we consume the iterable so that exceptions are raised.
for r in res:
@@ -149,8 +148,12 @@ if __name__ == '__main__':
help='specify the number of builds to run in parallel',
)
parser.add_argument(
+ '--no-check', action='store_true',
+ help='skip running tests after building',
+ )
+ parser.add_argument(
'dist', nargs='*', default=DISTS,
help='a list of distributions to build for. Default: %(default)s',
)
args = parser.parse_args()
- run_builds(dists=args.dist, jobs=args.jobs)
+ run_builds(dists=args.dist, jobs=args.jobs, skip_tests=args.no_check)
diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh
index 3cde53f5c0..1612ab522c 100755
--- a/scripts-dev/complement.sh
+++ b/scripts-dev/complement.sh
@@ -1,22 +1,49 @@
-#! /bin/bash -eu
+#!/usr/bin/env bash
# This script is designed for developers who want to test their code
# against Complement.
#
# It makes a Synapse image which represents the current checkout,
-# then downloads Complement and runs it with that image.
+# builds a synapse-complement image on top, then runs tests with it.
+#
+# By default the script will fetch the latest Complement master branch and
+# run tests with that. This can be overridden to use a custom Complement
+# checkout by setting the COMPLEMENT_DIR environment variable to the
+# filepath of a local Complement checkout.
+#
+# A regular expression of test method names can be supplied as the first
+# argument to the script. Complement will then only run those tests. If
+# no regex is supplied, all tests are run. For example;
+#
+# ./complement.sh "TestOutboundFederation(Profile|Send)"
+#
+
+# Exit if a line returns a non-zero exit code
+set -e
+# Change to the repository root
cd "$(dirname $0)/.."
+# Check for a user-specified Complement checkout
+if [[ -z "$COMPLEMENT_DIR" ]]; then
+ echo "COMPLEMENT_DIR not set. Fetching the latest Complement checkout..."
+ wget -Nq https://github.com/matrix-org/complement/archive/master.tar.gz
+ tar -xzf master.tar.gz
+ COMPLEMENT_DIR=complement-master
+ echo "Checkout available at 'complement-master'"
+fi
+
# Build the base Synapse image from the local checkout
-docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile .
+docker build -t matrixdotorg/synapse -f docker/Dockerfile .
+# Build the Synapse monolith image from Complement, based on the above image we just built
+docker build -t complement-synapse -f "$COMPLEMENT_DIR/dockerfiles/Synapse.Dockerfile" "$COMPLEMENT_DIR/dockerfiles"
-# Download Complement
-wget -N https://github.com/matrix-org/complement/archive/master.tar.gz
-tar -xzf master.tar.gz
-cd complement-master
+cd "$COMPLEMENT_DIR"
-# Build the Synapse image from Complement, based on the above image we just built
-docker build -t complement-synapse -f dockerfiles/Synapse.Dockerfile ./dockerfiles
+EXTRA_COMPLEMENT_ARGS=""
+if [[ -n "$1" ]]; then
+ # A test name regex has been set, supply it to Complement
+ EXTRA_COMPLEMENT_ARGS+="-run $1 "
+fi
-# Run the tests on the resulting image!
-COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -count=1 ./tests
+# Run the tests!
+COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -tags synapse_blacklist,msc2946,msc3083 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests
diff --git a/scripts-dev/release.py b/scripts-dev/release.py
new file mode 100755
index 0000000000..1042fa48bc
--- /dev/null
+++ b/scripts-dev/release.py
@@ -0,0 +1,244 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""An interactive script for doing a release. See `run()` below.
+"""
+
+import subprocess
+import sys
+from typing import Optional
+
+import click
+import git
+from packaging import version
+from redbaron import RedBaron
+
+
+@click.command()
+def run():
+ """An interactive script to walk through the initial stages of creating a
+ release, including creating release branch, updating changelog and pushing to
+ GitHub.
+
+ Requires the dev dependencies be installed, which can be done via:
+
+ pip install -e .[dev]
+
+ """
+
+ # Make sure we're in a git repo.
+ try:
+ repo = git.Repo()
+ except git.InvalidGitRepositoryError:
+ raise click.ClickException("Not in Synapse repo.")
+
+ if repo.is_dirty():
+ raise click.ClickException("Uncommitted changes exist.")
+
+ click.secho("Updating git repo...")
+ repo.remote().fetch()
+
+ # Parse the AST and load the `__version__` node so that we can edit it
+ # later.
+ with open("synapse/__init__.py") as f:
+ red = RedBaron(f.read())
+
+ version_node = None
+ for node in red:
+ if node.type != "assignment":
+ continue
+
+ if node.target.type != "name":
+ continue
+
+ if node.target.value != "__version__":
+ continue
+
+ version_node = node
+ break
+
+ if not version_node:
+ print("Failed to find '__version__' definition in synapse/__init__.py")
+ sys.exit(1)
+
+ # Parse the current version.
+ current_version = version.parse(version_node.value.value.strip('"'))
+ assert isinstance(current_version, version.Version)
+
+ # Figure out what sort of release we're doing and calcuate the new version.
+ rc = click.confirm("RC", default=True)
+ if current_version.pre:
+ # If the current version is an RC we don't need to bump any of the
+ # version numbers (other than the RC number).
+ base_version = "{}.{}.{}".format(
+ current_version.major,
+ current_version.minor,
+ current_version.micro,
+ )
+
+ if rc:
+ new_version = "{}.{}.{}rc{}".format(
+ current_version.major,
+ current_version.minor,
+ current_version.micro,
+ current_version.pre[1] + 1,
+ )
+ else:
+ new_version = base_version
+ else:
+ # If this is a new release cycle then we need to know if its a major
+ # version bump or a hotfix.
+ release_type = click.prompt(
+ "Release type",
+ type=click.Choice(("major", "hotfix")),
+ show_choices=True,
+ default="major",
+ )
+
+ if release_type == "major":
+ base_version = new_version = "{}.{}.{}".format(
+ current_version.major,
+ current_version.minor + 1,
+ 0,
+ )
+ if rc:
+ new_version = "{}.{}.{}rc1".format(
+ current_version.major,
+ current_version.minor + 1,
+ 0,
+ )
+
+ else:
+ base_version = new_version = "{}.{}.{}".format(
+ current_version.major,
+ current_version.minor,
+ current_version.micro + 1,
+ )
+ if rc:
+ new_version = "{}.{}.{}rc1".format(
+ current_version.major,
+ current_version.minor,
+ current_version.micro + 1,
+ )
+
+ # Confirm the calculated version is OK.
+ if not click.confirm(f"Create new version: {new_version}?", default=True):
+ click.get_current_context().abort()
+
+ # Switch to the release branch.
+ release_branch_name = f"release-v{base_version}"
+ release_branch = find_ref(repo, release_branch_name)
+ if release_branch:
+ if release_branch.is_remote():
+ # If the release branch only exists on the remote we check it out
+ # locally.
+ repo.git.checkout(release_branch_name)
+ release_branch = repo.active_branch
+ else:
+ # If a branch doesn't exist we create one. We ask which one branch it
+ # should be based off, defaulting to sensible values depending on the
+ # release type.
+ if current_version.is_prerelease:
+ default = release_branch_name
+ elif release_type == "major":
+ default = "develop"
+ else:
+ default = "master"
+
+ branch_name = click.prompt(
+ "Which branch should the release be based on?", default=default
+ )
+
+ base_branch = find_ref(repo, branch_name)
+ if not base_branch:
+ print(f"Could not find base branch {branch_name}!")
+ click.get_current_context().abort()
+
+ # Check out the base branch and ensure it's up to date
+ repo.head.reference = base_branch
+ repo.head.reset(index=True, working_tree=True)
+ if not base_branch.is_remote():
+ update_branch(repo)
+
+ # Create the new release branch
+ release_branch = repo.create_head(release_branch_name, commit=base_branch)
+
+ # Switch to the release branch and ensure its up to date.
+ repo.git.checkout(release_branch_name)
+ update_branch(repo)
+
+ # Update the `__version__` variable and write it back to the file.
+ version_node.value = '"' + new_version + '"'
+ with open("synapse/__init__.py", "w") as f:
+ f.write(red.dumps())
+
+ # Generate changelogs
+ subprocess.run("python3 -m towncrier", shell=True)
+
+ # Generate debian changelogs if its not an RC.
+ if not rc:
+ subprocess.run(
+ f'dch -M -v {new_version} "New synapse release {new_version}."', shell=True
+ )
+ subprocess.run('dch -M -r -D stable ""', shell=True)
+
+ # Show the user the changes and ask if they want to edit the change log.
+ repo.git.add("-u")
+ subprocess.run("git diff --cached", shell=True)
+
+ if click.confirm("Edit changelog?", default=False):
+ click.edit(filename="CHANGES.md")
+
+ # Commit the changes.
+ repo.git.add("-u")
+ repo.git.commit(f"-m {new_version}")
+
+ # We give the option to bail here in case the user wants to make sure things
+ # are OK before pushing.
+ if not click.confirm("Push branch to github?", default=True):
+ print("")
+ print("Run when ready to push:")
+ print("")
+ print(f"\tgit push -u {repo.remote().name} {repo.active_branch.name}")
+ print("")
+ sys.exit(0)
+
+ # Otherwise, push and open the changelog in the browser.
+ repo.git.push("-u", repo.remote().name, repo.active_branch.name)
+
+ click.launch(
+ f"https://github.com/matrix-org/synapse/blob/{repo.active_branch.name}/CHANGES.md"
+ )
+
+
+def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]:
+ """Find the branch/ref, looking first locally then in the remote."""
+ if ref_name in repo.refs:
+ return repo.refs[ref_name]
+ elif ref_name in repo.remote().refs:
+ return repo.remote().refs[ref_name]
+ else:
+ return None
+
+
+def update_branch(repo: git.Repo):
+ """Ensure branch is up to date if it has a remote"""
+ if repo.active_branch.tracking_branch():
+ repo.git.merge(repo.active_branch.tracking_branch().name)
+
+
+if __name__ == "__main__":
+ run()
diff --git a/setup.cfg b/setup.cfg
index 7329eed213..33601b71d5 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -18,16 +18,15 @@ ignore =
# E203: whitespace before ':' (which is contrary to pep8?)
# E731: do not assign a lambda expression, use a def
# E501: Line too long (black enforces this for us)
-# B00*: Subsection of the bugbear suite (TODO: add in remaining fixes)
-ignore=W503,W504,E203,E731,E501,B006,B007,B008
+# B007: Subsection of the bugbear suite (TODO: add in remaining fixes)
+ignore=W503,W504,E203,E731,E501,B007
[isort]
line_length = 88
-sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER
+sections=FUTURE,STDLIB,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER
default_section=THIRDPARTY
known_first_party = synapse
known_tests=tests
-known_compat = mock
known_twisted=twisted,OpenSSL
multi_line_output=3
include_trailing_comma=true
diff --git a/setup.py b/setup.py
index 29e9971dc1..e2e488761d 100755
--- a/setup.py
+++ b/setup.py
@@ -103,6 +103,13 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
"flake8",
]
+CONDITIONAL_REQUIREMENTS["dev"] = CONDITIONAL_REQUIREMENTS["lint"] + [
+ # The following are used by the release script
+ "click==7.1.2",
+ "redbaron==0.9.2",
+ "GitPython==3.1.14",
+]
+
CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.13"]
# Dependencies which are exclusively required by unit test code. This is
@@ -110,7 +117,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.13"]
# Tests assume that all optional dependencies are installed.
#
# parameterized_class decorator was introduced in parameterized 0.7.0
-CONDITIONAL_REQUIREMENTS["test"] = ["mock>=2.0", "parameterized>=0.7.0"]
+CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"]
setup(
name="matrix-synapse",
@@ -123,13 +130,12 @@ setup(
zip_safe=False,
long_description=long_description,
long_description_content_type="text/x-rst",
- python_requires="~=3.5",
+ python_requires="~=3.6",
classifiers=[
"Development Status :: 5 - Production/Stable",
"Topic :: Communications :: Chat",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3 :: Only",
- "Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 1d2883acf6..79232c4de1 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.31.0"
+__version__ = "1.32.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 8f37d2cf3b..a8ae41de48 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -59,6 +59,8 @@ class JoinRules:
KNOCK = "knock"
INVITE = "invite"
PRIVATE = "private"
+ # As defined for MSC3083.
+ MSC3083_RESTRICTED = "restricted"
class LoginType:
@@ -71,6 +73,11 @@ class LoginType:
DUMMY = "m.login.dummy"
+# This is used in the `type` parameter for /register when called by
+# an appservice to register a new user.
+APP_SERVICE_REGISTRATION_TYPE = "m.login.application_service"
+
+
class EventTypes:
Member = "m.room.member"
Create = "m.room.create"
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index c3f07bc1a3..2244b8a340 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -17,6 +17,7 @@ from collections import OrderedDict
from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError
+from synapse.storage.databases.main import DataStore
from synapse.types import Requester
from synapse.util import Clock
@@ -31,10 +32,13 @@ class Ratelimiter:
burst_count: How many actions that can be performed before being limited.
"""
- def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
+ def __init__(
+ self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int
+ ):
self.clock = clock
self.rate_hz = rate_hz
self.burst_count = burst_count
+ self.store = store
# A ordered dictionary keeping track of actions, when they were last
# performed and how often. Each entry is a mapping from a key of arbitrary type
@@ -46,45 +50,10 @@ class Ratelimiter:
OrderedDict()
) # type: OrderedDict[Hashable, Tuple[float, int, float]]
- def can_requester_do_action(
- self,
- requester: Requester,
- rate_hz: Optional[float] = None,
- burst_count: Optional[int] = None,
- update: bool = True,
- _time_now_s: Optional[int] = None,
- ) -> Tuple[bool, float]:
- """Can the requester perform the action?
-
- Args:
- requester: The requester to key off when rate limiting. The user property
- will be used.
- rate_hz: The long term number of actions that can be performed in a second.
- Overrides the value set during instantiation if set.
- burst_count: How many actions that can be performed before being limited.
- Overrides the value set during instantiation if set.
- update: Whether to count this check as performing the action
- _time_now_s: The current time. Optional, defaults to the current time according
- to self.clock. Only used by tests.
-
- Returns:
- A tuple containing:
- * A bool indicating if they can perform the action now
- * The reactor timestamp for when the action can be performed next.
- -1 if rate_hz is less than or equal to zero
- """
- # Disable rate limiting of users belonging to any AS that is configured
- # not to be rate limited in its registration file (rate_limited: true|false).
- if requester.app_service and not requester.app_service.is_rate_limited():
- return True, -1.0
-
- return self.can_do_action(
- requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
- )
-
- def can_do_action(
+ async def can_do_action(
self,
- key: Hashable,
+ requester: Optional[Requester],
+ key: Optional[Hashable] = None,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
@@ -92,9 +61,16 @@ class Ratelimiter:
) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action?
+ Checks if the user has ratelimiting disabled in the database by looking
+ for null/zero values in the `ratelimit_override` table. (Non-zero
+ values aren't honoured, as they're specific to the event sending
+ ratelimiter, rather than all ratelimiters)
+
Args:
- key: The key we should use when rate limiting. Can be a user ID
- (when sending events), an IP address, etc.
+ requester: The requester that is doing the action, if any. Used to check
+ if the user has ratelimits disabled in the database.
+ key: An arbitrary key used to classify an action. Defaults to the
+ requester's user ID.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
@@ -109,6 +85,30 @@ class Ratelimiter:
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
+ if key is None:
+ if not requester:
+ raise ValueError("Must supply at least one of `requester` or `key`")
+
+ key = requester.user.to_string()
+
+ if requester:
+ # Disable rate limiting of users belonging to any AS that is configured
+ # not to be rate limited in its registration file (rate_limited: true|false).
+ if requester.app_service and not requester.app_service.is_rate_limited():
+ return True, -1.0
+
+ # Check if ratelimiting has been disabled for the user.
+ #
+ # Note that we don't use the returned rate/burst count, as the table
+ # is specifically for the event sending ratelimiter. Instead, we
+ # only use it to (somewhat cheekily) infer whether the user should
+ # be subject to any rate limiting or not.
+ override = await self.store.get_ratelimit_for_user(
+ requester.authenticated_entity
+ )
+ if override and not override.messages_per_second:
+ return True, -1.0
+
# Override default values if set
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
@@ -175,9 +175,10 @@ class Ratelimiter:
else:
del self.actions[key]
- def ratelimit(
+ async def ratelimit(
self,
- key: Hashable,
+ requester: Optional[Requester],
+ key: Optional[Hashable] = None,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
@@ -185,8 +186,16 @@ class Ratelimiter:
):
"""Checks if an action can be performed. If not, raises a LimitExceededError
+ Checks if the user has ratelimiting disabled in the database by looking
+ for null/zero values in the `ratelimit_override` table. (Non-zero
+ values aren't honoured, as they're specific to the event sending
+ ratelimiter, rather than all ratelimiters)
+
Args:
- key: An arbitrary key used to classify an action
+ requester: The requester that is doing the action, if any. Used to check for
+ if the user has ratelimits disabled.
+ key: An arbitrary key used to classify an action. Defaults to the
+ requester's user ID.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
@@ -201,7 +210,8 @@ class Ratelimiter:
"""
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
- allowed, time_allowed = self.can_do_action(
+ allowed, time_allowed = await self.can_do_action(
+ requester,
key,
rate_hz=rate_hz,
burst_count=burst_count,
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 139fbf5524..4a4ad4c9bb 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -57,7 +57,7 @@ class RoomVersion:
state_res = attr.ib(type=int) # one of the StateResolutionVersions
enforce_key_validity = attr.ib(type=bool)
- # Before MSC2432, m.room.aliases had special auth rules and redaction rules
+ # Before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
special_case_aliases_auth = attr.ib(type=bool)
# Strictly enforce canonicaljson, do not allow:
# * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
@@ -69,8 +69,8 @@ class RoomVersion:
limit_notifications_power_levels = attr.ib(type=bool)
# MSC2174/MSC2176: Apply updated redaction rules algorithm.
msc2176_redaction_rules = attr.ib(type=bool)
- # MSC2174/MSC2176: Apply updated redaction rules algorithm.
- msc2176_redaction_rules = attr.ib(type=bool)
+ # MSC3083: Support the 'restricted' join_rule.
+ msc3083_join_rules = attr.ib(type=bool)
# MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending
# m.room.membership event with membership 'knock'.
allow_knocking = attr.ib(type=bool)
@@ -88,6 +88,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
allow_knocking=False,
+ msc3083_join_rules=False,
)
V2 = RoomVersion(
"2",
@@ -100,6 +101,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
allow_knocking=False,
+ msc3083_join_rules=False,
)
V3 = RoomVersion(
"3",
@@ -112,6 +114,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
allow_knocking=False,
+ msc3083_join_rules=False,
)
V4 = RoomVersion(
"4",
@@ -124,6 +127,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
allow_knocking=False,
+ msc3083_join_rules=False,
)
V5 = RoomVersion(
"5",
@@ -136,6 +140,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
allow_knocking=False,
+ msc3083_join_rules=False,
)
V6 = RoomVersion(
"6",
@@ -148,6 +153,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
allow_knocking=False,
+ msc3083_join_rules=False,
)
V7 = RoomVersion(
"7",
@@ -160,6 +166,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
allow_knocking=True,
+ msc3083_join_rules=False,
)
MSC2176 = RoomVersion(
"org.matrix.msc2176",
@@ -171,6 +178,19 @@ class RoomVersions:
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=True,
+ msc3083_join_rules=False,
+ )
+ MSC3083 = RoomVersion(
+ "org.matrix.msc3083",
+ RoomDisposition.UNSTABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
+ special_case_aliases_auth=False,
+ strict_canonicaljson=True,
+ limit_notifications_power_levels=True,
+ msc2176_redaction_rules=False,
+ msc3083_join_rules=True,
allow_knocking=False,
)
@@ -187,4 +207,5 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V7,
RoomVersions.MSC2176,
)
+ # Note that we do not include MSC3083 here unless it is enabled in the config.
} # type: Dict[str, RoomVersion]
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index b2d21acefd..70ce8a8988 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -281,6 +281,7 @@ class GenericWorkerPresence(BasePresenceHandler):
self.hs = hs
self.is_mine_id = hs.is_mine_id
+ self.presence_router = hs.get_presence_router()
self._presence_enabled = hs.config.use_presence
# The number of ongoing syncs on this process, by user id.
@@ -395,7 +396,7 @@ class GenericWorkerPresence(BasePresenceHandler):
return _user_syncing()
async def notify_from_replication(self, states, stream_id):
- parties = await get_interested_parties(self.store, states)
+ parties = await get_interested_parties(self.store, self.presence_router, states)
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 366c476f80..5203ffe90f 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -49,7 +49,7 @@ This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
import logging
-from typing import List
+from typing import List, Optional
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.events import EventBase
@@ -191,11 +191,11 @@ class _TransactionController:
self,
service: ApplicationService,
events: List[EventBase],
- ephemeral: List[JsonDict] = [],
+ ephemeral: Optional[List[JsonDict]] = None,
):
try:
txn = await self.store.create_appservice_txn(
- service=service, events=events, ephemeral=ephemeral
+ service=service, events=events, ephemeral=ephemeral or []
)
service_is_up = await self._is_service_up(service)
if service_is_up:
diff --git a/synapse/config/api.py b/synapse/config/api.py
index 0638ed8d2e..55683200de 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -1,5 +1,4 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,42 +12,131 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.constants import EventTypes
+import logging
+from typing import Iterable
-from ._base import Config
+from synapse.api.constants import EventTypes
+from synapse.config._base import Config, ConfigError
+from synapse.config._util import validate_config
+from synapse.types import JsonDict
-# The default types of room state to send to users to are invited to or knock on a room.
-DEFAULT_ROOM_STATE_TYPES = [
- EventTypes.JoinRules,
- EventTypes.CanonicalAlias,
- EventTypes.RoomAvatar,
- EventTypes.RoomEncryption,
- EventTypes.Name,
-]
+logger = logging.getLogger(__name__)
class ApiConfig(Config):
section = "api"
- def read_config(self, config, **kwargs):
- self.room_invite_state_types = config.get(
- "room_invite_state_types", DEFAULT_ROOM_STATE_TYPES
+ def read_config(self, config: JsonDict, **kwargs):
+ validate_config(_MAIN_SCHEMA, config, ())
+ self.room_prejoin_state = list(self._get_prejoin_state_types(config))
+
+ def generate_config_section(cls, **kwargs) -> str:
+ formatted_default_state_types = "\n".join(
+ " # - %s" % (t,) for t in _DEFAULT_PREJOIN_STATE_TYPES
)
- def generate_config_section(cls, **kwargs):
return """\
## API Configuration ##
- # A list of event types from a room that will be given to users when they
- # are invited to a room. This allows clients to display information about the
- # room that they've been invited to, without actually being in the room yet.
+ # Controls for the state that is shared with users who receive an invite
+ # to a room
#
- #room_invite_state_types:
- # - "{JoinRules}"
- # - "{CanonicalAlias}"
- # - "{RoomAvatar}"
- # - "{RoomEncryption}"
- # - "{Name}"
- """.format(
- **vars(EventTypes)
- )
+ room_prejoin_state:
+ # By default, the following state event types are shared with users who
+ # receive invites to the room:
+ #
+%(formatted_default_state_types)s
+ #
+ # Uncomment the following to disable these defaults (so that only the event
+ # types listed in 'additional_event_types' are shared). Defaults to 'false'.
+ #
+ #disable_default_event_types: true
+
+ # Additional state event types to share with users when they are invited
+ # to a room.
+ #
+ # By default, this list is empty (so only the default event types are shared).
+ #
+ #additional_event_types:
+ # - org.example.custom.event.type
+ """ % {
+ "formatted_default_state_types": formatted_default_state_types
+ }
+
+ def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
+ """Get the event types to include in the prejoin state
+
+ Parses the config and returns an iterable of the event types to be included.
+ """
+ room_prejoin_state_config = config.get("room_prejoin_state") or {}
+
+ # backwards-compatibility support for room_invite_state_types
+ if "room_invite_state_types" in config:
+ # if both "room_invite_state_types" and "room_prejoin_state" are set, then
+ # we don't really know what to do.
+ if room_prejoin_state_config:
+ raise ConfigError(
+ "Can't specify both 'room_invite_state_types' and 'room_prejoin_state' "
+ "in config"
+ )
+
+ logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)
+
+ yield from config["room_invite_state_types"]
+ return
+
+ if not room_prejoin_state_config.get("disable_default_event_types"):
+ yield from _DEFAULT_PREJOIN_STATE_TYPES
+
+ if self.spaces_enabled:
+ # MSC1772 suggests adding m.room.create to the prejoin state
+ yield EventTypes.Create
+
+ yield from room_prejoin_state_config.get("additional_event_types", [])
+
+
+_ROOM_INVITE_STATE_TYPES_WARNING = """\
+WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
+and replaced with 'room_prejoin_state'. New features may not work correctly
+unless 'room_invite_state_types' is removed. See the sample configuration file for
+details of 'room_prejoin_state'.
+--------------------------------------------------------------------------------
+"""
+
+_DEFAULT_PREJOIN_STATE_TYPES = [
+ EventTypes.JoinRules,
+ EventTypes.CanonicalAlias,
+ EventTypes.RoomAvatar,
+ EventTypes.RoomEncryption,
+ EventTypes.Name,
+]
+
+
+# room_prejoin_state can either be None (as it is in the default config), or
+# an object containing other config settings
+_ROOM_PREJOIN_STATE_CONFIG_SCHEMA = {
+ "oneOf": [
+ {
+ "type": "object",
+ "properties": {
+ "disable_default_event_types": {"type": "boolean"},
+ "additional_event_types": {
+ "type": "array",
+ "items": {"type": "string"},
+ },
+ },
+ },
+ {"type": "null"},
+ ]
+}
+
+# the legacy room_invite_state_types setting
+_ROOM_INVITE_STATE_TYPES_SCHEMA = {"type": "array", "items": {"type": "string"}}
+
+_MAIN_SCHEMA = {
+ "type": "object",
+ "properties": {
+ "room_prejoin_state": _ROOM_PREJOIN_STATE_CONFIG_SCHEMA,
+ "room_invite_state_types": _ROOM_INVITE_STATE_TYPES_SCHEMA,
+ },
+}
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 001bddc6f6..a207de63b8 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -34,7 +34,11 @@ class ExperimentalConfig(Config):
# MSC2858 (multiple SSO identity providers)
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
- # Spaces (MSC1772, MSC2946, etc)
+
+ # Spaces (MSC1772, MSC2946, MSC3083, etc)
self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool
+ if self.spaces_enabled:
+ KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083
+
# MSC3026 (busy presence state)
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 19322372a9..8fdaa59326 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict
+from typing import Dict, Optional
from ._base import Config
@@ -21,8 +21,10 @@ class RateLimitConfig:
def __init__(
self,
config: Dict[str, float],
- defaults={"per_second": 0.17, "burst_count": 3.0},
+ defaults: Optional[Dict[str, float]] = None,
):
+ defaults = defaults or {"per_second": 0.17, "burst_count": 3.0}
+
self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index b49e6609ce..1f441107b3 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -243,9 +243,9 @@ class RegistrationConfig(Config):
#
#allowed_local_3pids:
# - medium: email
- # pattern: '.*@matrix\\.org'
+ # pattern: '^[^@]+@matrix\\.org$'
# - medium: email
- # pattern: '.*@vector\\.im'
+ # pattern: '^[^@]+@vector\\.im$'
# - medium: msisdn
# pattern: '\\+44'
diff --git a/synapse/config/server.py b/synapse/config/server.py
index c8b1a25004..6de884eaf7 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -27,6 +27,7 @@ import yaml
from netaddr import AddrFormatError, IPNetwork, IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_server_name
from ._base import Config, ConfigError
@@ -238,7 +239,20 @@ class ServerConfig(Config):
self.public_baseurl = config.get("public_baseurl")
# Whether to enable user presence.
- self.use_presence = config.get("use_presence", True)
+ presence_config = config.get("presence") or {}
+ self.use_presence = presence_config.get("enabled")
+ if self.use_presence is None:
+ self.use_presence = config.get("use_presence", True)
+
+ # Custom presence router module
+ self.presence_router_module_class = None
+ self.presence_router_config = None
+ presence_router_config = presence_config.get("presence_router")
+ if presence_router_config:
+ (
+ self.presence_router_module_class,
+ self.presence_router_config,
+ ) = load_module(presence_router_config, ("presence", "presence_router"))
# Whether to update the user directory or not. This should be set to
# false only if we are updating the user directory in a worker
@@ -840,9 +854,28 @@ class ServerConfig(Config):
#
#soft_file_limit: 0
- # Set to false to disable presence tracking on this homeserver.
+ # Presence tracking allows users to see the state (e.g online/offline)
+ # of other local and remote users.
#
- #use_presence: false
+ presence:
+ # Uncomment to disable presence tracking on this homeserver. This option
+ # replaces the previous top-level 'use_presence' option.
+ #
+ #enabled: false
+
+ # Presence routers are third-party modules that can specify additional logic
+ # to where presence updates from users are routed.
+ #
+ presence_router:
+ # The custom module's class. Uncomment to use a custom presence router module.
+ #
+ #module: "my_custom_router.PresenceRouter"
+
+ # Configuration options of the custom module. Refer to your module's
+ # documentation for available options.
+ #
+ #config:
+ # example_option: 'something'
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index ad37b93c02..85b5db4c40 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -270,7 +270,7 @@ class TlsConfig(Config):
tls_certificate_path,
tls_private_key_path,
acme_domain,
- **kwargs
+ **kwargs,
):
"""If the acme_domain is specified acme will be enabled.
If the TLS paths are not specified the default will be certs in the
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 4e20851d7f..b2d8ba7849 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -163,7 +163,7 @@ def check(
# 5. If type is m.room.membership
if event.type == EventTypes.Member:
- _is_membership_change_allowed(event, auth_events)
+ _is_membership_change_allowed(room_version_obj, event, auth_events)
logger.debug("Allowing! %s", event)
return
@@ -221,8 +221,19 @@ def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
def _is_membership_change_allowed(
- event: EventBase, auth_events: StateMap[EventBase]
+ room_version: RoomVersion, event: EventBase, auth_events: StateMap[EventBase]
) -> None:
+ """
+ Confirms that the event which changes membership is an allowed change.
+
+ Args:
+ room_version: The version of the room.
+ event: The event to check.
+ auth_events: The current auth events of the room.
+
+ Raises:
+ AuthError if the event is not allowed.
+ """
membership = event.content["membership"]
# Check if this is the room creator joining:
@@ -320,14 +331,19 @@ def _is_membership_change_allowed(
if user_level < invite_level:
raise AuthError(403, "You don't have permission to invite users")
elif Membership.JOIN == membership:
- # Joins are valid iff caller == target and they were:
- # invited: They are accepting the invitation
- # joined: It's a NOOP
+ # Joins are valid iff caller == target and:
+ # * They are not banned.
+ # * They are accepting a previously sent invitation.
+ # * They are already joined (it's a NOOP).
+ # * The room is public or restricted.
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
- elif join_rule == JoinRules.PUBLIC:
+ elif join_rule == JoinRules.PUBLIC or (
+ room_version.msc3083_join_rules
+ and join_rule == JoinRules.MSC3083_RESTRICTED
+ ):
pass
elif join_rule in (JoinRules.INVITE, JoinRules.KNOCK):
if not caller_in_room and not caller_invited:
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 8f6b955d17..f9032e3697 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -330,9 +330,11 @@ class FrozenEvent(EventBase):
self,
event_dict: JsonDict,
room_version: RoomVersion,
- internal_metadata_dict: JsonDict = {},
+ internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None,
):
+ internal_metadata_dict = internal_metadata_dict or {}
+
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
@@ -386,9 +388,11 @@ class FrozenEventV2(EventBase):
self,
event_dict: JsonDict,
room_version: RoomVersion,
- internal_metadata_dict: JsonDict = {},
+ internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None,
):
+ internal_metadata_dict = internal_metadata_dict or {}
+
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
@@ -507,9 +511,11 @@ def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
def make_event_from_dict(
event_dict: JsonDict,
room_version: RoomVersion = RoomVersions.V1,
- internal_metadata_dict: JsonDict = {},
+ internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None,
) -> EventBase:
"""Construct an EventBase from the given event dict"""
event_type = _event_type_from_format_version(room_version.event_format)
- return event_type(event_dict, room_version, internal_metadata_dict, rejected_reason)
+ return event_type(
+ event_dict, room_version, internal_metadata_dict or {}, rejected_reason
+ )
diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py
new file mode 100644
index 0000000000..24cd389d80
--- /dev/null
+++ b/synapse/events/presence_router.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict, Iterable, Set, Union
+
+from synapse.api.presence import UserPresenceState
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class PresenceRouter:
+ """
+ A module that the homeserver will call upon to help route user presence updates to
+ additional destinations. If a custom presence router is configured, calls will be
+ passed to that instead.
+ """
+
+ ALL_USERS = "ALL"
+
+ def __init__(self, hs: "HomeServer"):
+ self.custom_presence_router = None
+
+ # Check whether a custom presence router module has been configured
+ if hs.config.presence_router_module_class:
+ # Initialise the module
+ self.custom_presence_router = hs.config.presence_router_module_class(
+ config=hs.config.presence_router_config, module_api=hs.get_module_api()
+ )
+
+ # Ensure the module has implemented the required methods
+ required_methods = ["get_users_for_states", "get_interested_users"]
+ for method_name in required_methods:
+ if not hasattr(self.custom_presence_router, method_name):
+ raise Exception(
+ "PresenceRouter module '%s' must implement all required methods: %s"
+ % (
+ hs.config.presence_router_module_class.__name__,
+ ", ".join(required_methods),
+ )
+ )
+
+ async def get_users_for_states(
+ self,
+ state_updates: Iterable[UserPresenceState],
+ ) -> Dict[str, Set[UserPresenceState]]:
+ """
+ Given an iterable of user presence updates, determine where each one
+ needs to go.
+
+ Args:
+ state_updates: An iterable of user presence state updates.
+
+ Returns:
+ A dictionary of user_id -> set of UserPresenceState, indicating which
+ presence updates each user should receive.
+ """
+ if self.custom_presence_router is not None:
+ # Ask the custom module
+ return await self.custom_presence_router.get_users_for_states(
+ state_updates=state_updates
+ )
+
+ # Don't include any extra destinations for presence updates
+ return {}
+
+ async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
+ """
+ Retrieve a list of users that `user_id` is interested in receiving the
+ presence of. This will be in addition to those they share a room with.
+ Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate
+ that this user should receive all incoming local and remote presence updates.
+
+ Note that this method will only be called for local users, but can return users
+ that are local or remote.
+
+ Args:
+ user_id: A user requesting presence updates.
+
+ Returns:
+ A set of user IDs to return presence updates for, or ALL_USERS to return all
+ known updates.
+ """
+ if self.custom_presence_router is not None:
+ # Ask the custom module for interested users
+ return await self.custom_presence_router.get_interested_users(
+ user_id=user_id
+ )
+
+ # A custom presence router is not defined.
+ # Don't report any additional interested users
+ return set()
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 184096d165..10244ee0d2 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -104,7 +104,7 @@ class FederationClient(FederationBase):
max_len=1000,
expiry_ms=120 * 1000,
reset_expiry_on_get=False,
- )
+ ) # type: ExpiringCache[str, EventBase]
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index cb48cc5722..794c1138a9 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -815,22 +815,20 @@ class FederationServer(FederationBase):
await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
- def __str__(self):
+ def __str__(self) -> str:
return "<ReplicationLayer(%s)>" % self.server_name
async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
- ):
- ret = await self.handler.exchange_third_party_invite(
+ ) -> None:
+ await self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed
)
- return ret
- async def on_exchange_third_party_invite_request(self, event_dict: Dict):
- ret = await self.handler.on_exchange_third_party_invite_request(event_dict)
- return ret
+ async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
+ await self.handler.on_exchange_third_party_invite_request(event_dict)
- async def check_server_matches_acl(self, server_name: str, room_id: str):
+ async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
"""Check if the given server is allowed by the server ACLs in the room
Args:
@@ -946,6 +944,7 @@ class FederationHandlerRegistry:
# A rate limiter for incoming room key requests per origin.
self._room_key_request_rate_limiter = Ratelimiter(
+ store=hs.get_datastore(),
clock=self.clock,
rate_hz=self.config.rc_key_requests.per_second,
burst_count=self.config.rc_key_requests.burst_count,
@@ -953,7 +952,7 @@ class FederationHandlerRegistry:
def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
- ):
+ ) -> None:
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
@@ -972,7 +971,7 @@ class FederationHandlerRegistry:
def register_query_handler(
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
- ):
+ ) -> None:
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
@@ -990,15 +989,17 @@ class FederationHandlerRegistry:
self.query_handlers[query_type] = handler
- def register_instance_for_edu(self, edu_type: str, instance_name: str):
+ def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None:
"""Register that the EDU handler is on a different instance than master."""
self._edu_type_to_instance[edu_type] = [instance_name]
- def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
+ def register_instances_for_edu(
+ self, edu_type: str, instance_names: List[str]
+ ) -> None:
"""Register that the EDU handler is on multiple instances."""
self._edu_type_to_instance[edu_type] = instance_names
- async def on_edu(self, edu_type: str, origin: str, content: dict):
+ async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
if not self.config.use_presence and edu_type == EduTypes.Presence:
return
@@ -1006,7 +1007,9 @@ class FederationHandlerRegistry:
# the limit, drop them.
if (
edu_type == EduTypes.RoomKeyRequest
- and not self._room_key_request_rate_limiter.can_do_action(origin)
+ and not await self._room_key_request_rate_limiter.can_do_action(
+ None, origin
+ )
):
return
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 8babb1ebbe..d821dcbf6a 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -44,6 +44,7 @@ from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
from synapse.util.metrics import Measure, measure_func
if TYPE_CHECKING:
+ from synapse.events.presence_router import PresenceRouter
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -162,6 +163,7 @@ class FederationSender(AbstractFederationSender):
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
+ self._presence_router = None # type: Optional[PresenceRouter]
self._transaction_manager = TransactionManager(hs)
self._instance_name = hs.get_instance_name()
@@ -584,7 +586,22 @@ class FederationSender(AbstractFederationSender):
"""Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination
"""
- hosts_and_states = await get_interested_remotes(self.store, states, self.state)
+ # We pull the presence router here instead of __init__
+ # to prevent a dependency cycle:
+ #
+ # AuthHandler -> Notifier -> FederationSender
+ # -> PresenceRouter -> ModuleApi -> AuthHandler
+ if self._presence_router is None:
+ self._presence_router = self.hs.get_presence_router()
+
+ assert self._presence_router is not None
+
+ hosts_and_states = await get_interested_remotes(
+ self.store,
+ self._presence_router,
+ states,
+ self.state,
+ )
for destinations, states in hosts_and_states:
for destination in destinations:
@@ -717,16 +734,18 @@ class FederationSender(AbstractFederationSender):
self._catchup_after_startup_timer = None
break
+ last_processed = destinations_to_wake[-1]
+
destinations_to_wake = [
d
for d in destinations_to_wake
if self._federation_shard_config.should_handle(self._instance_name, d)
]
- for last_processed in destinations_to_wake:
+ for destination in destinations_to_wake:
logger.info(
"Destination %s has outstanding catch-up, waking up.",
last_processed,
)
- self.wake_destination(last_processed)
+ self.wake_destination(destination)
await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 89df9a619b..e9c8a9f20a 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -29,6 +29,7 @@ from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.federation.units import Edu
from synapse.handlers.presence import format_user_presence_state
+from synapse.logging.opentracing import SynapseTags, set_tag
from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import ReadReceipt
@@ -557,6 +558,13 @@ class PerDestinationQueue:
contents, stream_id = await self._store.get_new_device_msgs_for_remote(
self._destination, last_device_stream_id, to_device_stream_id, limit
)
+ for content in contents:
+ message_id = content.get("message_id")
+ if not message_id:
+ continue
+
+ set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
+
edus = [
Edu(
origin=self._server_name,
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 294031e2a0..c616de5f22 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -427,13 +427,9 @@ class FederationSendServlet(BaseFederationServlet):
logger.exception(e)
return 400, {"error": "Invalid transaction"}
- try:
- code, response = await self.handler.on_incoming_transaction(
- origin, transaction_data
- )
- except Exception:
- logger.exception("on_incoming_transaction failed")
- raise
+ code, response = await self.handler.on_incoming_transaction(
+ origin, transaction_data
+ )
return code, response
@@ -650,8 +646,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id):
- content = await self.handler.on_exchange_third_party_invite_request(content)
- return 200, content
+ await self.handler.on_exchange_third_party_invite_request(content)
+ return 200, {}
class FederationClientKeysQueryServlet(BaseFederationServlet):
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index b662c42621..0f8bf000ac 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -18,6 +18,7 @@ server protocol.
"""
import logging
+from typing import Optional
import attr
@@ -98,7 +99,7 @@ class Transaction(JsonEncodedObject):
"pdus",
]
- def __init__(self, transaction_id=None, pdus=[], **kwargs):
+ def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs):
"""If we include a list of pdus then we decode then as PDU's
automatically.
"""
@@ -107,7 +108,7 @@ class Transaction(JsonEncodedObject):
if "edus" in kwargs and not kwargs["edus"]:
del kwargs["edus"]
- super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs)
+ super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs)
@staticmethod
def create_new(pdus, **kwargs):
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index aade2c4a3a..fb899aa90d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -49,7 +49,7 @@ class BaseHandler:
# The rate_hz and burst_count are overridden on a per-user basis
self.request_ratelimiter = Ratelimiter(
- clock=self.clock, rate_hz=0, burst_count=0
+ store=self.store, clock=self.clock, rate_hz=0, burst_count=0
)
self._rc_message = self.hs.config.rc_message
@@ -57,6 +57,7 @@ class BaseHandler:
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter(
+ store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
@@ -91,11 +92,6 @@ class BaseHandler:
if app_service is not None:
return # do not ratelimit app service senders
- # Disable rate limiting of users belonging to any AS that is configured
- # not to be rate limited in its registration file (rate_limited: true|false).
- if requester.app_service and not requester.app_service.is_rate_limited():
- return
-
messages_per_second = self._rc_message.per_second
burst_count = self._rc_message.burst_count
@@ -113,11 +109,11 @@ class BaseHandler:
if is_admin_redaction and self.admin_redaction_ratelimiter:
# If we have separate config for admin redactions, use a separate
# ratelimiter as to not have user_ids clash
- self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
+ await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
else:
# Override rate and burst count per-user
- self.request_ratelimiter.ratelimit(
- user_id,
+ await self.request_ratelimiter.ratelimit(
+ requester,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 996f9e5deb..9fb7ee335d 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -182,7 +182,7 @@ class ApplicationServicesHandler:
self,
stream_key: str,
new_token: Optional[int],
- users: Collection[Union[str, UserID]] = [],
+ users: Optional[Collection[Union[str, UserID]]] = None,
):
"""This is called by the notifier in the background
when a ephemeral event handled by the homeserver.
@@ -215,7 +215,7 @@ class ApplicationServicesHandler:
# We only start a new background process if necessary rather than
# optimistically (to cut down on overhead).
self._notify_interested_services_ephemeral(
- services, stream_key, new_token, users
+ services, stream_key, new_token, users or []
)
@wrap_as_background_process("notify_interested_services_ephemeral")
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d537ea8137..08e413bc98 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -238,6 +238,7 @@ class AuthHandler(BaseHandler):
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
+ store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
@@ -248,6 +249,7 @@ class AuthHandler(BaseHandler):
# Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
+ store=self.store,
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
@@ -352,7 +354,7 @@ class AuthHandler(BaseHandler):
requester_user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
- self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
+ await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False)
# build a list of supported flows
supported_ui_auth_types = await self._get_available_ui_auth_types(
@@ -373,7 +375,9 @@ class AuthHandler(BaseHandler):
)
except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
- self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
+ await self._failed_uia_attempts_ratelimiter.can_do_action(
+ requester,
+ )
raise
# find the completed login type
@@ -982,8 +986,8 @@ class AuthHandler(BaseHandler):
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
if ratelimit:
- self._failed_login_attempts_ratelimiter.ratelimit(
- (medium, address), update=False
+ await self._failed_login_attempts_ratelimiter.ratelimit(
+ None, (medium, address), update=False
)
# Check for login providers that support 3pid login types
@@ -1016,8 +1020,8 @@ class AuthHandler(BaseHandler):
# this code path, which is fine as then the per-user ratelimit
# will kick in below.
if ratelimit:
- self._failed_login_attempts_ratelimiter.can_do_action(
- (medium, address)
+ await self._failed_login_attempts_ratelimiter.can_do_action(
+ None, (medium, address)
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@@ -1039,8 +1043,8 @@ class AuthHandler(BaseHandler):
# Check if we've hit the failed ratelimit (but don't update it)
if ratelimit:
- self._failed_login_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(), update=False
+ await self._failed_login_attempts_ratelimiter.ratelimit(
+ None, qualified_user_id.lower(), update=False
)
try:
@@ -1051,8 +1055,8 @@ class AuthHandler(BaseHandler):
# exception and masking the LoginError. The actual ratelimiting
# should have happened above.
if ratelimit:
- self._failed_login_attempts_ratelimiter.can_do_action(
- qualified_user_id.lower()
+ await self._failed_login_attempts_ratelimiter.can_do_action(
+ None, qualified_user_id.lower()
)
raise
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 54293d0b9c..7e76db3e2a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -631,7 +631,7 @@ class DeviceListUpdater:
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
- )
+ ) # type: ExpiringCache[str, Set[str]]
# Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False
@@ -760,7 +760,7 @@ class DeviceListUpdater:
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
- seen_updates = self._seen_updates.get(user_id, set())
+ seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index eb547743be..c971eeb4d2 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -21,10 +21,10 @@ from synapse.api.errors import SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
+ SynapseTags,
get_active_span_text_map,
log_kv,
set_tag,
- start_active_span,
)
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
@@ -81,6 +81,7 @@ class DeviceMessageHandler:
)
self._ratelimiter = Ratelimiter(
+ store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.rc_key_requests.per_second,
burst_count=hs.config.rc_key_requests.burst_count,
@@ -182,7 +183,10 @@ class DeviceMessageHandler:
) -> None:
sender_user_id = requester.user.to_string()
- set_tag("number_of_messages", len(messages))
+ message_id = random_string(16)
+ set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
+
+ log_kv({"number_of_to_device_messages": len(messages)})
set_tag("sender", sender_user_id)
local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
@@ -191,8 +195,8 @@ class DeviceMessageHandler:
if (
message_type == EduTypes.RoomKeyRequest
and user_id != sender_user_id
- and self._ratelimiter.can_do_action(
- (sender_user_id, requester.device_id)
+ and await self._ratelimiter.can_do_action(
+ requester, (sender_user_id, requester.device_id)
)
):
continue
@@ -204,32 +208,35 @@ class DeviceMessageHandler:
"content": message_content,
"type": message_type,
"sender": sender_user_id,
+ "message_id": message_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
+ log_kv(
+ {
+ "user_id": user_id,
+ "device_id": list(messages_by_device),
+ }
+ )
else:
destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device
- message_id = random_string(16)
-
context = get_active_span_text_map()
remote_edu_contents = {}
for destination, messages in remote_messages.items():
- with start_active_span("to_device_for_user"):
- set_tag("destination", destination)
- remote_edu_contents[destination] = {
- "messages": messages,
- "sender": sender_user_id,
- "type": message_type,
- "message_id": message_id,
- "org.matrix.opentracing_context": json_encoder.encode(context),
- }
+ log_kv({"destination": destination})
+ remote_edu_contents[destination] = {
+ "messages": messages,
+ "sender": sender_user_id,
+ "type": message_type,
+ "message_id": message_id,
+ "org.matrix.opentracing_context": json_encoder.encode(context),
+ }
- log_kv({"local_messages": local_messages})
stream_id = await self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)
@@ -238,7 +245,6 @@ class DeviceMessageHandler:
"to_device_key", stream_id, users=local_messages.keys()
)
- log_kv({"remote_messages": remote_messages})
if self.federation_sender:
for destination in remote_messages.keys():
# Enqueue a new federation transaction to send the new
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 2ad9b6d930..92b18378fc 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -38,7 +38,6 @@ from synapse.types import (
)
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
@@ -1008,7 +1007,7 @@ class E2eKeysHandler:
return signature_list, failures
async def _get_e2e_cross_signing_verify_key(
- self, user_id: str, key_type: str, from_user_id: str = None
+ self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Tuple[JsonDict, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key.
@@ -1292,17 +1291,6 @@ class SigningKeyEduUpdater:
# user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
- # Recently seen stream ids. We don't bother keeping these in the DB,
- # but they're useful to have them about to reduce the number of spurious
- # resyncs.
- self._seen_updates = ExpiringCache(
- cache_name="signing_key_update_edu",
- clock=self.clock,
- max_len=10000,
- expiry_ms=30 * 60 * 1000,
- iterable=True,
- )
-
async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
) -> None:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ec2ce679c2..97e347c30b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -22,7 +22,17 @@ import itertools
import logging
from collections.abc import Container
from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Union,
+)
import attr
from signedjson.key import decode_verify_key_bytes
@@ -172,15 +182,17 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
- async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
+ async def on_receive_pdu(
+ self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
+ ) -> None:
"""Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
Args:
- origin (str): server which initiated the /send/ transaction. Will
+ origin: server which initiated the /send/ transaction. Will
be used to fetch missing events or state.
- pdu (FrozenEvent): received PDU
- sent_to_us_directly (bool): True if this event was pushed to us; False if
+ pdu: received PDU
+ sent_to_us_directly: True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
"""
@@ -413,13 +425,15 @@ class FederationHandler(BaseHandler):
await self._process_received_pdu(origin, pdu, state=state)
- async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ async def _get_missing_events_for_pdu(
+ self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
+ ) -> None:
"""
Args:
- origin (str): Origin of the pdu. Will be called to get the missing events
+ origin: Origin of the pdu. Will be called to get the missing events
pdu: received pdu
- prevs (set(str)): List of event ids which we are missing
- min_depth (int): Minimum depth of events to return.
+ prevs: List of event ids which we are missing
+ min_depth: Minimum depth of events to return.
"""
room_id = pdu.room_id
@@ -780,7 +794,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
- ):
+ ) -> None:
"""Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
@@ -889,7 +903,9 @@ class FederationHandler(BaseHandler):
logger.exception("Failed to resync device for %s", sender)
@log_function
- async def backfill(self, dest, room_id, limit, extremities):
+ async def backfill(
+ self, dest: str, room_id: str, limit: int, extremities: List[str]
+ ) -> List[EventBase]:
"""Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -1144,16 +1160,15 @@ class FederationHandler(BaseHandler):
curr_state = await self.state_handler.get_current_state(room_id)
- def get_domains_from_state(state):
+ def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
"""Get joined domains from state
Args:
- state (dict[tuple, FrozenEvent]): State map from type/state
- key to event.
+ state: State map from type/state key to event.
Returns:
- list[tuple[str, int]]: Returns a list of servers with the
- lowest depth of their joins. Sorted by lowest depth first.
+ Returns a list of servers with the lowest depth of their joins.
+ Sorted by lowest depth first.
"""
joined_users = [
(state_key, int(event.depth))
@@ -1181,7 +1196,7 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name
]
- async def try_backfill(domains):
+ async def try_backfill(domains: List[str]) -> bool:
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
@@ -1260,21 +1275,25 @@ class FederationHandler(BaseHandler):
}
for e_id, _ in sorted_extremeties_tuple:
- likely_domains = get_domains_from_state(states[e_id])
+ likely_extremeties_domains = get_domains_from_state(states[e_id])
success = await try_backfill(
- [dom for dom, _ in likely_domains if dom not in tried_domains]
+ [
+ dom
+ for dom, _ in likely_extremeties_domains
+ if dom not in tried_domains
+ ]
)
if success:
return True
- tried_domains.update(dom for dom, _ in likely_domains)
+ tried_domains.update(dom for dom, _ in likely_extremeties_domains)
return False
async def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str]
- ):
+ ) -> None:
"""Fetch the given events from a server, and persist them as outliers.
This function *does not* recursively get missing auth events of the
@@ -1350,7 +1369,7 @@ class FederationHandler(BaseHandler):
event_infos,
)
- def _sanity_check_event(self, ev):
+ def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event
@@ -1359,9 +1378,7 @@ class FederationHandler(BaseHandler):
or cascade of event fetches.
Args:
- ev (synapse.events.EventBase): event to be checked
-
- Returns: None
+ ev: event to be checked
Raises:
SynapseError if the event does not pass muster
@@ -1382,7 +1399,7 @@ class FederationHandler(BaseHandler):
)
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
- async def send_invite(self, target_host, event):
+ async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
"""Sends the invite to the remote server for signing.
Invites must be signed by the invitee's server before distribution.
@@ -1601,12 +1618,13 @@ class FederationHandler(BaseHandler):
)
return event.event_id, stream_id
- async def _handle_queued_pdus(self, room_queue):
+ async def _handle_queued_pdus(
+ self, room_queue: List[Tuple[EventBase, str]]
+ ) -> None:
"""Process PDUs which got queued up while we were busy send_joining.
Args:
- room_queue (list[FrozenEvent, str]): list of PDUs to be processed
- and the servers that sent them
+ room_queue: list of PDUs to be processed and the servers that sent them
"""
for p, origin in room_queue:
try:
@@ -1685,7 +1703,7 @@ class FederationHandler(BaseHandler):
return event
- async def on_send_join_request(self, origin, pdu):
+ async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
"""We have received a join event for a room. Fully process it and
respond with the current state and auth chains.
"""
@@ -1741,7 +1759,7 @@ class FederationHandler(BaseHandler):
async def on_invite_request(
self, origin: str, event: EventBase, room_version: RoomVersion
- ):
+ ) -> EventBase:
"""We've got an invite event. Process and persist it. Sign it.
Respond with the now signed event.
@@ -1791,7 +1809,7 @@ class FederationHandler(BaseHandler):
member_handler = self.hs.get_room_member_handler()
# We don't rate limit based on room ID, as that should be done by
# sending server.
- member_handler.ratelimit_invite(None, event.state_key)
+ await member_handler.ratelimit_invite(None, None, event.state_key)
# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
@@ -1852,7 +1870,7 @@ class FederationHandler(BaseHandler):
room_id: str,
user_id: str,
membership: str,
- content: JsonDict = {},
+ content: JsonDict,
params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
) -> Tuple[str, EventBase, RoomVersion]:
(
@@ -1921,7 +1939,7 @@ class FederationHandler(BaseHandler):
return event
- async def on_send_leave_request(self, origin, pdu):
+ async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
""" We have received a leave event for a room. Fully process it."""
event = pdu
@@ -2163,12 +2181,17 @@ class FederationHandler(BaseHandler):
else:
return None
- async def get_min_depth_for_context(self, context):
+ async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context)
async def _handle_new_event(
- self, origin, event, state=None, auth_events=None, backfilled=False
- ):
+ self,
+ origin: str,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]] = None,
+ auth_events: Optional[MutableStateMap[EventBase]] = None,
+ backfilled: bool = False,
+ ) -> EventContext:
context = await self._prep_event(
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
)
@@ -2474,40 +2497,14 @@ class FederationHandler(BaseHandler):
logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True
- async def on_query_auth(
- self, origin, event_id, room_id, remote_auth_chain, rejects, missing
- ):
- in_room = await self.auth.check_host_in_room(room_id, origin)
- if not in_room:
- raise AuthError(403, "Host not in room.")
-
- event = await self.store.get_event(event_id, check_room_id=room_id)
-
- # Just go through and process each event in `remote_auth_chain`. We
- # don't want to fall into the trap of `missing` being wrong.
- for e in remote_auth_chain:
- try:
- await self._handle_new_event(origin, e)
- except AuthError:
- pass
-
- # Now get the current auth_chain for the event.
- local_auth_chain = await self.store.get_auth_chain(
- room_id, list(event.auth_event_ids()), include_given=True
- )
-
- # TODO: Check if we would now reject event_id. If so we need to tell
- # everyone.
-
- ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
-
- logger.debug("on_query_auth returning: %s", ret)
-
- return ret
-
async def on_get_missing_events(
- self, origin, room_id, earliest_events, latest_events, limit
- ):
+ self,
+ origin: str,
+ room_id: str,
+ earliest_events: List[str],
+ latest_events: List[str],
+ limit: int,
+ ) -> List[EventBase]:
in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -2811,8 +2808,8 @@ class FederationHandler(BaseHandler):
assumes that we have already processed all events in remote_auth
Params:
- local_auth (list)
- remote_auth (list)
+ local_auth
+ remote_auth
Returns:
dict
@@ -2936,8 +2933,8 @@ class FederationHandler(BaseHandler):
@log_function
async def exchange_third_party_invite(
- self, sender_user_id, target_user_id, room_id, signed
- ):
+ self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
+ ) -> None:
third_party_invite = {"signed": signed}
event_dict = {
@@ -3029,8 +3026,12 @@ class FederationHandler(BaseHandler):
await member_handler.send_membership_event(None, event, context)
async def add_display_name_to_third_party_invite(
- self, room_version, event_dict, event, context
- ):
+ self,
+ room_version: str,
+ event_dict: JsonDict,
+ event: EventBase,
+ context: EventContext,
+ ) -> Tuple[EventBase, EventContext]:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"],
@@ -3066,13 +3067,13 @@ class FederationHandler(BaseHandler):
EventValidator().validate_new(event, self.config)
return (event, context)
- async def _check_signature(self, event, context):
+ async def _check_signature(self, event: EventBase, context: EventContext) -> None:
"""
Checks that the signature in the event is consistent with its invite.
Args:
- event (Event): The m.room.member event to check
- context (EventContext):
+ event: The m.room.member event to check
+ context:
Raises:
AuthError: if signature didn't match any keys, or key has been
@@ -3158,13 +3159,13 @@ class FederationHandler(BaseHandler):
raise last_exception
- async def _check_key_revocation(self, public_key, url):
+ async def _check_key_revocation(self, public_key: str, url: str) -> None:
"""
Checks whether public_key has been revoked.
Args:
- public_key (str): base-64 encoded public key.
- url (str): Key revocation URL.
+ public_key: base-64 encoded public key.
+ url: Key revocation URL.
Raises:
AuthError: if they key has been revoked.
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index f4c1265b43..88100bc7da 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -64,17 +64,19 @@ class IdentityHandler(BaseHandler):
# Ratelimiters for `/requestToken` endpoints.
self._3pid_validation_ratelimiter_ip = Ratelimiter(
+ store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
self._3pid_validation_ratelimiter_address = Ratelimiter(
+ store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
- def ratelimit_request_token_requests(
+ async def ratelimit_request_token_requests(
self,
request: SynapseRequest,
medium: str,
@@ -88,8 +90,12 @@ class IdentityHandler(BaseHandler):
address: The actual threepid ID, e.g. the phone number or email address
"""
- self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
- self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+ await self._3pid_validation_ratelimiter_ip.ratelimit(
+ None, (medium, request.getClientIP())
+ )
+ await self._3pid_validation_ratelimiter_address.ratelimit(
+ None, (medium, address)
+ )
async def threepid_from_creds(
self, id_server_url: str, creds: Dict[str, str]
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 67a8410276..bfff43c802 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -139,7 +139,7 @@ class MessageHandler:
self,
user_id: str,
room_id: str,
- state_filter: StateFilter = StateFilter.all(),
+ state_filter: Optional[StateFilter] = None,
at_token: Optional[StreamToken] = None,
is_guest: bool = False,
) -> List[dict]:
@@ -166,6 +166,8 @@ class MessageHandler:
AuthError (403) if the user doesn't have permission to view
members of this room.
"""
+ state_filter = state_filter or StateFilter.all()
+
if at_token:
# FIXME this claims to get the state at a stream position, but
# get_recent_events_for_room operates by topo ordering. This therefore
@@ -387,7 +389,7 @@ class EventCreationHandler:
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
- self.room_invite_state_types = self.hs.config.room_invite_state_types
+ self.room_invite_state_types = self.hs.config.api.room_prejoin_state
self.membership_types_to_include_profile_data_in = (
{Membership.JOIN, Membership.INVITE, Membership.KNOCK}
@@ -876,7 +878,7 @@ class EventCreationHandler:
event: EventBase,
context: EventContext,
ratelimit: bool = True,
- extra_users: List[UserID] = [],
+ extra_users: Optional[List[UserID]] = None,
ignore_shadow_ban: bool = False,
) -> EventBase:
"""Processes a new event.
@@ -904,6 +906,7 @@ class EventCreationHandler:
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
+ extra_users = extra_users or []
# we don't apply shadow-banning to membership events here. Invites are blocked
# higher up the stack, and we allow shadow-banned users to send join and leave
@@ -1073,7 +1076,7 @@ class EventCreationHandler:
event: EventBase,
context: EventContext,
ratelimit: bool = True,
- extra_users: List[UserID] = [],
+ extra_users: Optional[List[UserID]] = None,
) -> EventBase:
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
@@ -1085,6 +1088,8 @@ class EventCreationHandler:
it was de-duplicated (e.g. because we had already persisted an
event with the same transaction ID.)
"""
+ extra_users = extra_users or []
+
assert self.storage.persistence is not None
assert self._events_shard_config.should_handle(
self._instance_name, event.room_id
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index da92feacc9..0047907cd9 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,7 +25,17 @@ The methods that define policy are:
import abc
import logging
from contextlib import contextmanager
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ FrozenSet,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
from prometheus_client import Counter
from typing_extensions import ContextManager
@@ -34,6 +44,7 @@ import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
+from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
@@ -42,7 +53,7 @@ from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -209,6 +220,7 @@ class PresenceHandler(BasePresenceHandler):
self.notifier = hs.get_notifier()
self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler()
+ self.presence_router = hs.get_presence_router()
self._presence_enabled = hs.config.use_presence
federation_registry = hs.get_federation_registry()
@@ -653,7 +665,7 @@ class PresenceHandler(BasePresenceHandler):
"""
stream_id, max_token = await self.store.update_presence(states)
- parties = await get_interested_parties(self.store, states)
+ parties = await get_interested_parties(self.store, self.presence_router, states)
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
@@ -1041,7 +1053,12 @@ class PresenceEventSource:
#
# Presence -> Notifier -> PresenceEventSource -> Presence
#
+ # Same with get_module_api, get_presence_router
+ #
+ # AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler
self.get_presence_handler = hs.get_presence_handler
+ self.get_module_api = hs.get_module_api
+ self.get_presence_router = hs.get_presence_router
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@@ -1054,8 +1071,8 @@ class PresenceEventSource:
room_ids=None,
include_offline=True,
explicit_room_id=None,
- **kwargs
- ):
+ **kwargs,
+ ) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are:
# 1. Get the rooms the user is in.
# 2. Get the list of user in the rooms.
@@ -1068,7 +1085,17 @@ class PresenceEventSource:
# We don't try and limit the presence updates by the current token, as
# sending down the rare duplicate is not a concern.
+ user_id = user.to_string()
+ stream_change_cache = self.store.presence_stream_cache
+
with Measure(self.clock, "presence.get_new_events"):
+ if user_id in self.get_module_api()._send_full_presence_to_local_users:
+ # This user has been specified by a module to receive all current, online
+ # user presence. Removing from_key and setting include_offline to false
+ # will do effectively this.
+ from_key = None
+ include_offline = False
+
if from_key is not None:
from_key = int(from_key)
@@ -1091,59 +1118,209 @@ class PresenceEventSource:
# doesn't return. C.f. #5503.
return [], max_token
- presence = self.get_presence_handler()
- stream_change_cache = self.store.presence_stream_cache
-
+ # Figure out which other users this user should receive updates for
users_interested_in = await self._get_interested_in(user, explicit_room_id)
- user_ids_changed = set() # type: Collection[str]
- changed = None
- if from_key:
- changed = stream_change_cache.get_all_entities_changed(from_key)
+ # We have a set of users that we're interested in the presence of. We want to
+ # cross-reference that with the users that have actually changed their presence.
- if changed is not None and len(changed) < 500:
- assert isinstance(user_ids_changed, set)
+ # Check whether this user should see all user updates
- # For small deltas, its quicker to get all changes and then
- # work out if we share a room or they're in our presence list
- get_updates_counter.labels("stream").inc()
- for other_user_id in changed:
- if other_user_id in users_interested_in:
- user_ids_changed.add(other_user_id)
- else:
- # Too many possible updates. Find all users we can see and check
- # if any of them have changed.
- get_updates_counter.labels("full").inc()
+ if users_interested_in == PresenceRouter.ALL_USERS:
+ # Provide presence state for all users
+ presence_updates = await self._filter_all_presence_updates_for_user(
+ user_id, include_offline, from_key
+ )
- if from_key:
- user_ids_changed = stream_change_cache.get_entities_changed(
- users_interested_in, from_key
+ # Remove the user from the list of users to receive all presence
+ if user_id in self.get_module_api()._send_full_presence_to_local_users:
+ self.get_module_api()._send_full_presence_to_local_users.remove(
+ user_id
)
+
+ return presence_updates, max_token
+
+ # Make mypy happy. users_interested_in should now be a set
+ assert not isinstance(users_interested_in, str)
+
+ # The set of users that we're interested in and that have had a presence update.
+ # We'll actually pull the presence updates for these users at the end.
+ interested_and_updated_users = (
+ set()
+ ) # type: Union[Set[str], FrozenSet[str]]
+
+ if from_key:
+ # First get all users that have had a presence update
+ updated_users = stream_change_cache.get_all_entities_changed(from_key)
+
+ # Cross-reference users we're interested in with those that have had updates.
+ # Use a slightly-optimised method for processing smaller sets of updates.
+ if updated_users is not None and len(updated_users) < 500:
+ # For small deltas, it's quicker to get all changes and then
+ # cross-reference with the users we're interested in
+ get_updates_counter.labels("stream").inc()
+ for other_user_id in updated_users:
+ if other_user_id in users_interested_in:
+ # mypy thinks this variable could be a FrozenSet as it's possibly set
+ # to one in the `get_entities_changed` call below, and `add()` is not
+ # method on a FrozenSet. That doesn't affect us here though, as
+ # `interested_and_updated_users` is clearly a set() above.
+ interested_and_updated_users.add(other_user_id) # type: ignore
else:
- user_ids_changed = users_interested_in
+ # Too many possible updates. Find all users we can see and check
+ # if any of them have changed.
+ get_updates_counter.labels("full").inc()
- updates = await presence.current_state_for_users(user_ids_changed)
+ interested_and_updated_users = (
+ stream_change_cache.get_entities_changed(
+ users_interested_in, from_key
+ )
+ )
+ else:
+ # No from_key has been specified. Return the presence for all users
+ # this user is interested in
+ interested_and_updated_users = users_interested_in
+
+ # Retrieve the current presence state for each user
+ users_to_state = await self.get_presence_handler().current_state_for_users(
+ interested_and_updated_users
+ )
+ presence_updates = list(users_to_state.values())
- if include_offline:
- return (list(updates.values()), max_token)
+ # Remove the user from the list of users to receive all presence
+ if user_id in self.get_module_api()._send_full_presence_to_local_users:
+ self.get_module_api()._send_full_presence_to_local_users.remove(user_id)
+
+ if not include_offline:
+ # Filter out offline presence states
+ presence_updates = self._filter_offline_presence_state(presence_updates)
+
+ return presence_updates, max_token
+
+ async def _filter_all_presence_updates_for_user(
+ self,
+ user_id: str,
+ include_offline: bool,
+ from_key: Optional[int] = None,
+ ) -> List[UserPresenceState]:
+ """
+ Computes the presence updates a user should receive.
+
+ First pulls presence updates from the database. Then consults PresenceRouter
+ for whether any updates should be excluded by user ID.
+
+ Args:
+ user_id: The User ID of the user to compute presence updates for.
+ include_offline: Whether to include offline presence states from the results.
+ from_key: The minimum stream ID of updates to pull from the database
+ before filtering.
+
+ Returns:
+ A list of presence states for the given user to receive.
+ """
+ if from_key:
+ # Only return updates since the last sync
+ updated_users = self.store.presence_stream_cache.get_all_entities_changed(
+ from_key
+ )
+ if not updated_users:
+ updated_users = []
+
+ # Get the actual presence update for each change
+ users_to_state = await self.get_presence_handler().current_state_for_users(
+ updated_users
+ )
+ presence_updates = list(users_to_state.values())
+
+ if not include_offline:
+ # Filter out offline states
+ presence_updates = self._filter_offline_presence_state(presence_updates)
else:
- return (
- [s for s in updates.values() if s.state != PresenceState.OFFLINE],
- max_token,
+ users_to_state = await self.store.get_presence_for_all_users(
+ include_offline=include_offline
)
+ presence_updates = list(users_to_state.values())
+
+ # TODO: This feels wildly inefficient, and it's unfortunate we need to ask the
+ # module for information on a number of users when we then only take the info
+ # for a single user
+
+ # Filter through the presence router
+ users_to_state_set = await self.get_presence_router().get_users_for_states(
+ presence_updates
+ )
+
+ # We only want the mapping for the syncing user
+ presence_updates = list(users_to_state_set[user_id])
+
+ # Return presence information for all users
+ return presence_updates
+
+ def _filter_offline_presence_state(
+ self, presence_updates: Iterable[UserPresenceState]
+ ) -> List[UserPresenceState]:
+ """Given an iterable containing user presence updates, return a list with any offline
+ presence states removed.
+
+ Args:
+ presence_updates: Presence states to filter
+
+ Returns:
+ A new list with any offline presence states removed.
+ """
+ return [
+ update
+ for update in presence_updates
+ if update.state != PresenceState.OFFLINE
+ ]
+
def get_current_key(self):
return self.store.get_current_presence_token()
@cached(num_args=2, cache_context=True)
- async def _get_interested_in(self, user, explicit_room_id, cache_context):
+ async def _get_interested_in(
+ self,
+ user: UserID,
+ explicit_room_id: Optional[str] = None,
+ cache_context: Optional[_CacheContext] = None,
+ ) -> Union[Set[str], str]:
"""Returns the set of users that the given user should see presence
- updates for
+ updates for.
+
+ Args:
+ user: The user to retrieve presence updates for.
+ explicit_room_id: The users that are in the room will be returned.
+
+ Returns:
+ A set of user IDs to return presence updates for, or "ALL" to return all
+ known updates.
"""
user_id = user.to_string()
users_interested_in = set()
users_interested_in.add(user_id) # So that we receive our own presence
+ # cache_context isn't likely to ever be None due to the @cached decorator,
+ # but we can't have a non-optional argument after the optional argument
+ # explicit_room_id either. Assert cache_context is not None so we can use it
+ # without mypy complaining.
+ assert cache_context
+
+ # Check with the presence router whether we should poll additional users for
+ # their presence information
+ additional_users = await self.get_presence_router().get_interested_users(
+ user.to_string()
+ )
+ if additional_users == PresenceRouter.ALL_USERS:
+ # If the module requested that this user see the presence updates of *all*
+ # users, then simply return that instead of calculating what rooms this
+ # user shares
+ return PresenceRouter.ALL_USERS
+
+ # Add the additional users from the router
+ users_interested_in.update(additional_users)
+
+ # Find the users who share a room with this user
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id, on_invalidate=cache_context.invalidate
)
@@ -1314,14 +1491,15 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
async def get_interested_parties(
- store: DataStore, states: List[UserPresenceState]
+ store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState]
) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]:
"""Given a list of states return which entities (rooms, users)
are interested in the given states.
Args:
- store
- states
+ store: The homeserver's data store.
+ presence_router: A module for augmenting the destinations for presence updates.
+ states: A list of incoming user presence updates.
Returns:
A 2-tuple of `(room_ids_to_states, users_to_states)`,
@@ -1337,11 +1515,22 @@ async def get_interested_parties(
# Always notify self
users_to_states.setdefault(state.user_id, []).append(state)
+ # Ask a presence routing module for any additional parties if one
+ # is loaded.
+ router_users_to_states = await presence_router.get_users_for_states(states)
+
+ # Update the dictionaries with additional destinations and state to send
+ for user_id, user_states in router_users_to_states.items():
+ users_to_states.setdefault(user_id, []).extend(user_states)
+
return room_ids_to_states, users_to_states
async def get_interested_remotes(
- store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
+ store: DataStore,
+ presence_router: PresenceRouter,
+ states: List[UserPresenceState],
+ state_handler: StateHandler,
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
@@ -1349,9 +1538,10 @@ async def get_interested_remotes(
All the presence states should be for local users only.
Args:
- store
- states
- state_handler
+ store: The homeserver's data store.
+ presence_router: A module for augmenting the destinations for presence updates.
+ states: A list of incoming user presence updates.
+ state_handler:
Returns:
A list of 2-tuples of destinations and states, where for
@@ -1363,7 +1553,9 @@ async def get_interested_remotes(
# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote
# hosts in those rooms.
- room_ids_to_states, users_to_states = await get_interested_parties(store, states)
+ room_ids_to_states, users_to_states = await get_interested_parties(
+ store, presence_router, states
+ )
for room_id, states in room_ids_to_states.items():
hosts = await state_handler.get_current_hosts_in_room(room_id)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 1b2a515ee9..8d4a4612a3 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -184,7 +184,7 @@ class RegistrationHandler(BaseHandler):
user_type: Optional[str] = None,
default_display_name: Optional[str] = None,
address: Optional[str] = None,
- bind_emails: Iterable[str] = [],
+ bind_emails: Optional[Iterable[str]] = None,
by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
auth_provider_id: Optional[str] = None,
@@ -219,7 +219,9 @@ class RegistrationHandler(BaseHandler):
Raises:
SynapseError if there was a problem registering.
"""
- self.check_registration_ratelimit(address)
+ bind_emails = bind_emails or []
+
+ await self.check_registration_ratelimit(address)
result = await self.spam_checker.check_registration_for_spam(
threepid,
@@ -663,7 +665,7 @@ class RegistrationHandler(BaseHandler):
},
)
- def check_registration_ratelimit(self, address: Optional[str]) -> None:
+ async def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
@@ -677,7 +679,7 @@ class RegistrationHandler(BaseHandler):
if not address:
return
- self.ratelimiter.ratelimit(address)
+ await self.ratelimiter.ratelimit(None, address)
async def register_with_store(
self,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 10af3782f4..89e95c8ae9 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -20,7 +20,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse import types
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -29,6 +29,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import (
@@ -83,22 +84,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.allow_per_room_profiles = self.config.allow_per_room_profiles
self._join_rate_limiter_local = Ratelimiter(
+ store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
)
self._join_rate_limiter_remote = Ratelimiter(
+ store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
self._invites_per_room_limiter = Ratelimiter(
+ store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
)
self._invites_per_user_limiter = Ratelimiter(
+ store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
@@ -206,15 +211,76 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
async def forget(self, user: UserID, room_id: str) -> None:
raise NotImplementedError()
- def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
+ async def ratelimit_invite(
+ self,
+ requester: Optional[Requester],
+ room_id: Optional[str],
+ invitee_user_id: str,
+ ):
"""Ratelimit invites by room and by target user.
If room ID is missing then we just rate limit by target user.
"""
if room_id:
- self._invites_per_room_limiter.ratelimit(room_id)
+ await self._invites_per_room_limiter.ratelimit(requester, room_id)
- self._invites_per_user_limiter.ratelimit(invitee_user_id)
+ await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
+
+ async def _can_join_without_invite(
+ self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
+ ) -> bool:
+ """
+ Check whether a user can join a room without an invite.
+
+ When joining a room with restricted joined rules (as defined in MSC3083),
+ the membership of spaces must be checked during join.
+
+ Args:
+ state_ids: The state of the room as it currently is.
+ room_version: The room version of the room being joined.
+ user_id: The user joining the room.
+
+ Returns:
+ True if the user can join the room, false otherwise.
+ """
+ # This only applies to room versions which support the new join rule.
+ if not room_version.msc3083_join_rules:
+ return True
+
+ # If there's no join rule, then it defaults to public (so this doesn't apply).
+ join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+ if not join_rules_event_id:
+ return True
+
+ # If the join rule is not restricted, this doesn't apply.
+ join_rules_event = await self.store.get_event(join_rules_event_id)
+ if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
+ return True
+
+ # If allowed is of the wrong form, then only allow invited users.
+ allowed_spaces = join_rules_event.content.get("allow", [])
+ if not isinstance(allowed_spaces, list):
+ return False
+
+ # Get the list of joined rooms and see if there's an overlap.
+ joined_rooms = await self.store.get_rooms_for_user(user_id)
+
+ # Pull out the other room IDs, invalid data gets filtered.
+ for space in allowed_spaces:
+ if not isinstance(space, dict):
+ continue
+
+ space_id = space.get("space")
+ if not isinstance(space_id, str):
+ continue
+
+ # The user was joined to one of the spaces specified, they can join
+ # this room!
+ if space_id in joined_rooms:
+ return True
+
+ # The user was not in any of the required spaces.
+ return False
async def _local_membership_update(
self,
@@ -273,9 +339,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if event.membership == Membership.JOIN:
newly_joined = True
+ user_is_invited = False
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
+ user_is_invited = prev_member_event.membership == Membership.INVITE
+
+ # If the member is not already in the room and is not accepting an invite,
+ # check if they should be allowed access via membership in a space.
+ if (
+ newly_joined
+ and not user_is_invited
+ and not await self._can_join_without_invite(
+ prev_state_ids, event.room_version, user_id
+ )
+ ):
+ raise AuthError(
+ 403,
+ "You do not belong to any of the required spaces to join this room.",
+ )
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
@@ -284,7 +366,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
(
allowed,
time_allowed,
- ) = self._join_rate_limiter_local.can_requester_do_action(requester)
+ ) = await self._join_rate_limiter_local.can_do_action(requester)
if not allowed:
raise LimitExceededError(
@@ -471,9 +553,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if effective_membership_state == Membership.INVITE:
target_id = target.to_string()
if ratelimit:
- # Don't ratelimit application services.
- if not requester.app_service or requester.app_service.is_rate_limited():
- self.ratelimit_invite(room_id, target_id)
+ await self.ratelimit_invite(requester, room_id, target_id)
# block any attempts to invite the server notices mxid
if target_id == self._server_notices_mxid:
@@ -610,7 +690,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
(
allowed,
time_allowed,
- ) = self._join_rate_limiter_remote.can_requester_do_action(
+ ) = await self._join_rate_limiter_remote.can_do_action(
requester,
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 17277619ad..6d6d6ed0df 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -24,6 +24,7 @@ from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.filtering import FilterCollection
from synapse.events import EventBase
from synapse.logging.context import current_context
+from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
@@ -265,13 +266,13 @@ class SyncHandler:
self.storage = hs.get_storage()
self.state_store = self.storage.state
- # ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
+ # ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
"lazy_loaded_members_cache",
self.clock,
max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
- )
+ ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
async def wait_for_sync_for_user(
self,
@@ -354,7 +355,14 @@ class SyncHandler:
full_state: bool = False,
) -> SyncResult:
"""Get the sync for client needed to match what the server has now."""
- return await self.generate_sync_result(sync_config, since_token, full_state)
+ with start_active_span("current_sync_for_user"):
+ log_kv({"since_token": since_token})
+ sync_result = await self.generate_sync_result(
+ sync_config, since_token, full_state
+ )
+
+ set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
+ return sync_result
async def push_rules_for_user(self, user: UserID) -> JsonDict:
user_id = user.to_string()
@@ -554,7 +562,7 @@ class SyncHandler:
)
async def get_state_after_event(
- self, event: EventBase, state_filter: StateFilter = StateFilter.all()
+ self, event: EventBase, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""
Get the room state after the given event
@@ -564,7 +572,7 @@ class SyncHandler:
state_filter: The state filter used to fetch state from the database.
"""
state_ids = await self.state_store.get_state_ids_for_event(
- event.event_id, state_filter=state_filter
+ event.event_id, state_filter=state_filter or StateFilter.all()
)
if event.is_state():
state_ids = dict(state_ids)
@@ -575,7 +583,7 @@ class SyncHandler:
self,
room_id: str,
stream_position: StreamToken,
- state_filter: StateFilter = StateFilter.all(),
+ state_filter: Optional[StateFilter] = None,
) -> StateMap[str]:
"""Get the room state at a particular stream position
@@ -595,7 +603,7 @@ class SyncHandler:
if last_events:
last_event = last_events[-1]
state = await self.get_state_after_event(
- last_event, state_filter=state_filter
+ last_event, state_filter=state_filter or StateFilter.all()
)
else:
@@ -739,8 +747,10 @@ class SyncHandler:
def get_lazy_loaded_members_cache(
self, cache_key: Tuple[str, Optional[str]]
- ) -> LruCache:
- cache = self.lazy_loaded_members_cache.get(cache_key)
+ ) -> LruCache[str, str]:
+ cache = self.lazy_loaded_members_cache.get(
+ cache_key
+ ) # type: Optional[LruCache[str, str]]
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
@@ -978,6 +988,7 @@ class SyncHandler:
# to query up to a given point.
# Always use the `now_token` in `SyncResultBuilder`
now_token = self.event_sources.get_current_token()
+ log_kv({"now_token": now_token})
logger.debug(
"Calculating sync response for %r between %s and %s",
@@ -1245,6 +1256,13 @@ class SyncHandler:
user_id, device_id, since_stream_id, now_token.to_device_key
)
+ for message in messages:
+ # We pop here as we shouldn't be sending the message ID down
+ # `/sync`
+ message_id = message.pop("message_id", None)
+ if message_id:
+ set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
+
logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d)",
len(messages),
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 096d199f4c..bb35af099d 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -19,7 +19,10 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.replication.tcp.streams import TypingStream
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -86,6 +89,7 @@ class FollowerTypingHandler:
self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
+ @wrap_as_background_process("typing._handle_timeouts")
def _handle_timeouts(self) -> None:
logger.debug("Checking for typing timeouts")
diff --git a/synapse/http/client.py b/synapse/http/client.py
index a0caba84e4..f7a07f0466 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -297,7 +297,7 @@ class SimpleHttpClient:
def __init__(
self,
hs: "HomeServer",
- treq_args: Dict[str, Any] = {},
+ treq_args: Optional[Dict[str, Any]] = None,
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
use_proxy: bool = False,
@@ -317,7 +317,7 @@ class SimpleHttpClient:
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
- self._extra_treq_args = treq_args
+ self._extra_treq_args = treq_args or {}
self.user_agent = hs.version_string
self.clock = hs.get_clock()
@@ -590,7 +590,7 @@ class SimpleHttpClient:
uri: str,
json_body: Any,
args: Optional[QueryParams] = None,
- headers: RawHeaders = None,
+ headers: Optional[RawHeaders] = None,
) -> Any:
"""Puts some json to the given URI.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 5f01ebd3d4..ab47dec8f2 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -272,7 +272,7 @@ class MatrixFederationHttpClient:
self,
request: MatrixFederationRequest,
try_trailing_slash_on_400: bool = False,
- **send_request_args
+ **send_request_args,
) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 16ec850064..ea5ad14cb0 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -27,7 +27,7 @@ from twisted.python.failure import Failure
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent
+from twisted.web.iweb import IAgent, IPolicyForHTTPS
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
@@ -88,12 +88,14 @@ class ProxyAgent(_AgentBase):
self,
reactor,
proxy_reactor=None,
- contextFactory=BrowserLikePolicyForHTTPS(),
+ contextFactory: Optional[IPolicyForHTTPS] = None,
connectTimeout=None,
bindAddress=None,
pool=None,
use_proxy=False,
):
+ contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
+
_AgentBase.__init__(self, reactor, pool)
if proxy_reactor is None:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 47754aff43..32b5e19c09 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
import contextlib
import logging
import time
-from typing import Optional, Type, Union
+from typing import Optional, Tuple, Type, Union
import attr
from zope.interface import implementer
@@ -26,7 +26,11 @@ from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
from synapse.http import get_request_user_agent, redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import (
+ ContextRequest,
+ LoggingContext,
+ PreserveLoggingContext,
+)
from synapse.types import Requester
logger = logging.getLogger(__name__)
@@ -63,7 +67,7 @@ class SynapseRequest(Request):
# The requester, if authenticated. For federation requests this is the
# server name, for client requests this is the Requester object.
- self.requester = None # type: Optional[Union[Requester, str]]
+ self._requester = None # type: Optional[Union[Requester, str]]
# we can't yet create the logcontext, as we don't know the method.
self.logcontext = None # type: Optional[LoggingContext]
@@ -93,6 +97,31 @@ class SynapseRequest(Request):
self.site.site_tag,
)
+ @property
+ def requester(self) -> Optional[Union[Requester, str]]:
+ return self._requester
+
+ @requester.setter
+ def requester(self, value: Union[Requester, str]) -> None:
+ # Store the requester, and update some properties based on it.
+
+ # This should only be called once.
+ assert self._requester is None
+
+ self._requester = value
+
+ # A logging context should exist by now (and have a ContextRequest).
+ assert self.logcontext is not None
+ assert self.logcontext.request is not None
+
+ (
+ requester,
+ authenticated_entity,
+ ) = self.get_authenticated_entity()
+ self.logcontext.request.requester = requester
+ # If there's no authenticated entity, it was the requester.
+ self.logcontext.request.authenticated_entity = authenticated_entity or requester
+
def get_request_id(self):
return "%s-%i" % (self.get_method(), self.request_seq)
@@ -126,13 +155,60 @@ class SynapseRequest(Request):
return self.method.decode("ascii")
return method
+ def get_authenticated_entity(self) -> Tuple[Optional[str], Optional[str]]:
+ """
+ Get the "authenticated" entity of the request, which might be the user
+ performing the action, or a user being puppeted by a server admin.
+
+ Returns:
+ A tuple:
+ The first item is a string representing the user making the request.
+
+ The second item is a string or None representing the user who
+ authenticated when making this request. See
+ Requester.authenticated_entity.
+ """
+ # Convert the requester into a string that we can log
+ if isinstance(self._requester, str):
+ return self._requester, None
+ elif isinstance(self._requester, Requester):
+ requester = self._requester.user.to_string()
+ authenticated_entity = self._requester.authenticated_entity
+
+ # If this is a request where the target user doesn't match the user who
+ # authenticated (e.g. and admin is puppetting a user) then we return both.
+ if self._requester.user.to_string() != authenticated_entity:
+ return requester, authenticated_entity
+
+ return requester, None
+ elif self._requester is not None:
+ # This shouldn't happen, but we log it so we don't lose information
+ # and can see that we're doing something wrong.
+ return repr(self._requester), None # type: ignore[unreachable]
+
+ return None, None
+
def render(self, resrc):
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
# create a LogContext for this request
request_id = self.get_request_id()
- self.logcontext = LoggingContext(request_id, request=request_id)
+ self.logcontext = LoggingContext(
+ request_id,
+ request=ContextRequest(
+ request_id=request_id,
+ ip_address=self.getClientIP(),
+ site_tag=self.site.site_tag,
+ # The requester is going to be unknown at this point.
+ requester=None,
+ authenticated_entity=None,
+ method=self.get_method(),
+ url=self.get_redacted_uri(),
+ protocol=self.clientproto.decode("ascii", errors="replace"),
+ user_agent=get_request_user_agent(self),
+ ),
+ )
# override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string)
@@ -277,25 +353,6 @@ class SynapseRequest(Request):
# to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time
- # Convert the requester into a string that we can log
- authenticated_entity = None
- if isinstance(self.requester, str):
- authenticated_entity = self.requester
- elif isinstance(self.requester, Requester):
- authenticated_entity = self.requester.authenticated_entity
-
- # If this is a request where the target user doesn't match the user who
- # authenticated (e.g. and admin is puppetting a user) then we log both.
- if self.requester.user.to_string() != authenticated_entity:
- authenticated_entity = "{},{}".format(
- authenticated_entity,
- self.requester.user.to_string(),
- )
- elif self.requester is not None:
- # This shouldn't happen, but we log it so we don't lose information
- # and can see that we're doing something wrong.
- authenticated_entity = repr(self.requester) # type: ignore[unreachable]
-
user_agent = get_request_user_agent(self, "-")
code = str(self.code)
@@ -305,6 +362,13 @@ class SynapseRequest(Request):
code += "!"
log_level = logging.INFO if self._should_log_request() else logging.DEBUG
+
+ # If this is a request where the target user doesn't match the user who
+ # authenticated (e.g. and admin is puppetting a user) then we log both.
+ requester, authenticated_entity = self.get_authenticated_entity()
+ if authenticated_entity:
+ requester = "{}.{}".format(authenticated_entity, requester)
+
self.site.access_logger.log(
log_level,
"%s - %s - {%s}"
@@ -312,7 +376,7 @@ class SynapseRequest(Request):
' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(),
self.site.site_tag,
- authenticated_entity,
+ requester,
processing_time,
response_send_time,
usage.ru_utime,
@@ -433,7 +497,7 @@ class SynapseSite(Site):
resource,
server_version_string,
*args,
- **kwargs
+ **kwargs,
):
Site.__init__(self, resource, *args, **kwargs)
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 03cf3c2b8e..dbd7d3a33a 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -22,7 +22,6 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
-
import inspect
import logging
import threading
@@ -30,6 +29,7 @@ import types
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
+import attr
from typing_extensions import Literal
from twisted.internet import defer, threads
@@ -181,6 +181,29 @@ class ContextResourceUsage:
return res
+@attr.s(slots=True)
+class ContextRequest:
+ """
+ A bundle of attributes from the SynapseRequest object.
+
+ This exists to:
+
+ * Avoid a cycle between LoggingContext and SynapseRequest.
+ * Be a single variable that can be passed from parent LoggingContexts to
+ their children.
+ """
+
+ request_id = attr.ib(type=str)
+ ip_address = attr.ib(type=str)
+ site_tag = attr.ib(type=str)
+ requester = attr.ib(type=Optional[str])
+ authenticated_entity = attr.ib(type=Optional[str])
+ method = attr.ib(type=str)
+ url = attr.ib(type=str)
+ protocol = attr.ib(type=str)
+ user_agent = attr.ib(type=str)
+
+
LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
@@ -254,9 +277,9 @@ class LoggingContext:
def __init__(
self,
- name: Optional[str] = None,
+ name: str,
parent_context: "Optional[LoggingContext]" = None,
- request: Optional[str] = None,
+ request: Optional[ContextRequest] = None,
) -> None:
self.previous_context = current_context()
self.name = name
@@ -281,16 +304,18 @@ class LoggingContext:
self.parent_context = parent_context
if self.parent_context is not None:
- self.parent_context.copy_to(self)
+ # we track the current request_id
+ self.request = self.parent_context.request
+
+ # we also track the current scope:
+ self.scope = self.parent_context.scope
if request is not None:
# the request param overrides the request from the parent context
self.request = request
def __str__(self) -> str:
- if self.request:
- return str(self.request)
- return "%s@%x" % (self.name, id(self))
+ return self.name
@classmethod
def current_context(cls) -> LoggingContextOrSentinel:
@@ -556,8 +581,23 @@ class LoggingContextFilter(logging.Filter):
# we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake.
if context is not None:
- # Logging is interested in the request.
- record.request = context.request # type: ignore
+ # Logging is interested in the request ID. Note that for backwards
+ # compatibility this is stored as the "request" on the record.
+ record.request = str(context) # type: ignore
+
+ # Add some data from the HTTP request.
+ request = context.request
+ if request is None:
+ return True
+
+ record.ip_address = request.ip_address # type: ignore
+ record.site_tag = request.site_tag # type: ignore
+ record.requester = request.requester # type: ignore
+ record.authenticated_entity = request.authenticated_entity # type: ignore
+ record.method = request.method # type: ignore
+ record.url = request.url # type: ignore
+ record.protocol = request.protocol # type: ignore
+ record.user_agent = request.user_agent # type: ignore
return True
@@ -630,8 +670,8 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
def nested_logging_context(suffix: str) -> LoggingContext:
"""Creates a new logging context as a child of another.
- The nested logging context will have a 'request' made up of the parent context's
- request, plus the given suffix.
+ The nested logging context will have a 'name' made up of the parent context's
+ name, plus the given suffix.
CPU/db usage stats will be added to the parent context's on exit.
@@ -641,7 +681,7 @@ def nested_logging_context(suffix: str) -> LoggingContext:
# ... do stuff
Args:
- suffix: suffix to add to the parent context's 'request'.
+ suffix: suffix to add to the parent context's 'name'.
Returns:
LoggingContext: new logging context.
@@ -652,12 +692,14 @@ def nested_logging_context(suffix: str) -> LoggingContext:
"Starting nested logging context from sentinel context: metrics will be lost"
)
parent_context = None
- prefix = ""
else:
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
- prefix = str(parent_context.request)
- return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
+ prefix = str(curr_context)
+ return LoggingContext(
+ prefix + "-" + suffix,
+ parent_context=parent_context,
+ )
def preserve_fn(f):
@@ -847,7 +889,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
parent_context = curr_context
def g():
- with LoggingContext(parent_context=parent_context):
+ with LoggingContext(str(curr_context), parent_context=parent_context):
return f(*args, **kwargs)
return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index aa146e8bb8..bfe9136fd8 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -259,6 +259,14 @@ except ImportError:
logger = logging.getLogger(__name__)
+class SynapseTags:
+ # The message ID of any to_device message processed
+ TO_DEVICE_MESSAGE_ID = "to_device.message_id"
+
+ # Whether the sync response has new data to be returned to the client.
+ SYNC_RESULT = "sync.new_data"
+
+
# Block everything by default
# A regex which matches the server_names to expose traces for.
# None means 'block everything'.
@@ -478,7 +486,7 @@ def start_active_span_from_request(
def start_active_span_from_edu(
edu_content,
operation_name,
- references=[],
+ references: Optional[list] = None,
tags=None,
start_time=None,
ignore_active_span=False,
@@ -493,6 +501,7 @@ def start_active_span_from_edu(
For the other args see opentracing.tracer
"""
+ references = references or []
if opentracing is None:
return noop_context_manager()
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 3b499efc07..13a5bc4558 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -214,7 +214,12 @@ class GaugeBucketCollector:
Prometheus, and optimise for that case.
"""
- __slots__ = ("_name", "_documentation", "_bucket_bounds", "_metric")
+ __slots__ = (
+ "_name",
+ "_documentation",
+ "_bucket_bounds",
+ "_metric",
+ )
def __init__(
self,
@@ -242,11 +247,16 @@ class GaugeBucketCollector:
if self._bucket_bounds[-1] != float("inf"):
self._bucket_bounds.append(float("inf"))
- self._metric = self._values_to_metric([])
+ # We initially set this to None. We won't report metrics until
+ # this has been initialised after a successful data update
+ self._metric = None # type: Optional[GaugeHistogramMetricFamily]
+
registry.register(self)
def collect(self):
- yield self._metric
+ # Don't report metrics unless we've already collected some data
+ if self._metric is not None:
+ yield self._metric
def update_data(self, values: Iterable[float]):
"""Update the data to be reported by the metric
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index b56986d8e7..78e9cfbc26 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -199,11 +199,11 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
_background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc()
- with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
+ with BackgroundProcessLoggingContext("%s-%s" % (desc, count)) as context:
try:
ctx = noop_context_manager()
if bg_start_span:
- ctx = start_active_span(desc, tags={"request_id": context.request})
+ ctx = start_active_span(desc, tags={"request_id": str(context)})
with ctx:
return await maybe_awaitable(func(*args, **kwargs))
except Exception:
@@ -244,9 +244,8 @@ class BackgroundProcessLoggingContext(LoggingContext):
__slots__ = ["_proc"]
- def __init__(self, name: str, request: Optional[str] = None):
- super().__init__(name, request=request)
-
+ def __init__(self, name: str):
+ super().__init__(name)
self._proc = _BackgroundProcess(name, self)
def start(self, rusage: "Optional[resource._RUsage]"):
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 2e38150eac..189f7e52ba 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple
from twisted.internet import defer
@@ -50,11 +50,20 @@ class ModuleApi:
self._auth = hs.get_auth()
self._auth_handler = auth_handler
self._server_name = hs.hostname
+ self._presence_stream = hs.get_event_sources().sources["presence"]
# We expose these as properties below in order to attach a helpful docstring.
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
self._public_room_list_manager = PublicRoomListManager(hs)
+ # The next time these users sync, they will receive the current presence
+ # state of all local users. Users are added by send_local_online_presence_to,
+ # and removed after a successful sync.
+ #
+ # We make this a private variable to deter modules from accessing it directly,
+ # though other classes in Synapse will still do so.
+ self._send_full_presence_to_local_users = set()
+
@property
def http_client(self):
"""Allows making outbound HTTP requests to remote resources.
@@ -118,7 +127,7 @@ class ModuleApi:
return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
@defer.inlineCallbacks
- def register(self, localpart, displayname=None, emails=[]):
+ def register(self, localpart, displayname=None, emails: Optional[List[str]] = None):
"""Registers a new user with given localpart and optional displayname, emails.
Also returns an access token for the new user.
@@ -138,11 +147,13 @@ class ModuleApi:
logger.warning(
"Using deprecated ModuleApi.register which creates a dummy user device."
)
- user_id = yield self.register_user(localpart, displayname, emails)
+ user_id = yield self.register_user(localpart, displayname, emails or [])
_, access_token = yield self.register_device(user_id)
return user_id, access_token
- def register_user(self, localpart, displayname=None, emails=[]):
+ def register_user(
+ self, localpart, displayname=None, emails: Optional[List[str]] = None
+ ):
"""Registers a new user with given localpart and optional displayname, emails.
Args:
@@ -161,7 +172,7 @@ class ModuleApi:
self._hs.get_registration_handler().register_user(
localpart=localpart,
default_display_name=displayname,
- bind_emails=emails,
+ bind_emails=emails or [],
)
)
@@ -385,6 +396,47 @@ class ModuleApi:
return event
+ async def send_local_online_presence_to(self, users: Iterable[str]) -> None:
+ """
+ Forces the equivalent of a presence initial_sync for a set of local or remote
+ users. The users will receive presence for all currently online users that they
+ are considered interested in.
+
+ Updates to remote users will be sent immediately, whereas local users will receive
+ them on their next sync attempt.
+
+ Note that this method can only be run on the main or federation_sender worker
+ processes.
+ """
+ if not self._hs.should_send_federation():
+ raise Exception(
+ "send_local_online_presence_to can only be run "
+ "on processes that send federation",
+ )
+
+ for user in users:
+ if self._hs.is_mine_id(user):
+ # Modify SyncHandler._generate_sync_entry_for_presence to call
+ # presence_source.get_new_events with an empty `from_key` if
+ # that user's ID were in a list modified by ModuleApi somewhere.
+ # That user would then get all presence state on next incremental sync.
+
+ # Force a presence initial_sync for this user next time
+ self._send_full_presence_to_local_users.add(user)
+ else:
+ # Retrieve presence state for currently online users that this user
+ # is considered interested in
+ presence_events, _ = await self._presence_stream.get_new_events(
+ UserID.from_string(user), from_key=None, include_offline=False
+ )
+
+ # Send to remote destinations
+ await make_deferred_yieldable(
+ # We pull the federation sender here as we can only do so on workers
+ # that support sending presence
+ self._hs.get_federation_sender().send_presence(presence_events)
+ )
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 1374aae490..7ce34380af 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -39,6 +39,7 @@ from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import PreserveLoggingContext
+from synapse.logging.opentracing import log_kv, start_active_span
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.streams.config import PaginationConfig
@@ -136,6 +137,15 @@ class _NotifierUserStream:
self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred
+ log_kv(
+ {
+ "notify": self.user_id,
+ "stream": stream_key,
+ "stream_id": stream_id,
+ "listeners": self.count_listeners(),
+ }
+ )
+
users_woken_by_stream_counter.labels(stream_key).inc()
with PreserveLoggingContext():
@@ -266,7 +276,7 @@ class Notifier:
event: EventBase,
event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken,
- extra_users: Collection[UserID] = [],
+ extra_users: Optional[Collection[UserID]] = None,
):
"""Unwraps event and calls `on_new_room_event_args`."""
self.on_new_room_event_args(
@@ -276,7 +286,7 @@ class Notifier:
state_key=event.get("state_key"),
membership=event.content.get("membership"),
max_room_stream_token=max_room_stream_token,
- extra_users=extra_users,
+ extra_users=extra_users or [],
)
def on_new_room_event_args(
@@ -287,7 +297,7 @@ class Notifier:
membership: Optional[str],
event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken,
- extra_users: Collection[UserID] = [],
+ extra_users: Optional[Collection[UserID]] = None,
):
"""Used by handlers to inform the notifier something has happened
in the room, room event wise.
@@ -303,7 +313,7 @@ class Notifier:
self.pending_new_room_events.append(
_PendingRoomEventEntry(
event_pos=event_pos,
- extra_users=extra_users,
+ extra_users=extra_users or [],
room_id=room_id,
type=event_type,
state_key=state_key,
@@ -372,14 +382,14 @@ class Notifier:
self,
stream_key: str,
new_token: Union[int, RoomStreamToken],
- users: Collection[Union[str, UserID]] = [],
+ users: Optional[Collection[Union[str, UserID]]] = None,
):
try:
stream_token = None
if isinstance(new_token, int):
stream_token = new_token
self.appservice_handler.notify_interested_services_ephemeral(
- stream_key, stream_token, users
+ stream_key, stream_token, users or []
)
except Exception:
logger.exception("Error notifying application services of event")
@@ -394,16 +404,26 @@ class Notifier:
self,
stream_key: str,
new_token: Union[int, RoomStreamToken],
- users: Collection[Union[str, UserID]] = [],
- rooms: Collection[str] = [],
+ users: Optional[Collection[Union[str, UserID]]] = None,
+ rooms: Optional[Collection[str]] = None,
):
"""Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms.
"""
+ users = users or []
+ rooms = rooms or []
+
with Measure(self.clock, "on_new_event"):
user_streams = set()
+ log_kv(
+ {
+ "waking_up_explicit_users": len(users),
+ "waking_up_explicit_rooms": len(rooms),
+ }
+ )
+
for user in users:
user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
@@ -476,12 +496,34 @@ class Notifier:
(end_time - now) / 1000.0,
self.hs.get_reactor(),
)
- with PreserveLoggingContext():
- await listener.deferred
+
+ with start_active_span("wait_for_events.deferred"):
+ log_kv(
+ {
+ "wait_for_events": "sleep",
+ "token": prev_token,
+ }
+ )
+
+ with PreserveLoggingContext():
+ await listener.deferred
+
+ log_kv(
+ {
+ "wait_for_events": "woken",
+ "token": user_stream.current_token,
+ }
+ )
current_token = user_stream.current_token
result = await callback(prev_token, current_token)
+ log_kv(
+ {
+ "wait_for_events": "result",
+ "result": bool(result),
+ }
+ )
if result:
break
@@ -489,8 +531,10 @@ class Notifier:
# has happened between the old prev_token and the current_token
prev_token = current_token
except defer.TimeoutError:
+ log_kv({"wait_for_events": "timeout"})
break
except defer.CancelledError:
+ log_kv({"wait_for_events": "cancelled"})
break
if result is None:
@@ -507,7 +551,7 @@ class Notifier:
pagination_config: PaginationConfig,
timeout: int,
is_guest: bool = False,
- explicit_room_id: str = None,
+ explicit_room_id: Optional[str] = None,
) -> EventStreamResult:
"""For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index d005f38767..73d7477854 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -77,7 +77,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
- self.registration_handler.check_registration_ratelimit(content["address"])
+ await self.registration_handler.check_registration_ratelimit(content["address"])
await self.registration_handler.register_with_store(
user_id=user_id,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e829add257..ba753318bd 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -184,8 +184,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
- ctx_name = "replication-conn-%s" % self.conn_id
- self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
+ self._logging_context = BackgroundProcessLoggingContext(
+ "replication-conn-%s" % (self.conn_id,)
+ )
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 2f4d407f94..98bdeb0ec6 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -60,7 +60,7 @@ class ConstantProperty(Generic[T, V]):
constant = attr.ib() # type: V
- def __get__(self, obj: Optional[T], objtype: Type[T] = None) -> V:
+ def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
return self.constant
def __set__(self, obj: Optional[T], value: V):
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 8457db1e22..2dec818a5f 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -54,6 +54,7 @@ from synapse.rest.admin.users import (
AccountValidityRenewServlet,
DeactivateAccountRestServlet,
PushersRestServlet,
+ RateLimitRestServlet,
ResetPasswordRestServlet,
SearchUsersRestServlet,
ShadowBanRestServlet,
@@ -62,7 +63,6 @@ from synapse.rest.admin.users import (
UserMembershipRestServlet,
UserRegisterServlet,
UserRestServletV2,
- UsersRestServlet,
UsersRestServletV2,
UserTokenRestServlet,
WhoisRestServlet,
@@ -240,6 +240,7 @@ def register_servlets(hs, http_server):
ShadowBanRestServlet(hs).register(http_server)
ForwardExtremitiesRestServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server)
+ RateLimitRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
@@ -248,7 +249,6 @@ def register_servlets_for_client_rest_resource(hs, http_server):
PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)
- UsersRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 309bd2771b..04990c71fb 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -36,6 +36,7 @@ from synapse.rest.admin._base import (
)
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.databases.main.media_repository import MediaSortOrder
+from synapse.storage.databases.main.stats import UserSortOrder
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
@@ -44,29 +45,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class UsersRestServlet(RestServlet):
- PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
-
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self.store = hs.get_datastore()
- self.auth = hs.get_auth()
- self.admin_handler = hs.get_admin_handler()
-
- async def on_GET(
- self, request: SynapseRequest, user_id: str
- ) -> Tuple[int, List[JsonDict]]:
- target_user = UserID.from_string(user_id)
- await assert_requester_is_admin(self.auth, request)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- ret = await self.store.get_users()
-
- return 200, ret
-
-
class UsersRestServletV2(RestServlet):
PATTERNS = admin_patterns("/users$", "v2")
@@ -117,8 +95,26 @@ class UsersRestServletV2(RestServlet):
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
+ order_by = parse_string(
+ request,
+ "order_by",
+ default=UserSortOrder.NAME.value,
+ allowed_values=(
+ UserSortOrder.NAME.value,
+ UserSortOrder.DISPLAYNAME.value,
+ UserSortOrder.GUEST.value,
+ UserSortOrder.ADMIN.value,
+ UserSortOrder.DEACTIVATED.value,
+ UserSortOrder.USER_TYPE.value,
+ UserSortOrder.AVATAR_URL.value,
+ UserSortOrder.SHADOW_BANNED.value,
+ ),
+ )
+
+ direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+
users, total = await self.store.get_users_paginate(
- start, limit, user_id, name, guests, deactivated
+ start, limit, user_id, name, guests, deactivated, order_by, direction
)
ret = {"users": users, "total": total}
if (start + limit) < total:
@@ -985,3 +981,114 @@ class ShadowBanRestServlet(RestServlet):
await self.store.set_shadow_banned(UserID.from_string(user_id), True)
return 200, {}
+
+
+class RateLimitRestServlet(RestServlet):
+ """An admin API to override ratelimiting for an user.
+
+ Example:
+ POST /_synapse/admin/v1/users/@test:example.com/override_ratelimit
+ {
+ "messages_per_second": 0,
+ "burst_count": 0
+ }
+ 200 OK
+ {
+ "messages_per_second": 0,
+ "burst_count": 0
+ }
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Can only lookup local users")
+
+ if not await self.store.get_user_by_id(user_id):
+ raise NotFoundError("User not found")
+
+ ratelimit = await self.store.get_ratelimit_for_user(user_id)
+
+ if ratelimit:
+ # convert `null` to `0` for consistency
+ # both values do the same in retelimit handler
+ ret = {
+ "messages_per_second": 0
+ if ratelimit.messages_per_second is None
+ else ratelimit.messages_per_second,
+ "burst_count": 0
+ if ratelimit.burst_count is None
+ else ratelimit.burst_count,
+ }
+ else:
+ ret = {}
+
+ return 200, ret
+
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Only local users can be ratelimited")
+
+ if not await self.store.get_user_by_id(user_id):
+ raise NotFoundError("User not found")
+
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+
+ messages_per_second = body.get("messages_per_second", 0)
+ burst_count = body.get("burst_count", 0)
+
+ if not isinstance(messages_per_second, int) or messages_per_second < 0:
+ raise SynapseError(
+ 400,
+ "%r parameter must be a positive int" % (messages_per_second,),
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if not isinstance(burst_count, int) or burst_count < 0:
+ raise SynapseError(
+ 400,
+ "%r parameter must be a positive int" % (burst_count,),
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ await self.store.set_ratelimit_for_user(
+ user_id, messages_per_second, burst_count
+ )
+ ratelimit = await self.store.get_ratelimit_for_user(user_id)
+ assert ratelimit is not None
+
+ ret = {
+ "messages_per_second": ratelimit.messages_per_second,
+ "burst_count": ratelimit.burst_count,
+ }
+
+ return 200, ret
+
+ async def on_DELETE(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Only local users can be ratelimited")
+
+ if not await self.store.get_user_by_id(user_id):
+ raise NotFoundError("User not found")
+
+ await self.store.delete_ratelimit_for_user(user_id)
+
+ return 200, {}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index e4c352f572..3151e72d4f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -74,11 +74,13 @@ class LoginRestServlet(RestServlet):
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
+ store=hs.get_datastore(),
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
)
self._account_ratelimiter = Ratelimiter(
+ store=hs.get_datastore(),
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
@@ -141,20 +143,22 @@ class LoginRestServlet(RestServlet):
appservice = self.auth.get_appservice_by_req(request)
if appservice.is_rate_limited():
- self._address_ratelimiter.ratelimit(request.getClientIP())
+ await self._address_ratelimiter.ratelimit(
+ None, request.getClientIP()
+ )
result = await self._do_appservice_login(login_submission, appservice)
elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
- self._address_ratelimiter.ratelimit(request.getClientIP())
+ await self._address_ratelimiter.ratelimit(None, request.getClientIP())
result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
- self._address_ratelimiter.ratelimit(request.getClientIP())
+ await self._address_ratelimiter.ratelimit(None, request.getClientIP())
result = await self._do_token_login(login_submission)
else:
- self._address_ratelimiter.ratelimit(request.getClientIP())
+ await self._address_ratelimiter.ratelimit(None, request.getClientIP())
result = await self._do_other_login(login_submission)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -258,7 +262,7 @@ class LoginRestServlet(RestServlet):
# too often. This happens here rather than before as we don't
# necessarily know the user before now.
if ratelimit:
- self._account_ratelimiter.ratelimit(user_id.lower())
+ await self._account_ratelimiter.ratelimit(None, user_id.lower())
if create_non_existent_users:
canonical_uid = await self.auth_handler.check_user_exists(user_id)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 80ee0d2d8e..3076706571 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -105,7 +105,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
- self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+ await self.identity_handler.ratelimit_request_token_requests(
+ request, "email", email
+ )
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
@@ -415,7 +417,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+ await self.identity_handler.ratelimit_request_token_requests(
+ request, "email", email
+ )
if next_link:
# Raise if the provided next_link value isn't valid
@@ -496,7 +500,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- self.identity_handler.ratelimit_request_token_requests(
+ await self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index a7aea914e9..beca08ab5d 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import hmac
import logging
import random
@@ -24,7 +23,7 @@ from typing import List, Union
import synapse
import synapse.api.auth
import synapse.types
-from synapse.api.constants import LoginType
+from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import (
Codes,
InteractiveAuthIncompleteError,
@@ -128,7 +127,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+ await self.identity_handler.ratelimit_request_token_requests(
+ request, "email", email
+ )
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
@@ -212,7 +213,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- self.identity_handler.ratelimit_request_token_requests(
+ await self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn
)
@@ -404,7 +405,7 @@ class RegisterRestServlet(RestServlet):
client_addr = request.getClientIP()
- self.ratelimiter.ratelimit(client_addr, update=False)
+ await self.ratelimiter.ratelimit(None, client_addr, update=False)
kind = b"user"
if b"kind" in request.args:
@@ -443,7 +444,16 @@ class RegisterRestServlet(RestServlet):
# different registration flows to normal users
# == Application Service Registration ==
- if appservice:
+ if body.get("type") == APP_SERVICE_REGISTRATION_TYPE:
+ if not self.auth.has_access_token(request):
+ raise SynapseError(
+ 400,
+ "Appservice token must be provided when using a type of m.login.application_service",
+ )
+
+ # Verify the AS
+ self.auth.get_appservice_by_req(request)
+
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
@@ -464,6 +474,11 @@ class RegisterRestServlet(RestServlet):
)
return 200, result
+ elif self.auth.has_access_token(request):
+ raise SynapseError(
+ 400,
+ "An access token should not be provided on requests to /register (except if type is m.login.application_service)",
+ )
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index c4ed9dfdb4..814145a04a 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -175,7 +175,7 @@ class PreviewUrlResource(DirectServeJsonResource):
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=ONE_HOUR,
- )
+ ) # type: ExpiringCache[str, ObservableDeferred]
if self._worker_run_media_background_jobs:
self._cleaner_loop = self.clock.looping_call(
diff --git a/synapse/server.py b/synapse/server.py
index e85b9391fa..cfb55c230d 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -51,6 +51,7 @@ from synapse.crypto import context_factory
from synapse.crypto.context_factory import RegularPolicyForHTTPS
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
+from synapse.events.presence_router import PresenceRouter
from synapse.events.spamcheck import SpamChecker
from synapse.events.third_party_rules import ThirdPartyEventRules
from synapse.events.utils import EventClientSerializer
@@ -329,6 +330,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_registration_ratelimiter(self) -> Ratelimiter:
return Ratelimiter(
+ store=self.get_datastore(),
clock=self.get_clock(),
rate_hz=self.config.rc_registration.per_second,
burst_count=self.config.rc_registration.burst_count,
@@ -425,6 +427,10 @@ class HomeServer(metaclass=abc.ABCMeta):
raise Exception("Workers cannot write typing")
@cache_in_self
+ def get_presence_router(self) -> PresenceRouter:
+ return PresenceRouter(self)
+
+ @cache_in_self
def get_typing_handler(self) -> FollowerTypingHandler:
if self.config.worker.writers.typing == self.get_instance_name():
# Use get_typing_writer_handler to ensure that we use the same
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index c3d6e80c49..c0f79ffdc8 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -22,6 +22,7 @@ from typing import (
Callable,
DefaultDict,
Dict,
+ FrozenSet,
Iterable,
List,
Optional,
@@ -515,7 +516,7 @@ class StateResolutionHandler:
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
- )
+ ) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry]
#
# stuff for tracking time spent on state-res by room
@@ -536,7 +537,7 @@ class StateResolutionHandler:
state_groups_ids: Dict[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
- ):
+ ) -> _StateCacheEntry:
"""Resolves conflicts between a set of state groups
Always generates a new state group (unless we hit the cache), so should
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 94590e7b45..77ef29ec71 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -488,7 +488,7 @@ class DatabasePool:
exception_callbacks: List[_CallbackListEntry],
func: "Callable[..., R]",
*args: Any,
- **kwargs: Any
+ **kwargs: Any,
) -> R:
"""Start a new database transaction with the given connection.
@@ -622,7 +622,7 @@ class DatabasePool:
func: "Callable[..., R]",
*args: Any,
db_autocommit: bool = False,
- **kwargs: Any
+ **kwargs: Any,
) -> R:
"""Starts a transaction on the database and runs a given function
@@ -682,7 +682,7 @@ class DatabasePool:
func: "Callable[..., R]",
*args: Any,
db_autocommit: bool = False,
- **kwargs: Any
+ **kwargs: Any,
) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool.
@@ -775,7 +775,7 @@ class DatabasePool:
desc: str,
decoder: Optional[Callable[[Cursor], R]],
query: str,
- *args: Any
+ *args: Any,
) -> R:
"""Runs a single query for a result set.
@@ -900,7 +900,7 @@ class DatabasePool:
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
- insertion_values: Dict[str, Any] = {},
+ insertion_values: Optional[Dict[str, Any]] = None,
desc: str = "simple_upsert",
lock: bool = True,
) -> Optional[bool]:
@@ -927,6 +927,8 @@ class DatabasePool:
Native upserts always return None. Emulated upserts return True if a
new entry was created, False if an existing one was updated.
"""
+ insertion_values = insertion_values or {}
+
attempts = 0
while True:
try:
@@ -964,7 +966,7 @@ class DatabasePool:
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
- insertion_values: Dict[str, Any] = {},
+ insertion_values: Optional[Dict[str, Any]] = None,
lock: bool = True,
) -> Optional[bool]:
"""
@@ -982,6 +984,8 @@ class DatabasePool:
Native upserts always return None. Emulated upserts return True if a
new entry was created, False if an existing one was updated.
"""
+ insertion_values = insertion_values or {}
+
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
@@ -1003,7 +1007,7 @@ class DatabasePool:
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
- insertion_values: Dict[str, Any] = {},
+ insertion_values: Optional[Dict[str, Any]] = None,
lock: bool = True,
) -> bool:
"""
@@ -1017,6 +1021,8 @@ class DatabasePool:
Returns True if a new entry was created, False if an existing
one was updated.
"""
+ insertion_values = insertion_values or {}
+
# We need to lock the table :(, unless we're *really* careful
if lock:
self.engine.lock_table(txn, table)
@@ -1077,7 +1083,7 @@ class DatabasePool:
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
- insertion_values: Dict[str, Any] = {},
+ insertion_values: Optional[Dict[str, Any]] = None,
) -> None:
"""
Use the native UPSERT functionality in recent PostgreSQL versions.
@@ -1090,7 +1096,7 @@ class DatabasePool:
"""
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
- allvalues.update(insertion_values)
+ allvalues.update(insertion_values or {})
if not values:
latter = "NOTHING"
@@ -1513,7 +1519,7 @@ class DatabasePool:
column: str,
iterable: Iterable[Any],
retcols: Iterable[str],
- keyvalues: Dict[str, Any] = {},
+ keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch",
batch_size: int = 100,
) -> List[Any]:
@@ -1531,6 +1537,8 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
batch_size: the number of rows for each select query
"""
+ keyvalues = keyvalues or {}
+
results = [] # type: List[Dict[str, Any]]
if not iterable:
@@ -2059,69 +2067,18 @@ def make_in_list_sql_clause(
KV = TypeVar("KV")
-def make_tuple_comparison_clause(
- database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]]
-) -> Tuple[str, List[KV]]:
+def make_tuple_comparison_clause(keys: List[Tuple[str, KV]]) -> Tuple[str, List[KV]]:
"""Returns a tuple comparison SQL clause
- Depending what the SQL engine supports, builds a SQL clause that looks like either
- "(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)".
+ Builds a SQL clause that looks like "(a, b) > (?, ?)"
Args:
- database_engine
keys: A set of (column, value) pairs to be compared.
Returns:
A tuple of SQL query and the args
"""
- if database_engine.supports_tuple_comparison:
- return (
- "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
- [k[1] for k in keys],
- )
-
- # we want to build a clause
- # (a > ?) OR
- # (a == ? AND b > ?) OR
- # (a == ? AND b == ? AND c > ?)
- # ...
- # (a == ? AND b == ? AND ... AND z > ?)
- #
- # or, equivalently:
- #
- # (a > ? OR (a == ? AND
- # (b > ? OR (b == ? AND
- # ...
- # (y > ? OR (y == ? AND
- # z > ?
- # ))
- # ...
- # ))
- # ))
- #
- # which itself is equivalent to (and apparently easier for the query optimiser):
- #
- # (a >= ? AND (a > ? OR
- # (b >= ? AND (b > ? OR
- # ...
- # (y >= ? AND (y > ? OR
- # z > ?
- # ))
- # ...
- # ))
- # ))
- #
- #
-
- clause = ""
- args = [] # type: List[KV]
- for k, v in keys[:-1]:
- clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k)
- args.extend([v, v])
-
- (k, v) = keys[-1]
- clause += "%s > ?" % (k,)
- args.append(v)
-
- clause += "))" * (len(keys) - 1)
- return clause, args
+ return (
+ "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
+ [k[1] for k in keys],
+ )
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 1d44c3aa2c..b3d16ca7ac 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -21,6 +21,7 @@ from typing import List, Optional, Tuple
from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
IdGenerator,
@@ -292,6 +293,8 @@ class DataStore(
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
+ order_by: UserSortOrder = UserSortOrder.USER_ID.value,
+ direction: str = "f",
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
@@ -304,6 +307,8 @@ class DataStore(
name: search for local part of user_id or display name
guests: whether to in include guest users
deactivated: whether to include deactivated users
+ order_by: the sort order of the returned list
+ direction: sort ascending or descending
Returns:
A tuple of a list of mappings from user to information and a count of total users.
"""
@@ -312,6 +317,14 @@ class DataStore(
filters = []
args = [self.hs.config.server_name]
+ # Set ordering
+ order_by_column = UserSortOrder(order_by).value
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
# `name` is in database already in lower case
if name:
filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
@@ -339,10 +352,15 @@ class DataStore(
txn.execute(sql, args)
count = txn.fetchone()[0]
- sql = (
- "SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url "
- + sql_base
- + " ORDER BY u.name LIMIT ? OFFSET ?"
+ sql = """
+ SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url
+ {sql_base}
+ ORDER BY {order_by_column} {order}, u.name ASC
+ LIMIT ? OFFSET ?
+ """.format(
+ sql_base=sql_base,
+ order_by_column=order_by_column,
+ order=order,
)
args += [limit, start]
txn.execute(sql, args)
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 6d18e692b0..ea3c15fd0e 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -298,7 +298,6 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
# times, which is fine.
where_clause, where_args = make_tuple_comparison_clause(
- self.database_engine,
[("user_id", last_user_id), ("device_id", last_device_id)],
)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d327e9aa0b..9bf8ba888f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -985,7 +985,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
def _txn(txn):
clause, args = make_tuple_comparison_clause(
- self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS]
+ [(x, last_row[x]) for x in KEY_COLS]
)
sql = """
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 98dac19a95..ad17123915 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -320,8 +320,8 @@ class PersistEventsStore:
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
- state_delta_for_room: Dict[str, DeltaState] = {},
- new_forward_extremeties: Dict[str, List[str]] = {},
+ state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
+ new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
):
"""Insert some number of room events into the necessary database tables.
@@ -342,6 +342,9 @@ class PersistEventsStore:
extremities.
"""
+ state_delta_for_room = state_delta_for_room or {}
+ new_forward_extremeties = new_forward_extremeties or {}
+
all_events_and_contexts = events_and_contexts
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 78367ea58d..79e7df6ca9 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -838,7 +838,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# We want to do a `(topological_ordering, stream_ordering) > (?,?)`
# comparison, but that is not supported on older SQLite versions
tuple_clause, tuple_args = make_tuple_comparison_clause(
- self.database_engine,
[
("events.room_id", last_room_id),
("topological_ordering", last_depth),
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index dfb638ea54..c00780969f 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -16,7 +16,7 @@
import logging
import threading
from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Tuple, overload
+from typing import Container, Dict, Iterable, List, Optional, Tuple, overload
from constantly import NamedConstant, Names
from typing_extensions import Literal
@@ -544,7 +544,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_stripped_room_state_from_event_context(
self,
context: EventContext,
- state_types_to_include: List[str],
+ state_types_to_include: Container[str],
membership_user_id: Optional[str] = None,
) -> List[JsonDict]:
"""
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ac07e0197b..bd7826f4e9 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1027,8 +1027,8 @@ class GroupServerStore(GroupServerWorkerStore):
user_id: str,
is_admin: bool = False,
is_public: bool = True,
- local_attestation: dict = None,
- remote_attestation: dict = None,
+ local_attestation: Optional[dict] = None,
+ remote_attestation: Optional[dict] = None,
) -> None:
"""Add a user to the group server.
@@ -1171,7 +1171,7 @@ class GroupServerStore(GroupServerWorkerStore):
user_id: str,
membership: str,
is_admin: bool = False,
- content: JsonDict = {},
+ content: Optional[JsonDict] = None,
local_attestation: Optional[dict] = None,
remote_attestation: Optional[dict] = None,
is_publicised: bool = False,
@@ -1192,6 +1192,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_publicised: Whether this should be publicised.
"""
+ content = content or {}
+
def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert?
self.db_pool.simple_delete_txn(
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 4f3d192562..b7820ac7ff 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -22,6 +22,9 @@ from synapse.storage.database import DatabasePool
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
"media_repository_drop_index_wo_method"
)
+BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
+ "media_repository_drop_index_wo_method_2"
+)
class MediaSortOrder(Enum):
@@ -85,23 +88,35 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
unique=True,
)
+ # the original impl of _drop_media_index_without_method was broken (see
+ # https://github.com/matrix-org/synapse/issues/8649), so we replace the original
+ # impl with a no-op and run the fixed migration as
+ # media_repository_drop_index_wo_method_2.
+ self.db_pool.updates.register_noop_background_update(
+ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
+ )
self.db_pool.updates.register_background_update_handler(
- BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD,
+ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2,
self._drop_media_index_without_method,
)
async def _drop_media_index_without_method(self, progress, batch_size):
+ """background update handler which removes the old constraints.
+
+ Note that this is only run on postgres.
+ """
+
def f(txn):
txn.execute(
"ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
)
txn.execute(
- "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+ "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"
)
await self.db_pool.runInteraction("drop_media_indices_without_method", f)
await self.db_pool.updates._end_background_update(
- BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
+ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2
)
return 1
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 8db6f1396a..a76e9ae2e7 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -541,13 +541,11 @@ class RoomWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
- async def get_ratelimit_for_user(self, user_id):
- """Check if there are any overrides for ratelimiting for the given
- user
+ async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]:
+ """Check if there are any overrides for ratelimiting for the given user
Args:
- user_id (str)
-
+ user_id: user ID of the user
Returns:
RatelimitOverride if there is an override, else None. If the contents
of RatelimitOverride are None or 0 then ratelimitng has been
@@ -569,6 +567,62 @@ class RoomWorkerStore(SQLBaseStore):
else:
return None
+ async def set_ratelimit_for_user(
+ self, user_id: str, messages_per_second: int, burst_count: int
+ ) -> None:
+ """Sets whether a user is set an overridden ratelimit.
+ Args:
+ user_id: user ID of the user
+ messages_per_second: The number of actions that can be performed in a second.
+ burst_count: How many actions that can be performed before being limited.
+ """
+
+ def set_ratelimit_txn(txn):
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="ratelimit_override",
+ keyvalues={"user_id": user_id},
+ values={
+ "messages_per_second": messages_per_second,
+ "burst_count": burst_count,
+ },
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_ratelimit_for_user, (user_id,)
+ )
+
+ await self.db_pool.runInteraction("set_ratelimit", set_ratelimit_txn)
+
+ async def delete_ratelimit_for_user(self, user_id: str) -> None:
+ """Delete an overridden ratelimit for a user.
+ Args:
+ user_id: user ID of the user
+ """
+
+ def delete_ratelimit_txn(txn):
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ table="ratelimit_override",
+ keyvalues={"user_id": user_id},
+ retcols=["user_id"],
+ allow_none=True,
+ )
+
+ if not row:
+ return
+
+ # They are there, delete them.
+ self.db_pool.simple_delete_one_txn(
+ txn, "ratelimit_override", keyvalues={"user_id": user_id}
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_ratelimit_for_user, (user_id,)
+ )
+
+ await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
+
@cached()
async def get_retention_policy_for_room(self, room_id):
"""Get the retention policy for a given room.
diff --git a/synapse/storage/databases/main/schema/delta/59/11drop_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/59/11drop_thumbnail_constraint.sql.postgres
new file mode 100644
index 0000000000..54c1bca3b1
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/11drop_thumbnail_constraint.sql.postgres
@@ -0,0 +1,22 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- drop old constraints on remote_media_cache_thumbnails
+--
+-- This was originally part of 57.07, but it was done wrong, per
+-- https://github.com/matrix-org/synapse/issues/8649, so we do it again.
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (5911, 'media_repository_drop_index_wo_method_2', '{}', 'remote_media_repository_thumbnails_method_idx');
+
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index a7f371732f..93431efe00 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -190,7 +190,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# FIXME: how should this be cached?
async def get_filtered_current_state_ids(
- self, room_id: str, state_filter: StateFilter = StateFilter.all()
+ self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
@@ -205,7 +205,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Map from type/state_key to event ID.
"""
- where_clause, where_args = state_filter.make_sql_filter_clause()
+ where_clause, where_args = (
+ state_filter or StateFilter.all()
+ ).make_sql_filter_clause()
if not where_clause:
# We delegate to the cached version
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 38adecc78a..b33c93da2d 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -67,18 +67,37 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
class UserSortOrder(Enum):
"""
Enum to define the sorting method used when returning users
- with get_users_media_usage_paginate
+ with get_users_paginate in __init__.py
+ and get_users_media_usage_paginate in stats.py
- MEDIA_LENGTH = ordered by size of uploaded media. Smallest to largest.
- MEDIA_COUNT = ordered by number of uploaded media. Smallest to largest.
+ When moves this to __init__.py gets `builtins.ImportError` with
+ `most likely due to a circular import`
+
+ MEDIA_LENGTH = ordered by size of uploaded media.
+ MEDIA_COUNT = ordered by number of uploaded media.
USER_ID = ordered alphabetically by `user_id`.
+ NAME = ordered alphabetically by `user_id`. This is for compatibility reasons,
+ as the user_id is returned in the name field in the response in list users admin API.
DISPLAYNAME = ordered alphabetically by `displayname`
+ GUEST = ordered by `is_guest`
+ ADMIN = ordered by `admin`
+ DEACTIVATED = ordered by `deactivated`
+ USER_TYPE = ordered alphabetically by `user_type`
+ AVATAR_URL = ordered alphabetically by `avatar_url`
+ SHADOW_BANNED = ordered by `shadow_banned`
"""
MEDIA_LENGTH = "media_length"
MEDIA_COUNT = "media_count"
USER_ID = "user_id"
+ NAME = "name"
DISPLAYNAME = "displayname"
+ GUEST = "is_guest"
+ ADMIN = "admin"
+ DEACTIVATED = "deactivated"
+ USER_TYPE = "user_type"
+ AVATAR_URL = "avatar_url"
+ SHADOW_BANNED = "shadow_banned"
class StatsStore(StateDeltasStore):
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 1fd333b707..75c09b3687 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Optional
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -73,8 +74,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count
def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter=StateFilter.all()
+ self, txn, groups, state_filter: Optional[StateFilter] = None
):
+ state_filter = state_filter or StateFilter.all()
+
results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 97ec65f757..dfcf89d91c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -15,7 +15,7 @@
import logging
from collections import namedtuple
-from typing import Dict, Iterable, List, Set, Tuple
+from typing import Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
@@ -210,7 +210,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
async def _get_state_for_groups(
- self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@@ -223,6 +223,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns:
Dict of state group to state map.
"""
+ state_filter = state_filter or StateFilter.all()
member_filter, non_member_filter = state_filter.get_member_split()
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index cca839c70f..21db1645d3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -44,14 +44,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
- def supports_tuple_comparison(self) -> bool:
- """
- Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
- """
- ...
-
- @property
- @abc.abstractmethod
def supports_using_any_list(self) -> bool:
"""
Do we support using `a = ANY(?)` and passing a list
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 80a3558aec..dba8cc51d3 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -47,8 +47,8 @@ class PostgresEngine(BaseDatabaseEngine):
self._version = db_conn.server_version
# Are we on a supported PostgreSQL version?
- if not allow_outdated_version and self._version < 90500:
- raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
+ if not allow_outdated_version and self._version < 90600:
+ raise RuntimeError("Synapse requires PostgreSQL 9.6 or above.")
with db_conn.cursor() as txn:
txn.execute("SHOW SERVER_ENCODING")
@@ -130,13 +130,6 @@ class PostgresEngine(BaseDatabaseEngine):
return True
@property
- def supports_tuple_comparison(self):
- """
- Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
- """
- return True
-
- @property
def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list"""
return True
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index b87e7798da..f4f16456f2 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -57,14 +57,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
return self.module.sqlite_version_info >= (3, 24, 0)
@property
- def supports_tuple_comparison(self):
- """
- Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires
- SQLite 3.15+.
- """
- return self.module.sqlite_version_info >= (3, 15, 0)
-
- @property
def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list"""
return False
@@ -72,8 +64,11 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def check_database(self, db_conn, allow_outdated_version: bool = False):
if not allow_outdated_version:
version = self.module.sqlite_version_info
- if version < (3, 11, 0):
- raise RuntimeError("Synapse requires sqlite 3.11 or above.")
+ # Synapse is untested against older SQLite versions, and we don't want
+ # to let users upgrade to a version of Synapse with broken support for their
+ # sqlite version, because it risks leaving them with a half-upgraded db.
+ if version < (3, 22, 0):
+ raise RuntimeError("Synapse requires sqlite 3.22 or above.")
def check_new_database(self, txn):
"""Gets called when setting up a brand new database. This allows us to
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6c3c2da520..c7f0b8ccb5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import imp
+import importlib.util
import logging
import os
import re
@@ -454,8 +454,13 @@ def _upgrade_existing_database(
)
module_name = "synapse.storage.v%d_%s" % (v, root_name)
- with open(absolute_path) as python_file:
- module = imp.load_source(module_name, absolute_path, python_file) # type: ignore
+
+ spec = importlib.util.spec_from_file_location(
+ module_name, absolute_path
+ )
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module) # type: ignore
+
logger.info("Running script %s", relative_path)
module.run_create(cur, database_engine) # type: ignore
if not is_empty:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 2e277a21c4..c1c147c62a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -449,7 +449,7 @@ class StateGroupStorage:
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events(
- self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
+ self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
@@ -465,7 +465,7 @@ class StateGroupStorage:
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
- groups, state_filter
+ groups, state_filter or StateFilter.all()
)
state_event_map = await self.stores.main.get_events(
@@ -485,7 +485,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events(
- self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
+ self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
@@ -502,7 +502,7 @@ class StateGroupStorage:
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
- groups, state_filter
+ groups, state_filter or StateFilter.all()
)
event_to_state = {
@@ -513,7 +513,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids}
async def get_state_for_event(
- self, event_id: str, state_filter: StateFilter = StateFilter.all()
+ self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
"""
Get the state dict corresponding to a particular event
@@ -525,11 +525,13 @@ class StateGroupStorage:
Returns:
A dict from (type, state_key) -> state_event
"""
- state_map = await self.get_state_for_events([event_id], state_filter)
+ state_map = await self.get_state_for_events(
+ [event_id], state_filter or StateFilter.all()
+ )
return state_map[event_id]
async def get_state_ids_for_event(
- self, event_id: str, state_filter: StateFilter = StateFilter.all()
+ self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
@@ -541,11 +543,13 @@ class StateGroupStorage:
Returns:
A dict from (type, state_key) -> state_event
"""
- state_map = await self.get_state_ids_for_events([event_id], state_filter)
+ state_map = await self.get_state_ids_for_events(
+ [event_id], state_filter or StateFilter.all()
+ )
return state_map[event_id]
def _get_state_for_groups(
- self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@@ -558,7 +562,9 @@ class StateGroupStorage:
Returns:
Dict of state group to state map.
"""
- return self.stores.state._get_state_for_groups(groups, state_filter)
+ return self.stores.state._get_state_for_groups(
+ groups, state_filter or StateFilter.all()
+ )
async def store_state_group(
self,
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index d4643c4fdf..32d6cc16b9 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -17,7 +17,7 @@ import logging
import threading
from collections import OrderedDict
from contextlib import contextmanager
-from typing import Dict, List, Optional, Set, Tuple, Union
+from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
import attr
@@ -91,7 +91,14 @@ class StreamIdGenerator:
# ... persist event ...
"""
- def __init__(self, db_conn, table, column, extra_tables=[], step=1):
+ def __init__(
+ self,
+ db_conn,
+ table,
+ column,
+ extra_tables: Iterable[Tuple[str, str]] = (),
+ step=1,
+ ):
assert step != 0
self._lock = threading.Lock()
self._step = step
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1adc92eb90..dd392cf694 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -283,7 +283,9 @@ class DeferredCache(Generic[KT, VT]):
# we return a new Deferred which will be called before any subsequent observers.
return observable.observe()
- def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
+ def prefill(
+ self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
+ ):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index e15f7ee698..4dc3477e89 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -15,40 +15,50 @@
import logging
from collections import OrderedDict
+from typing import Any, Generic, Optional, TypeVar, Union, overload
+
+import attr
+from typing_extensions import Literal
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
-SENTINEL = object()
+SENTINEL = object() # type: Any
+
+T = TypeVar("T")
+KT = TypeVar("KT")
+VT = TypeVar("VT")
-class ExpiringCache:
+
+class ExpiringCache(Generic[KT, VT]):
def __init__(
self,
- cache_name,
- clock,
- max_len=0,
- expiry_ms=0,
- reset_expiry_on_get=False,
- iterable=False,
+ cache_name: str,
+ clock: Clock,
+ max_len: int = 0,
+ expiry_ms: int = 0,
+ reset_expiry_on_get: bool = False,
+ iterable: bool = False,
):
"""
Args:
- cache_name (str): Name of this cache, used for logging.
- clock (Clock)
- max_len (int): Max size of dict. If the dict grows larger than this
+ cache_name: Name of this cache, used for logging.
+ clock
+ max_len: Max size of dict. If the dict grows larger than this
then the oldest items get automatically evicted. Default is 0,
which indicates there is no max limit.
- expiry_ms (int): How long before an item is evicted from the cache
+ expiry_ms: How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get
evicted based on time.
- reset_expiry_on_get (bool): If true, will reset the expiry time for
+ reset_expiry_on_get: If true, will reset the expiry time for
an item on access. Defaults to False.
- iterable (bool): If true, the size is calculated by summing the
+ iterable: If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries.
"""
self._cache_name = cache_name
@@ -62,7 +72,7 @@ class ExpiringCache:
self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get
- self._cache = OrderedDict()
+ self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry]
self.iterable = iterable
@@ -79,12 +89,12 @@ class ExpiringCache:
self._clock.looping_call(f, self._expiry_ms / 2)
- def __setitem__(self, key, value):
+ def __setitem__(self, key: KT, value: VT) -> None:
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
self.evict()
- def evict(self):
+ def evict(self) -> None:
# Evict if there are now too many items
while self._max_size and len(self) > self._max_size:
_key, value = self._cache.popitem(last=False)
@@ -93,7 +103,7 @@ class ExpiringCache:
else:
self.metrics.inc_evictions()
- def __getitem__(self, key):
+ def __getitem__(self, key: KT) -> VT:
try:
entry = self._cache[key]
self.metrics.inc_hits()
@@ -106,7 +116,7 @@ class ExpiringCache:
return entry.value
- def pop(self, key, default=SENTINEL):
+ def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Removes and returns the value with the given key from the cache.
If the key isn't in the cache then `default` will be returned if
@@ -115,29 +125,40 @@ class ExpiringCache:
Identical functionality to `dict.pop(..)`.
"""
- value = self._cache.pop(key, default)
+ value = self._cache.pop(key, SENTINEL)
+ # The key was not found.
if value is SENTINEL:
- raise KeyError(key)
+ if default is SENTINEL:
+ raise KeyError(key)
+ return default
- return value
+ return value.value
- def __contains__(self, key):
+ def __contains__(self, key: KT) -> bool:
return key in self._cache
- def get(self, key, default=None):
+ @overload
+ def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]:
+ ...
+
+ @overload
+ def get(self, key: KT, default: T) -> Union[VT, T]:
+ ...
+
+ def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]:
try:
return self[key]
except KeyError:
return default
- def setdefault(self, key, value):
+ def setdefault(self, key: KT, value: VT) -> VT:
try:
return self[key]
except KeyError:
self[key] = value
return value
- def _prune_cache(self):
+ def _prune_cache(self) -> None:
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
@@ -166,7 +187,7 @@ class ExpiringCache:
len(self),
)
- def __len__(self):
+ def __len__(self) -> int:
if self.iterable:
return sum(len(entry.value) for entry in self._cache.values())
else:
@@ -190,9 +211,7 @@ class ExpiringCache:
return False
+@attr.s(slots=True)
class _CacheEntry:
- __slots__ = ["time", "value"]
-
- def __init__(self, time, value):
- self.time = time
- self.value = value
+ time = attr.ib(type=int)
+ value = attr.ib()
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 60bb6ff642..20c8e2d9f5 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -57,12 +57,14 @@ def enumerate_leaves(node, depth):
class _Node:
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
- def __init__(self, prev_node, next_node, key, value, callbacks=set()):
+ def __init__(
+ self, prev_node, next_node, key, value, callbacks: Optional[set] = None
+ ):
self.prev_node = prev_node
self.next_node = next_node
self.key = key
self.value = value
- self.callbacks = callbacks
+ self.callbacks = callbacks or set()
class LruCache(Generic[KT, VT]):
@@ -176,10 +178,10 @@ class LruCache(Generic[KT, VT]):
self.len = synchronized(cache_len)
- def add_node(key, value, callbacks=set()):
+ def add_node(key, value, callbacks: Optional[set] = None):
prev_node = list_root
next_node = prev_node.next_node
- node = _Node(prev_node, next_node, key, value, callbacks)
+ node = _Node(prev_node, next_node, key, value, callbacks or set())
prev_node.next_node = node
next_node.prev_node = node
cache[key] = node
@@ -237,7 +239,7 @@ class LruCache(Generic[KT, VT]):
def cache_get(
key: KT,
default: Optional[T] = None,
- callbacks: Iterable[Callable[[], None]] = [],
+ callbacks: Iterable[Callable[[], None]] = (),
update_metrics: bool = True,
):
node = cache.get(key, None)
@@ -253,7 +255,7 @@ class LruCache(Generic[KT, VT]):
return default
@synchronized
- def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
+ def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()):
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 1023c856d1..019cfa17cc 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -105,7 +105,13 @@ class Measure:
"start",
]
- def __init__(self, clock, name):
+ def __init__(self, clock, name: str):
+ """
+ Args:
+ clock: A n object with a "time()" method, which returns the current
+ time in seconds.
+ name: The name of the metric to report.
+ """
self.clock = clock
self.name = name
curr_context = current_context()
@@ -118,10 +124,8 @@ class Measure:
else:
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
- self._logging_context = LoggingContext(
- "Measure[%s]" % (self.name,), parent_context
- )
- self.start = None
+ self._logging_context = LoggingContext(str(curr_context), parent_context)
+ self.start = None # type: Optional[int]
def __enter__(self) -> "Measure":
if self.start is not None:
diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py
index c306891b27..b3abc6b254 100644
--- a/synmark/suites/logging.py
+++ b/synmark/suites/logging.py
@@ -16,8 +16,7 @@
import logging
import warnings
from io import StringIO
-
-from mock import Mock
+from unittest.mock import Mock
from pyperf import perf_counter
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 34f72ae795..28d77f0ca2 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
import pymacaroons
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 483418192c..fa96ba07a5 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -5,38 +5,25 @@ from synapse.types import create_requester
from tests import unittest
-class TestRatelimiter(unittest.TestCase):
+class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_via_can_do_action(self):
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
- self.assertTrue(allowed)
- self.assertEquals(10.0, time_allowed)
-
- allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
- self.assertFalse(allowed)
- self.assertEquals(10.0, time_allowed)
-
- allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
- self.assertTrue(allowed)
- self.assertEquals(20.0, time_allowed)
-
- def test_allowed_user_via_can_requester_do_action(self):
- user_requester = create_requester("@user:example.com")
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_requester_do_action(
- user_requester, _time_now_s=0
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id", _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- user_requester, _time_now_s=5
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id", _time_now_s=5)
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- user_requester, _time_now_s=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id", _time_now_s=10)
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
@@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase):
)
as_requester = create_requester("@user:example.com", app_service=appservice)
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=0
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=5
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=5)
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=10)
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
@@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase):
)
as_requester = create_requester("@user:example.com", app_service=appservice)
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=0
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=5
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=5)
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=10)
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
def test_allowed_via_ratelimit(self):
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
# Shouldn't raise
- limiter.ratelimit(key="test_id", _time_now_s=0)
+ self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))
# Should raise
with self.assertRaises(LimitExceededError) as context:
- limiter.ratelimit(key="test_id", _time_now_s=5)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key="test_id", _time_now_s=5)
+ )
self.assertEqual(context.exception.retry_after_ms, 5000)
# Shouldn't raise
- limiter.ratelimit(key="test_id", _time_now_s=10)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key="test_id", _time_now_s=10)
+ )
def test_allowed_via_can_do_action_and_overriding_parameters(self):
"""Test that we can override options of can_do_action that would otherwise fail
an action
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
# First attempt should be allowed
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",),
- _time_now_s=0,
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(
+ None,
+ ("test_id",),
+ _time_now_s=0,
+ )
)
self.assertTrue(allowed)
self.assertEqual(10.0, time_allowed)
# Second attempt, 1s later, will fail
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",),
- _time_now_s=1,
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(
+ None,
+ ("test_id",),
+ _time_now_s=1,
+ )
)
self.assertFalse(allowed)
self.assertEqual(10.0, time_allowed)
# But, if we allow 10 actions/sec for this request, we should be allowed
# to continue.
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",), _time_now_s=1, rate_hz=10.0
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0)
)
self.assertTrue(allowed)
self.assertEqual(1.1, time_allowed)
# Similarly if we allow a burst of 10 actions
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",), _time_now_s=1, burst_count=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10)
)
self.assertTrue(allowed)
self.assertEqual(1.0, time_allowed)
@@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase):
fail an action
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
# First attempt should be allowed
- limiter.ratelimit(key=("test_id",), _time_now_s=0)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
+ )
# Second attempt, 1s later, will fail
with self.assertRaises(LimitExceededError) as context:
- limiter.ratelimit(key=("test_id",), _time_now_s=1)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
+ )
self.assertEqual(context.exception.retry_after_ms, 9000)
# But, if we allow 10 actions/sec for this request, we should be allowed
# to continue.
- limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
+ )
# Similarly if we allow a burst of 10 actions
- limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
+ )
def test_pruning(self):
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- limiter.can_do_action(key="test_id_1", _time_now_s=0)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
+ )
self.assertIn("test_id_1", limiter.actions)
- limiter.can_do_action(key="test_id_2", _time_now_s=10)
+ self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
+ )
self.assertNotIn("test_id_1", limiter.actions)
+
+ def test_db_user_override(self):
+ """Test that users that have ratelimiting disabled in the DB aren't
+ ratelimited.
+ """
+ store = self.hs.get_datastore()
+
+ user_id = "@user:test"
+ requester = create_requester(user_id)
+
+ self.get_success(
+ store.db_pool.simple_insert(
+ table="ratelimit_override",
+ values={
+ "user_id": user_id,
+ "messages_per_second": None,
+ "burst_count": None,
+ },
+ desc="test_db_user_override",
+ )
+ )
+
+ limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)
+
+ # Shouldn't raise
+ for _ in range(20):
+ self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 467033e201..33a37fe35e 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock, patch
+from unittest.mock import Mock, patch
from parameterized import parameterized
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 0bffeb1150..03a7440eec 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
-
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 97f8cad0dd..3c27d797fb 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 734a9983e8..c109425671 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -20,6 +20,7 @@ from io import StringIO
import yaml
+from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig
from tests import unittest
@@ -35,9 +36,9 @@ class ConfigLoadingTestCase(unittest.TestCase):
def test_load_fails_if_server_name_missing(self):
self.generate_config_and_remove_lines_containing("server_name")
- with self.assertRaises(Exception):
+ with self.assertRaises(ConfigError):
HomeServerConfig.load_config("", ["-c", self.file])
- with self.assertRaises(Exception):
+ with self.assertRaises(ConfigError):
HomeServerConfig.load_or_generate_config("", ["-c", self.file])
def test_generates_and_loads_macaroon_secret_key(self):
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 30fcc4c1bf..a56063315b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
+from unittest.mock import Mock
-from mock import Mock
-
+import attr
import canonicaljson
import signedjson.key
import signedjson.sign
@@ -68,6 +68,11 @@ class MockPerspectiveServer:
signedjson.sign.sign_json(res, self.server_name, self.key)
+@attr.s(slots=True)
+class FakeRequest:
+ id = attr.ib()
+
+
@logcontext_clean
class KeyringTestCase(unittest.HomeserverTestCase):
def check_context(self, val, expected):
@@ -89,7 +94,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
first_lookup_deferred = Deferred()
async def first_lookup_fetch(keys_to_fetch):
- self.assertEquals(current_context().request, "context_11")
+ self.assertEquals(current_context().request.id, "context_11")
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
await make_deferred_yieldable(first_lookup_deferred)
@@ -102,9 +107,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher.get_keys.side_effect = first_lookup_fetch
async def first_lookup():
- with LoggingContext("context_11") as context_11:
- context_11.request = "context_11"
-
+ with LoggingContext("context_11", request=FakeRequest("context_11")):
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
)
@@ -130,7 +133,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should block rather than start a second call
async def second_lookup_fetch(keys_to_fetch):
- self.assertEquals(current_context().request, "context_12")
+ self.assertEquals(current_context().request.id, "context_12")
return {
"server10": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
@@ -142,9 +145,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
second_lookup_state = [0]
async def second_lookup():
- with LoggingContext("context_12") as context_12:
- context_12.request = "context_12"
-
+ with LoggingContext("context_12", request=FakeRequest("context_12")):
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")]
)
@@ -589,10 +590,7 @@ def get_key_id(key):
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
- with LoggingContext("testctx") as ctx:
- # we set the "request" prop to make it easier to follow what's going on in the
- # logs.
- ctx.request = "testctx"
+ with LoggingContext("testctx"):
rv = yield f(*args, **kwargs)
return rv
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
new file mode 100644
index 0000000000..c996ecc221
--- /dev/null
+++ b/tests/events/test_presence_router.py
@@ -0,0 +1,386 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
+from unittest.mock import Mock
+
+import attr
+
+from synapse.api.constants import EduTypes
+from synapse.events.presence_router import PresenceRouter
+from synapse.federation.units import Transaction
+from synapse.handlers.presence import UserPresenceState
+from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, presence, room
+from synapse.types import JsonDict, StreamToken, create_requester
+
+from tests.handlers.test_sync import generate_sync_config
+from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+
+
+@attr.s
+class PresenceRouterTestConfig:
+ users_who_should_receive_all_presence = attr.ib(type=List[str], default=[])
+
+
+class PresenceRouterTestModule:
+ def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi):
+ self._config = config
+ self._module_api = module_api
+
+ async def get_users_for_states(
+ self, state_updates: Iterable[UserPresenceState]
+ ) -> Dict[str, Set[UserPresenceState]]:
+ users_to_state = {
+ user_id: set(state_updates)
+ for user_id in self._config.users_who_should_receive_all_presence
+ }
+ return users_to_state
+
+ async def get_interested_users(
+ self, user_id: str
+ ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+ if user_id in self._config.users_who_should_receive_all_presence:
+ return PresenceRouter.ALL_USERS
+
+ return set()
+
+ @staticmethod
+ def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
+ """Parse a configuration dictionary from the homeserver config, do
+ some validation and return a typed PresenceRouterConfig.
+
+ Args:
+ config_dict: The configuration dictionary.
+
+ Returns:
+ A validated config object.
+ """
+ # Initialise a typed config object
+ config = PresenceRouterTestConfig()
+
+ config.users_who_should_receive_all_presence = config_dict.get(
+ "users_who_should_receive_all_presence"
+ )
+
+ return config
+
+
+class PresenceRouterTestCase(FederatingHomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ presence.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
+
+ def prepare(self, reactor, clock, homeserver):
+ self.sync_handler = self.hs.get_sync_handler()
+ self.module_api = homeserver.get_module_api()
+
+ @override_config(
+ {
+ "presence": {
+ "presence_router": {
+ "module": __name__ + ".PresenceRouterTestModule",
+ "config": {
+ "users_who_should_receive_all_presence": [
+ "@presence_gobbler:test",
+ ]
+ },
+ }
+ },
+ "send_federation": True,
+ }
+ )
+ def test_receiving_all_presence(self):
+ """Test that a user that does not share a room with another other can receive
+ presence for them, due to presence routing.
+ """
+ # Create a user who should receive all presence of others
+ self.presence_receiving_user_id = self.register_user(
+ "presence_gobbler", "monkey"
+ )
+ self.presence_receiving_user_tok = self.login("presence_gobbler", "monkey")
+
+ # And two users who should not have any special routing
+ self.other_user_one_id = self.register_user("other_user_one", "monkey")
+ self.other_user_one_tok = self.login("other_user_one", "monkey")
+ self.other_user_two_id = self.register_user("other_user_two", "monkey")
+ self.other_user_two_tok = self.login("other_user_two", "monkey")
+
+ # Put the other two users in a room with each other
+ room_id = self.helper.create_room_as(
+ self.other_user_one_id, tok=self.other_user_one_tok
+ )
+
+ self.helper.invite(
+ room_id,
+ self.other_user_one_id,
+ self.other_user_two_id,
+ tok=self.other_user_one_tok,
+ )
+ self.helper.join(room_id, self.other_user_two_id, tok=self.other_user_two_tok)
+ # User one sends some presence
+ send_presence_update(
+ self,
+ self.other_user_one_id,
+ self.other_user_one_tok,
+ "online",
+ "boop",
+ )
+
+ # Check that the presence receiving user gets user one's presence when syncing
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiving_user_id
+ )
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.other_user_one_id)
+ self.assertEqual(presence_update.state, "online")
+ self.assertEqual(presence_update.status_msg, "boop")
+
+ # Have all three users send presence
+ send_presence_update(
+ self,
+ self.other_user_one_id,
+ self.other_user_one_tok,
+ "online",
+ "user_one",
+ )
+ send_presence_update(
+ self,
+ self.other_user_two_id,
+ self.other_user_two_tok,
+ "online",
+ "user_two",
+ )
+ send_presence_update(
+ self,
+ self.presence_receiving_user_id,
+ self.presence_receiving_user_tok,
+ "online",
+ "presence_gobbler",
+ )
+
+ # Check that the presence receiving user gets everyone's presence
+ presence_updates, _ = sync_presence(
+ self, self.presence_receiving_user_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 3)
+
+ # But that User One only get itself and User Two's presence
+ presence_updates, _ = sync_presence(self, self.other_user_one_id)
+ self.assertEqual(len(presence_updates), 2)
+
+ found = False
+ for update in presence_updates:
+ if update.user_id == self.other_user_two_id:
+ self.assertEqual(update.state, "online")
+ self.assertEqual(update.status_msg, "user_two")
+ found = True
+
+ self.assertTrue(found)
+
+ @override_config(
+ {
+ "presence": {
+ "presence_router": {
+ "module": __name__ + ".PresenceRouterTestModule",
+ "config": {
+ "users_who_should_receive_all_presence": [
+ "@presence_gobbler1:test",
+ "@presence_gobbler2:test",
+ "@far_away_person:island",
+ ]
+ },
+ }
+ },
+ "send_federation": True,
+ }
+ )
+ def test_send_local_online_presence_to_with_module(self):
+ """Tests that send_local_presence_to_users sends local online presence to a set
+ of specified local and remote users, with a custom PresenceRouter module enabled.
+ """
+ # Create a user who will send presence updates
+ self.other_user_id = self.register_user("other_user", "monkey")
+ self.other_user_tok = self.login("other_user", "monkey")
+
+ # And another two users that will also send out presence updates, as well as receive
+ # theirs and everyone else's
+ self.presence_receiving_user_one_id = self.register_user(
+ "presence_gobbler1", "monkey"
+ )
+ self.presence_receiving_user_one_tok = self.login("presence_gobbler1", "monkey")
+ self.presence_receiving_user_two_id = self.register_user(
+ "presence_gobbler2", "monkey"
+ )
+ self.presence_receiving_user_two_tok = self.login("presence_gobbler2", "monkey")
+
+ # Have all three users send some presence updates
+ send_presence_update(
+ self,
+ self.other_user_id,
+ self.other_user_tok,
+ "online",
+ "I'm online!",
+ )
+ send_presence_update(
+ self,
+ self.presence_receiving_user_one_id,
+ self.presence_receiving_user_one_tok,
+ "online",
+ "I'm also online!",
+ )
+ send_presence_update(
+ self,
+ self.presence_receiving_user_two_id,
+ self.presence_receiving_user_two_tok,
+ "unavailable",
+ "I'm in a meeting!",
+ )
+
+ # Mark each presence-receiving user for receiving all user presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to(
+ [
+ self.presence_receiving_user_one_id,
+ self.presence_receiving_user_two_id,
+ ]
+ )
+ )
+
+ # Perform a sync for each user
+
+ # The other user should only receive their own presence
+ presence_updates, _ = sync_presence(self, self.other_user_id)
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.other_user_id)
+ self.assertEqual(presence_update.state, "online")
+ self.assertEqual(presence_update.status_msg, "I'm online!")
+
+ # Whereas both presence receiving users should receive everyone's presence updates
+ presence_updates, _ = sync_presence(self, self.presence_receiving_user_one_id)
+ self.assertEqual(len(presence_updates), 3)
+ presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id)
+ self.assertEqual(len(presence_updates), 3)
+
+ # Test that sending to a remote user works
+ remote_user_id = "@far_away_person:island"
+
+ # Note that due to the remote user being in our module's
+ # users_who_should_receive_all_presence config, they would have
+ # received user presence updates already.
+ #
+ # Thus we reset the mock, and try sending all online local user
+ # presence again
+ self.hs.get_federation_transport_client().send_transaction.reset_mock()
+
+ # Broadcast local user online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to([remote_user_id])
+ )
+
+ # Check that the expected presence updates were sent
+ expected_users = [
+ self.other_user_id,
+ self.presence_receiving_user_one_id,
+ self.presence_receiving_user_two_id,
+ ]
+
+ calls = (
+ self.hs.get_federation_transport_client().send_transaction.call_args_list
+ )
+ for call in calls:
+ call_args = call[0]
+ federation_transaction = call_args[0] # type: Transaction
+
+ # Get the sent EDUs in this transaction
+ edus = federation_transaction.get_dict()["edus"]
+
+ for edu in edus:
+ # Make sure we're only checking presence-type EDUs
+ if edu["edu_type"] != EduTypes.Presence:
+ continue
+
+ # EDUs can contain multiple presence updates
+ for presence_update in edu["content"]["push"]:
+ # Check for presence updates that contain the user IDs we're after
+ expected_users.remove(presence_update["user_id"])
+
+ # Ensure that no offline states are being sent out
+ self.assertNotEqual(presence_update["presence"], "offline")
+
+ self.assertEqual(len(expected_users), 0)
+
+
+def send_presence_update(
+ testcase: TestCase,
+ user_id: str,
+ access_token: str,
+ presence_state: str,
+ status_message: Optional[str] = None,
+) -> JsonDict:
+ # Build the presence body
+ body = {"presence": presence_state}
+ if status_message:
+ body["status_msg"] = status_message
+
+ # Update the user's presence state
+ channel = testcase.make_request(
+ "PUT", "/presence/%s/status" % (user_id,), body, access_token=access_token
+ )
+ testcase.assertEqual(channel.code, 200)
+
+ return channel.json_body
+
+
+def sync_presence(
+ testcase: TestCase,
+ user_id: str,
+ since_token: Optional[StreamToken] = None,
+) -> Tuple[List[UserPresenceState], StreamToken]:
+ """Perform a sync request for the given user and return the user presence updates
+ they've received, as well as the next_batch token.
+
+ This method assumes testcase.sync_handler points to the homeserver's sync handler.
+
+ Args:
+ testcase: The testcase that is currently being run.
+ user_id: The ID of the user to generate a sync response for.
+ since_token: An optional token to indicate from at what point to sync from.
+
+ Returns:
+ A tuple containing a list of presence updates, and the sync response's
+ next_batch token.
+ """
+ requester = create_requester(user_id)
+ sync_config = generate_sync_config(requester.user.to_string())
+ sync_result = testcase.get_success(
+ testcase.sync_handler.wait_for_sync_for_user(
+ requester, sync_config, since_token
+ )
+ )
+
+ return sync_result.presence, sync_result.next_batch
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 8186b8ca01..701fa8379f 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 95eac6a5a3..802c5ad299 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -1,6 +1,5 @@
from typing import List, Tuple
-
-from mock import Mock
+from unittest.mock import Mock
from synapse.api.constants import EventTypes
from synapse.events import EventBase
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index ecc3faa572..deb12433cf 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
-
-from mock import Mock
+from unittest.mock import Mock
from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index a01fdd0839..32669ae9ce 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -14,8 +14,7 @@
# limitations under the License.
from collections import Counter
-
-from mock import Mock
+from unittest.mock import Mock
import synapse.api.errors
import synapse.handlers.admin
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index d5d3fdd99a..6e325b24ce 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c9f889b511..321c5ba045 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
import pymacaroons
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 7975af243c..0444b26798 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.handlers.cas_handler import CasResponse
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index fadec16e13..a8d0cf6603 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
import synapse
import synapse.api.errors
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 5e86c5e56b..6915ac0205 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import mock
+from unittest import mock
from signedjson import key as key, sign as sign
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index d7498aa51a..07893302ec 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -16,8 +16,7 @@
# limitations under the License.
import copy
-
-import mock
+from unittest import mock
from synapse.api.errors import SynapseError
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index c7796fb837..8702ee70e0 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -14,10 +14,9 @@
# limitations under the License.
import json
import os
+from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
-from mock import ANY, Mock, patch
-
import pymacaroons
from synapse.handlers.sso import MappingException
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index a98a65ae67..e28e4159eb 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -16,8 +16,7 @@
"""Tests for the password_auth_provider interface"""
from typing import Any, Type, Union
-
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 77330f59a9..9f16cc65fc 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from mock import Mock, call
+from unittest.mock import Mock, call
from signedjson.key import generate_signing_key
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index cbbe7280c7..60f2458c98 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
import synapse.types
from synapse.api.errors import AuthError, SynapseError
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 00a0bc5274..c30b414d99 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 30efd43b40..8cfc184fef 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -13,8 +13,7 @@
# limitations under the License.
from typing import Optional
-
-from mock import Mock
+from unittest.mock import Mock
import attr
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e62586142e..8e950f25c5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -37,7 +37,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:test"
user_id2 = "@user2:test"
- sync_config = self._generate_sync_config(user_id1)
+ sync_config = generate_sync_config(user_id1)
requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time
@@ -60,7 +60,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = False
- sync_config = self._generate_sync_config(user_id2)
+ sync_config = generate_sync_config(user_id2)
requester = create_requester(user_id2)
e = self.get_failure(
@@ -69,11 +69,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def _generate_sync_config(self, user_id):
- return SyncConfig(
- user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
- filter_collection=DEFAULT_FILTER_COLLECTION,
- is_guest=False,
- request_key="request_key",
- device_id="device_id",
- )
+
+def generate_sync_config(user_id: str) -> SyncConfig:
+ return SyncConfig(
+ user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
+ filter_collection=DEFAULT_FILTER_COLLECTION,
+ is_guest=False,
+ request_key="request_key",
+ device_id="device_id",
+ )
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 24e7138196..9fa231a37a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -16,8 +16,7 @@
import json
from typing import Dict
-
-from mock import ANY, Mock, call
+from unittest.mock import ANY, Mock, call
from twisted.internet import defer
from twisted.web.resource import Resource
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index dbe68bb058..67a8e49945 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 3972abb038..e6b20799e5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from mock import Mock
+from typing import Optional
+from unittest.mock import Mock
import treq
from netaddr import IPSet
@@ -180,7 +180,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
_check_logcontext(context)
def _handle_well_known_connection(
- self, client_factory, expected_sni, content, response_headers={}
+ self,
+ client_factory,
+ expected_sni,
+ content,
+ response_headers: Optional[dict] = None,
):
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
request is for a .well-known, and send the response.
@@ -202,10 +206,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(
request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
)
- self._send_well_known_response(request, content, headers=response_headers)
+ self._send_well_known_response(request, content, headers=response_headers or {})
return well_known_server
- def _send_well_known_response(self, request, content, headers={}):
+ def _send_well_known_response(
+ self, request, content, headers: Optional[dict] = None
+ ):
"""Check that an incoming request looks like a valid .well-known request, and
send back the response.
"""
@@ -213,7 +219,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(request.path, b"/.well-known/matrix/server")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
# send back a response
- for k, v in headers.items():
+ for k, v in (headers or {}).items():
request.setHeader(k, v)
request.write(content)
request.finish()
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index fee2985d35..466ce722d9 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
from twisted.internet.defer import Deferred
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 0ce181a51e..7e2f2a01cc 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -13,8 +13,7 @@
# limitations under the License.
from io import BytesIO
-
-from mock import Mock
+from unittest.mock import Mock
from netaddr import IPSet
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 9c52c8fdca..21c1297171 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from netaddr import IPSet
from parameterized import parameterized
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index 45089158ce..f979c96f7c 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -14,8 +14,7 @@
# limitations under the License.
import json
from io import BytesIO
-
-from mock import Mock
+from unittest.mock import Mock
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
index a1cf0862d4..cc4cae320d 100644
--- a/tests/http/test_simple_client.py
+++ b/tests/http/test_simple_client.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from netaddr import IPSet
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 48a74e2eee..ecf873e2ab 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -12,15 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
import logging
-from io import StringIO
+from io import BytesIO, StringIO
+from unittest.mock import Mock, patch
+
+from twisted.web.server import Request
+from synapse.http.site import SynapseRequest
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
from synapse.logging.context import LoggingContext, LoggingContextFilter
from tests.logging import LoggerCleanupMixin
+from tests.server import FakeChannel
from tests.unittest import TestCase
@@ -120,7 +124,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
handler.addFilter(LoggingContextFilter())
logger = self.get_logger(handler)
- with LoggingContext(request="test"):
+ with LoggingContext("name"):
logger.info("Hello there, %s!", "wally")
log = self.get_log_line()
@@ -134,4 +138,63 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
- self.assertEqual(log["request"], "test")
+ self.assertEqual(log["request"], "name")
+
+ def test_with_request_context(self):
+ """
+ Information from the logging context request should be added to the JSON response.
+ """
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(JsonFormatter())
+ handler.addFilter(LoggingContextFilter())
+ logger = self.get_logger(handler)
+
+ # A full request isn't needed here.
+ site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
+ site.site_tag = "test-site"
+ site.server_version_string = "Server v1"
+ request = SynapseRequest(FakeChannel(site, None))
+ # Call requestReceived to finish instantiating the object.
+ request.content = BytesIO()
+ # Partially skip some of the internal processing of SynapseRequest.
+ request._started_processing = Mock()
+ request.request_metrics = Mock(spec=["name"])
+ with patch.object(Request, "render"):
+ request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
+
+ # Also set the requester to ensure the processing works.
+ request.requester = "@foo:test"
+
+ with LoggingContext(
+ request.get_request_id(), parent_context=request.logcontext
+ ):
+ logger.info("Hello there, %s!", "wally")
+
+ log = self.get_log_line()
+
+ # The terse logger includes additional request information, if possible.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ "request",
+ "ip_address",
+ "site_tag",
+ "requester",
+ "authenticated_entity",
+ "method",
+ "url",
+ "protocol",
+ "user_agent",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
+ self.assertTrue(log["request"].startswith("POST-"))
+ self.assertEqual(log["ip_address"], "127.0.0.1")
+ self.assertEqual(log["site_tag"], "test-site")
+ self.assertEqual(log["requester"], "@foo:test")
+ self.assertEqual(log["authenticated_entity"], "@foo:test")
+ self.assertEqual(log["method"], "POST")
+ self.assertEqual(log["url"], "/_matrix/client/versions")
+ self.assertEqual(log["protocol"], "1.1")
+ self.assertEqual(log["user_agent"], "")
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index edacd1b566..349f93560e 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -12,27 +12,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
+from synapse.api.constants import EduTypes
from synapse.events import EventBase
+from synapse.federation.units import Transaction
+from synapse.handlers.presence import UserPresenceState
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v1 import login, presence, room
from synapse.types import create_requester
-from tests.unittest import HomeserverTestCase
+from tests.events.test_presence_router import send_presence_update, sync_presence
+from tests.test_utils.event_injection import inject_member_event
+from tests.unittest import FederatingHomeserverTestCase, override_config
-class ModuleApiTestCase(HomeserverTestCase):
+class ModuleApiTestCase(FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
+ presence.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
self.module_api = homeserver.get_module_api()
self.event_creation_handler = homeserver.get_event_creation_handler()
+ self.sync_handler = homeserver.get_sync_handler()
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
def test_can_register_user(self):
"""Tests that an external module can register a user"""
@@ -205,3 +217,161 @@ class ModuleApiTestCase(HomeserverTestCase):
)
)
self.assertFalse(is_in_public_rooms)
+
+ # The ability to send federation is required by send_local_online_presence_to.
+ @override_config({"send_federation": True})
+ def test_send_local_online_presence_to(self):
+ """Tests that send_local_presence_to_users sends local online presence to local users."""
+ # Create a user who will send presence updates
+ self.presence_receiver_id = self.register_user("presence_receiver", "monkey")
+ self.presence_receiver_tok = self.login("presence_receiver", "monkey")
+
+ # And another user that will send presence updates out
+ self.presence_sender_id = self.register_user("presence_sender", "monkey")
+ self.presence_sender_tok = self.login("presence_sender", "monkey")
+
+ # Put them in a room together so they will receive each other's presence updates
+ room_id = self.helper.create_room_as(
+ self.presence_receiver_id,
+ tok=self.presence_receiver_tok,
+ )
+ self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok)
+
+ # Presence sender comes online
+ send_presence_update(
+ self,
+ self.presence_sender_id,
+ self.presence_sender_tok,
+ "online",
+ "I'm online!",
+ )
+
+ # Presence receiver should have received it
+ presence_updates, sync_token = sync_presence(self, self.presence_receiver_id)
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.presence_sender_id)
+ self.assertEqual(presence_update.state, "online")
+
+ # Syncing again should result in no presence updates
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiver_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 0)
+
+ # Trigger sending local online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to(
+ [
+ self.presence_receiver_id,
+ ]
+ )
+ )
+
+ # Presence receiver should have received online presence again
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiver_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.presence_sender_id)
+ self.assertEqual(presence_update.state, "online")
+
+ # Presence sender goes offline
+ send_presence_update(
+ self,
+ self.presence_sender_id,
+ self.presence_sender_tok,
+ "offline",
+ "I slink back into the darkness.",
+ )
+
+ # Trigger sending local online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to(
+ [
+ self.presence_receiver_id,
+ ]
+ )
+ )
+
+ # Presence receiver should *not* have received offline state
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiver_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 0)
+
+ @override_config({"send_federation": True})
+ def test_send_local_online_presence_to_federation(self):
+ """Tests that send_local_presence_to_users sends local online presence to remote users."""
+ # Create a user who will send presence updates
+ self.presence_sender_id = self.register_user("presence_sender", "monkey")
+ self.presence_sender_tok = self.login("presence_sender", "monkey")
+
+ # And a room they're a part of
+ room_id = self.helper.create_room_as(
+ self.presence_sender_id,
+ tok=self.presence_sender_tok,
+ )
+
+ # Mark them as online
+ send_presence_update(
+ self,
+ self.presence_sender_id,
+ self.presence_sender_tok,
+ "online",
+ "I'm online!",
+ )
+
+ # Make up a remote user to send presence to
+ remote_user_id = "@far_away_person:island"
+
+ # Create a join membership event for the remote user into the room.
+ # This allows presence information to flow from one user to the other.
+ self.get_success(
+ inject_member_event(
+ self.hs,
+ room_id,
+ sender=remote_user_id,
+ target=remote_user_id,
+ membership="join",
+ )
+ )
+
+ # The remote user would have received the existing room members' presence
+ # when they joined the room.
+ #
+ # Thus we reset the mock, and try sending online local user
+ # presence again
+ self.hs.get_federation_transport_client().send_transaction.reset_mock()
+
+ # Broadcast local user online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to([remote_user_id])
+ )
+
+ # Check that a presence update was sent as part of a federation transaction
+ found_update = False
+ calls = (
+ self.hs.get_federation_transport_client().send_transaction.call_args_list
+ )
+ for call in calls:
+ call_args = call[0]
+ federation_transaction = call_args[0] # type: Transaction
+
+ # Get the sent EDUs in this transaction
+ edus = federation_transaction.get_dict()["edus"]
+
+ for edu in edus:
+ # Make sure we're only checking presence-type EDUs
+ if edu["edu_type"] != EduTypes.Presence:
+ continue
+
+ # EDUs can contain multiple presence updates
+ for presence_update in edu["content"]["push"]:
+ if presence_update["user_id"] == self.presence_sender_id:
+ found_update = True
+
+ self.assertTrue(found_update)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index a3b304d316..f590e8d21c 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet.defer import Deferred
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 1d4a592862..aff19d9fb3 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return resource
def make_worker_hs(
- self, worker_app: str, extra_config: dict = {}, **kwargs
+ self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
@@ -283,7 +283,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config = self._get_worker_hs_config()
config["worker_app"] = worker_app
- config.update(extra_config)
+ config.update(extra_config or {})
worker_hs = self.setup_test_homeserver(
homeserver_to_use=GenericWorkerServer,
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 56497b8476..83e89383f6 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from tests.replication._base import BaseStreamTestCase
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 0ceb0f935c..db80a0bdbd 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Iterable, Optional
from canonicaljson import encode_canonical_json
@@ -332,15 +333,18 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
room_id=ROOM_ID,
type="m.room.message",
key=None,
- internal={},
+ internal: Optional[dict] = None,
depth=None,
- prev_events=[],
- auth_events=[],
- prev_state=[],
+ prev_events: Optional[list] = None,
+ auth_events: Optional[list] = None,
+ prev_state: Optional[list] = None,
redacts=None,
- push_actions=[],
- **content
+ push_actions: Iterable = frozenset(),
+ **content,
):
+ prev_events = prev_events or []
+ auth_events = auth_events or []
+ prev_state = prev_state or []
if depth is None:
depth = self.event_id
@@ -369,7 +373,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if redacts is not None:
event_dict["redacts"] = redacts
- event = make_event_from_dict(event_dict, internal_metadata_dict=internal)
+ event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {})
self.event_id += 1
state_handler = self.hs.get_state_handler()
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 56b062ecc1..7d848e41ff 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -15,7 +15,7 @@
# type: ignore
-from mock import Mock
+from unittest.mock import Mock
from synapse.replication.tcp.streams._base import ReceiptsStream
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index ca49d4dd3a..4a0b342264 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.handlers.typing import RoomMember
from synapse.replication.tcp.streams import TypingStream
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 0d9e3bb11d..44ad5eec57 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import mock
+from unittest import mock
from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 2f2d117858..8ca595c3ee 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from mock import Mock
+from unittest.mock import Mock
from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index ab2988a6ba..1f12bde1aa 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index c9b773fbd2..6c2e1674cb 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from mock import patch
+from unittest.mock import patch
from synapse.api.room_versions import RoomVersion
from synapse.rest import admin
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 057e27372e..4abcbe3f55 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -17,8 +17,7 @@ import json
import os
import urllib.parse
from binascii import unhexlify
-
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet.defer import Deferred
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index b55160b70a..85f77c0a65 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -16,8 +16,7 @@
import json
import urllib.parse
from typing import List, Optional
-
-from mock import Mock
+from unittest.mock import Mock
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 79a05b519b..a7b600a1d4 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -19,8 +19,7 @@ import json
import urllib.parse
from binascii import unhexlify
from typing import List, Optional
-
-from mock import Mock
+from unittest.mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
@@ -28,7 +27,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import devices, sync
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from tests import unittest
from tests.server import FakeSite, make_request
@@ -467,6 +466,8 @@ class UsersListTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users"
def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -634,6 +635,26 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ # unkown order_by
+ channel = self.make_request(
+ "GET",
+ self.url + "?order_by=bar",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # invalid search order
+ channel = self.make_request(
+ "GET",
+ self.url + "?dir=bar",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
def test_limit(self):
"""
Testing list of users with limit
@@ -759,6 +780,103 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
+ def test_order_by(self):
+ """
+ Testing order list with parameter `order_by`
+ """
+
+ user1 = self.register_user("user1", "pass1", admin=False, displayname="Name Z")
+ user2 = self.register_user("user2", "pass2", admin=False, displayname="Name Y")
+
+ # Modify user
+ self.get_success(self.store.set_user_deactivated_status(user1, True))
+ self.get_success(self.store.set_shadow_banned(UserID.from_string(user1), True))
+
+ # Set avatar URL to all users, that no user has a NULL value to avoid
+ # different sort order between SQlite and PostreSQL
+ self.get_success(self.store.set_profile_avatar_url("user1", "mxc://url3"))
+ self.get_success(self.store.set_profile_avatar_url("user2", "mxc://url2"))
+ self.get_success(self.store.set_profile_avatar_url("admin", "mxc://url1"))
+
+ # order by default (name)
+ self._order_test([self.admin_user, user1, user2], None)
+ self._order_test([self.admin_user, user1, user2], None, "f")
+ self._order_test([user2, user1, self.admin_user], None, "b")
+
+ # order by name
+ self._order_test([self.admin_user, user1, user2], "name")
+ self._order_test([self.admin_user, user1, user2], "name", "f")
+ self._order_test([user2, user1, self.admin_user], "name", "b")
+
+ # order by displayname
+ self._order_test([user2, user1, self.admin_user], "displayname")
+ self._order_test([user2, user1, self.admin_user], "displayname", "f")
+ self._order_test([self.admin_user, user1, user2], "displayname", "b")
+
+ # order by is_guest
+ # like sort by ascending name, as no guest user here
+ self._order_test([self.admin_user, user1, user2], "is_guest")
+ self._order_test([self.admin_user, user1, user2], "is_guest", "f")
+ self._order_test([self.admin_user, user1, user2], "is_guest", "b")
+
+ # order by admin
+ self._order_test([user1, user2, self.admin_user], "admin")
+ self._order_test([user1, user2, self.admin_user], "admin", "f")
+ self._order_test([self.admin_user, user1, user2], "admin", "b")
+
+ # order by deactivated
+ self._order_test([self.admin_user, user2, user1], "deactivated")
+ self._order_test([self.admin_user, user2, user1], "deactivated", "f")
+ self._order_test([user1, self.admin_user, user2], "deactivated", "b")
+
+ # order by user_type
+ # like sort by ascending name, as no special user type here
+ self._order_test([self.admin_user, user1, user2], "user_type")
+ self._order_test([self.admin_user, user1, user2], "user_type", "f")
+ self._order_test([self.admin_user, user1, user2], "is_guest", "b")
+
+ # order by shadow_banned
+ self._order_test([self.admin_user, user2, user1], "shadow_banned")
+ self._order_test([self.admin_user, user2, user1], "shadow_banned", "f")
+ self._order_test([user1, self.admin_user, user2], "shadow_banned", "b")
+
+ # order by avatar_url
+ self._order_test([self.admin_user, user2, user1], "avatar_url")
+ self._order_test([self.admin_user, user2, user1], "avatar_url", "f")
+ self._order_test([user1, user2, self.admin_user], "avatar_url", "b")
+
+ def _order_test(
+ self,
+ expected_user_list: List[str],
+ order_by: Optional[str],
+ dir: Optional[str] = None,
+ ):
+ """Request the list of users in a certain order. Assert that order is what
+ we expect
+ Args:
+ expected_user_list: The list of user_id in the order we expect to get
+ back from the server
+ order_by: The type of ordering to give the server
+ dir: The direction of ordering to give the server
+ """
+
+ url = self.url + "?deactivated=true&"
+ if order_by is not None:
+ url += "order_by=%s&" % (order_by,)
+ if dir is not None and dir in ("b", "f"):
+ url += "dir=%s" % (dir,)
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], len(expected_user_list))
+
+ returned_order = [row["name"] for row in channel.json_body["users"]]
+ self.assertEqual(expected_user_list, returned_order)
+ self._check_fields(channel.json_body["users"])
+
def _check_fields(self, content: JsonDict):
"""Checks that the expected user attributes are present in content
Args:
@@ -2908,3 +3026,287 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
# Ensure the user is shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
self.assertTrue(result.shadow_banned)
+
+
+class RateLimitTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url = (
+ "/_synapse/admin/v1/users/%s/override_ratelimit"
+ % urllib.parse.quote(self.other_user)
+ )
+
+ def test_no_auth(self):
+ """
+ Try to get information of a user without authentication.
+ """
+ channel = self.make_request("GET", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ channel = self.make_request("POST", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ channel = self.make_request("DELETE", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=other_user_token,
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=other_user_token,
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=other_user_token,
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = (
+ "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit"
+ )
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "Only local users can be ratelimited", channel.json_body["error"]
+ )
+
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "Only local users can be ratelimited", channel.json_body["error"]
+ )
+
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+ # messages_per_second is a string
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"messages_per_second": "string"},
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # messages_per_second is negative
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"messages_per_second": -1},
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # burst_count is a string
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"burst_count": "string"},
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # burst_count is negative
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"burst_count": -1},
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_return_zero_when_null(self):
+ """
+ If values in database are `null` API should return an int `0`
+ """
+
+ self.get_success(
+ self.store.db_pool.simple_upsert(
+ table="ratelimit_override",
+ keyvalues={"user_id": self.other_user},
+ values={
+ "messages_per_second": None,
+ "burst_count": None,
+ },
+ )
+ )
+
+ # request status
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(0, channel.json_body["messages_per_second"])
+ self.assertEqual(0, channel.json_body["burst_count"])
+
+ def test_success(self):
+ """
+ Rate-limiting (set/update/delete) should succeed for an admin.
+ """
+ # request status
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertNotIn("messages_per_second", channel.json_body)
+ self.assertNotIn("burst_count", channel.json_body)
+
+ # set ratelimit
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"messages_per_second": 10, "burst_count": 11},
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(10, channel.json_body["messages_per_second"])
+ self.assertEqual(11, channel.json_body["burst_count"])
+
+ # update ratelimit
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"messages_per_second": 20, "burst_count": 21},
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(20, channel.json_body["messages_per_second"])
+ self.assertEqual(21, channel.json_body["burst_count"])
+
+ # request status
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(20, channel.json_body["messages_per_second"])
+ self.assertEqual(21, channel.json_body["burst_count"])
+
+ # delete ratelimit
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertNotIn("messages_per_second", channel.json_body)
+ self.assertNotIn("burst_count", channel.json_body)
+
+ # request status
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertNotIn("messages_per_second", channel.json_body)
+ self.assertNotIn("burst_count", channel.json_body)
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index b8285f3240..be1211dbce 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.api.constants import EventTypes
from synapse.rest import admin
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d2cce44032..288ee12888 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock, patch
+from unittest.mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventTypes
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index bf39014277..a7ebe0c3e9 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -14,8 +14,7 @@
# limitations under the License.
import threading
from typing import Dict
-
-from mock import Mock
+from unittest.mock import Mock
from synapse.events import EventBase
from synapse.module_api import ModuleApi
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 171632e195..3b5747cb12 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -1,4 +1,4 @@
-from mock import Mock, call
+from unittest.mock import Mock, call
from twisted.internet import defer, reactor
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 2ae896db1e..87a18d2cb9 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -15,7 +15,7 @@
""" Tests REST events for /events paths."""
-from mock import Mock
+from unittest.mock import Mock
import synapse.rest.admin
from synapse.rest.client.v1 import events, login, room
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 988821b16f..c7b79ab8a7 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -16,10 +16,9 @@
import time
import urllib.parse
from typing import Any, Dict, List, Optional, Union
+from unittest.mock import Mock
from urllib.parse import urlencode
-from mock import Mock
-
import pymacaroons
from twisted.web.resource import Resource
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 94a5154834..c136827f79 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index ed65f645fc..4df20c90fd 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -19,10 +19,10 @@
"""Tests REST events for /rooms paths."""
import json
+from typing import Iterable
+from unittest.mock import Mock
from urllib import parse as urlparse
-from mock import Mock
-
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
@@ -207,7 +207,9 @@ class RoomPermissionsTestCase(RoomBase):
)
self.assertEquals(403, channel.code, msg=channel.result["body"])
- def _test_get_membership(self, room=None, members=[], expect_code=None):
+ def _test_get_membership(
+ self, room=None, members: Iterable = frozenset(), expect_code=None
+ ):
for member in members:
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
channel = self.make_request("GET", path)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 329dbd06de..0b8f565121 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -16,7 +16,7 @@
"""Tests REST events for /rooms paths."""
-from mock import Mock
+from unittest.mock import Mock
from synapse.rest.client.v1 import room
from synapse.types import UserID
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 946740aa5d..a6a292b20c 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -21,8 +21,7 @@ import re
import time
import urllib.parse
from typing import Any, Dict, Mapping, MutableMapping, Optional
-
-from mock import patch
+from unittest.mock import patch
import attr
@@ -132,7 +131,7 @@ class RestHelper:
src: str,
targ: str,
membership: str,
- extra_data: dict = {},
+ extra_data: Optional[dict] = None,
tok: Optional[str] = None,
expect_code: int = 200,
) -> None:
@@ -156,7 +155,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
data = {"membership": membership}
- data.update(extra_data)
+ data.update(extra_data or {})
channel = make_request(
self.hs.get_reactor(),
@@ -187,7 +186,13 @@ class RestHelper:
)
def send_event(
- self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
+ self,
+ room_id,
+ type,
+ content: Optional[dict] = None,
+ txn_id=None,
+ tok=None,
+ expect_code=200,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -201,7 +206,7 @@ class RestHelper:
self.site,
"PUT",
path,
- json.dumps(content).encode("utf8"),
+ json.dumps(content or {}).encode("utf8"),
)
assert (
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 9734a2159a..ed433d9333 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Union
+from typing import Optional, Union
from twisted.internet.defer import succeed
@@ -74,7 +74,10 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
return channel
def recaptcha(
- self, session: str, expected_post_response: int, post_session: str = None
+ self,
+ session: str,
+ expected_post_response: int,
+ post_session: Optional[str] = None,
) -> None:
"""Get and respond to a fallback recaptcha. Returns the second request."""
if post_session is None:
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2d4ce871eb..41e52c701f 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import datetime
import json
import os
@@ -28,7 +27,7 @@ import pkg_resources
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import LoginType
+from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout
@@ -65,7 +64,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
self.hs.get_datastore().services_cache.append(appservice)
- request_data = json.dumps({"username": "as_user_kermit"})
+ request_data = json.dumps(
+ {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
+ )
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
@@ -75,9 +76,31 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
det_data = {"user_id": user_id, "home_server": self.hs.hostname}
self.assertDictContainsSubset(det_data, channel.json_body)
+ def test_POST_appservice_registration_no_type(self):
+ as_token = "i_am_an_app_service"
+
+ appservice = ApplicationService(
+ as_token,
+ self.hs.config.server_name,
+ id="1234",
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ sender="@as:test",
+ )
+
+ self.hs.get_datastore().services_cache.append(appservice)
+ request_data = json.dumps({"username": "as_user_kermit"})
+
+ channel = self.make_request(
+ b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
+ )
+
+ self.assertEquals(channel.result["code"], b"400", channel.result)
+
def test_POST_appservice_registration_invalid(self):
self.appservice = None # no application service exists
- request_data = json.dumps({"username": "kermit"})
+ request_data = json.dumps(
+ {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
+ )
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index e7bb5583fc..21ee436b91 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -16,6 +16,7 @@
import itertools
import json
import urllib
+from typing import Optional
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
@@ -681,7 +682,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
relation_type,
event_type,
key=None,
- content={},
+ content: Optional[dict] = None,
access_token=None,
parent_id=None,
):
@@ -713,7 +714,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, original_id, relation_type, event_type, query),
- json.dumps(content).encode("utf-8"),
+ json.dumps(content or {}).encode("utf-8"),
access_token=access_token,
)
return channel
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 9d0d0ef414..eb8687ce68 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -14,8 +14,7 @@
# limitations under the License.
import urllib.parse
from io import BytesIO, StringIO
-
-from mock import Mock
+from unittest.mock import Mock
import signedjson.key
from canonicaljson import encode_canonical_json
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 9f77125fd4..375f0b7977 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -18,10 +18,9 @@ import tempfile
from binascii import unhexlify
from io import BytesIO
from typing import Optional
+from unittest.mock import Mock
from urllib import parse
-from mock import Mock
-
import attr
from parameterized import parameterized_class
from PIL import Image as Image
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 6968502433..9067463e54 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -15,8 +15,7 @@
import json
import os
import re
-
-from mock import patch
+from unittest.mock import patch
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 6f56893f5e..885b95a51f 100644
--- a/tests/scripts/test_new_matrix_user.py
+++ b/tests/scripts/test_new_matrix_user.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse._scripts.register_new_matrix_user import request_registration
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index d40d65b06a..450b4ec710 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 1ce29af5fd..e755a4db62 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -15,8 +15,7 @@
import json
import os
import tempfile
-
-from mock import Mock
+from unittest.mock import Mock
import yaml
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 1b4fae0bb5..069db0edc4 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,4 +1,4 @@
-from mock import Mock
+from unittest.mock import Mock
from synapse.storage.background_updates import BackgroundUpdater
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index eac7e4dcd2..54e9e7f6fe 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -15,8 +15,7 @@
from collections import OrderedDict
-
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 7791138688..b02fb32ced 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -14,9 +14,7 @@
# limitations under the License.
import os.path
-from unittest.mock import patch
-
-from mock import Mock
+from unittest.mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventTypes
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 34e6526097..f7f75320ba 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
@@ -390,7 +390,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
+ synapse.rest.admin.register_servlets,
login.register_servlets,
]
@@ -434,7 +434,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
self.reactor,
self.site,
"GET",
- "/_synapse/admin/v1/users/" + self.user_id,
+ "/_synapse/admin/v2/users/" + self.user_id,
access_token=access_token,
custom_headers=headers1.items(),
**make_request_args,
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index 5a77c84962..a906d30e73 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -36,17 +36,6 @@ def _stub_db_engine(**kwargs) -> BaseDatabaseEngine:
class TupleComparisonClauseTestCase(unittest.TestCase):
def test_native_tuple_comparison(self):
- db_engine = _stub_db_engine(supports_tuple_comparison=True)
- clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)])
+ clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2])
-
- def test_emulated_tuple_comparison(self):
- db_engine = _stub_db_engine(supports_tuple_comparison=False)
- clause, args = make_tuple_comparison_clause(
- db_engine, [("a", 1), ("b", 2), ("c", 3)]
- )
- self.assertEqual(
- clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))"
- )
- self.assertEqual(args, [1, 1, 2, 2, 3])
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index dabc1c5f09..ef4cf8d0f1 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,32 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
import synapse.api.errors
-import tests.unittest
-import tests.utils
-
-
-class DeviceStoreTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
+from tests.unittest import HomeserverTestCase
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class DeviceStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_store_new_device(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device_id", "display_name")
)
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res,
)
- @defer.inlineCallbacks
def test_get_devices_by_user(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device2", "display_name 2")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3")
)
- res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
+ res = self.get_success(self.store.get_devices_by_user("user_id"))
self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset(
{
@@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res["device2"],
)
- @defer.inlineCallbacks
def test_count_devices_by_users(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device2", "display_name 2")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3")
)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+ res = self.get_success(self.store.count_devices_by_users())
self.assertEqual(0, res)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+ res = self.get_success(self.store.count_devices_by_users(["unknown"]))
self.assertEqual(0, res)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+ res = self.get_success(self.store.count_devices_by_users(["user_id"]))
self.assertEqual(2, res)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.count_devices_by_users(["user_id", "user_id2"])
)
self.assertEqual(3, res)
- @defer.inlineCallbacks
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield defer.ensureDeferred(
+ now_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=100)
)
@@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
}
self.assertEqual(received_device_ids, set(expected_device_ids))
- @defer.inlineCallbacks
def test_update_device(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1")
)
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
- yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ self.get_success(self.store.update_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
- yield defer.ensureDeferred(
+ self.get_success(
self.store.update_device(
"user_id", "device_id", new_display_name="display_name 2"
)
)
# check it worked
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
- @defer.inlineCallbacks
def test_update_unknown_device(self):
- with self.assertRaises(synapse.api.errors.StoreError) as cm:
- yield defer.ensureDeferred(
- self.store.update_device(
- "user_id", "unknown_device_id", new_display_name="display_name 2"
- )
- )
- self.assertEqual(404, cm.exception.code)
+ exc = self.get_failure(
+ self.store.update_device(
+ "user_id", "unknown_device_id", new_display_name="display_name 2"
+ ),
+ synapse.api.errors.StoreError,
+ )
+ self.assertEqual(404, exc.value.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index da93ca3980..0db233fd68 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,28 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.types import RoomAlias, RoomID
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase
-class DirectoryStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
-
+class DirectoryStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
- @defer.inlineCallbacks
def test_room_to_alias(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
@@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertEquals(
["#my-room:test"],
- (
- yield defer.ensureDeferred(
- self.store.get_aliases_for_room(self.room.to_string())
- )
- ),
+ (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_alias_to_room(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
@@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]},
- (
- yield defer.ensureDeferred(
- self.store.get_association_from_room_alias(self.alias)
- )
- ),
+ (self.get_success(self.store.get_association_from_room_alias(self.alias))),
)
- @defer.inlineCallbacks
def test_delete_alias(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
)
- room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
+ room_id = self.get_success(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone(
- (
- yield defer.ensureDeferred(
- self.store.get_association_from_room_alias(self.alias)
- )
- )
+ (self.get_success(self.store.get_association_from_room_alias(self.alias)))
)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 3fc4bb13b6..1e54b940fd 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,30 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+from tests.unittest import HomeserverTestCase
-import tests.unittest
-import tests.utils
-
-class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EndToEndKeyStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_key_without_device_name(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+ self.get_success(self.store.store_device("user", "device", None))
- yield defer.ensureDeferred(
- self.store.set_e2e_device_keys("user", "device", now, json)
- )
+ self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
@@ -44,38 +36,32 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev)
- @defer.inlineCallbacks
def test_reupload_key(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+ self.get_success(self.store.store_device("user", "device", None))
- changed = yield defer.ensureDeferred(
+ changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing
# changed
- changed = yield defer.ensureDeferred(
+ changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertFalse(changed)
- @defer.inlineCallbacks
def test_get_key_with_device_name(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(
- self.store.set_e2e_device_keys("user", "device", now, json)
- )
- yield defer.ensureDeferred(
- self.store.store_device("user", "device", "display_name")
- )
+ self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
+ self.get_success(self.store.store_device("user", "device", "display_name"))
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
@@ -85,29 +71,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
)
- @defer.inlineCallbacks
def test_multiple_devices(self):
now = 1470174257070
- yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
- yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
- yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
- yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
+ self.get_success(self.store.store_device("user1", "device1", None))
+ self.get_success(self.store.store_device("user1", "device2", None))
+ self.get_success(self.store.store_device("user2", "device1", None))
+ self.get_success(self.store.store_device("user2", "device2", None))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api(
(("user1", "device1"), ("user2", "device2"))
)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 485f1ee033..0289942f88 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,12 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
-from twisted.internet import defer
-
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
USER_ID = "@user:example.com"
@@ -30,37 +27,31 @@ HIGHLIGHT = [
]
-class EventPushActionsStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EventPushActionsStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.persist_events_store = hs.get_datastores().persist_events
- @defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20
)
)
- @defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20
)
)
- @defer.inlineCallbacks
def test_count_aggregation(self):
room_id = "!foo:example.com"
user_id = "@user1235:example.com"
- @defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield defer.ensureDeferred(
+ counts = self.get_success(
self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
@@ -74,7 +65,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
},
)
- @defer.inlineCallbacks
def _inject_actions(stream, action):
event = Mock()
event.room_id = room_id
@@ -82,14 +72,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_push_actions_to_staging(
event.event_id,
{user_id: action},
False,
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"",
self.persist_events_store._set_push_actions_for_event_and_users_txn,
@@ -99,14 +89,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
def _rotate(stream):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
)
def _mark_read(stream, depth):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"",
self.store._remove_old_push_actions_before_txn,
@@ -116,49 +106,48 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
)
- yield _assert_counts(0, 0)
- yield _inject_actions(1, PlAIN_NOTIF)
- yield _assert_counts(1, 0)
- yield _rotate(2)
- yield _assert_counts(1, 0)
+ _assert_counts(0, 0)
+ _inject_actions(1, PlAIN_NOTIF)
+ _assert_counts(1, 0)
+ _rotate(2)
+ _assert_counts(1, 0)
- yield _inject_actions(3, PlAIN_NOTIF)
- yield _assert_counts(2, 0)
- yield _rotate(4)
- yield _assert_counts(2, 0)
+ _inject_actions(3, PlAIN_NOTIF)
+ _assert_counts(2, 0)
+ _rotate(4)
+ _assert_counts(2, 0)
- yield _inject_actions(5, PlAIN_NOTIF)
- yield _mark_read(3, 3)
- yield _assert_counts(1, 0)
+ _inject_actions(5, PlAIN_NOTIF)
+ _mark_read(3, 3)
+ _assert_counts(1, 0)
- yield _mark_read(5, 5)
- yield _assert_counts(0, 0)
+ _mark_read(5, 5)
+ _assert_counts(0, 0)
- yield _inject_actions(6, PlAIN_NOTIF)
- yield _rotate(7)
+ _inject_actions(6, PlAIN_NOTIF)
+ _rotate(7)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
)
- yield _assert_counts(1, 0)
+ _assert_counts(1, 0)
- yield _mark_read(7, 7)
- yield _assert_counts(0, 0)
+ _mark_read(7, 7)
+ _assert_counts(0, 0)
- yield _inject_actions(8, HIGHLIGHT)
- yield _assert_counts(1, 1)
- yield _rotate(9)
- yield _assert_counts(1, 1)
- yield _rotate(10)
- yield _assert_counts(1, 1)
+ _inject_actions(8, HIGHLIGHT)
+ _assert_counts(1, 1)
+ _rotate(9)
+ _assert_counts(1, 1)
+ _rotate(10)
+ _assert_counts(1, 1)
- @defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.simple_insert(
"events",
{
@@ -177,24 +166,16 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
# start with the base case where there are no events in the table
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(11)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.assertEqual(r, 0)
# now with one event
- yield add_event(2, 10)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(9)
- )
+ add_event(2, 10)
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(9))
self.assertEqual(r, 2)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(10)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(10))
self.assertEqual(r, 2)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(11)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.assertEqual(r, 3)
# add a bunch of dummy events to the events table
@@ -205,39 +186,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
(10, 130),
(20, 140),
):
- yield add_event(stream_ordering, ts)
+ add_event(stream_ordering, ts)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(110)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(110))
self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(120)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(120))
self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(129)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(129))
self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(140)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(140))
self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(160)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(160))
self.assertEqual(r, 21)
# check we can find an event at ordering zero
- yield add_event(0, 5)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(1)
- )
+ add_event(0, 5)
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(1))
self.assertEqual(r, 0)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index aad6bc907e..6c389fe9ac 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional
+
from synapse.storage.database import DatabasePool
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -43,7 +45,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
- self, instance_name="master", writers=["master"]
+ self, instance_name="master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
@@ -53,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
- writers=writers,
+ writers=writers or ["master"],
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
@@ -476,7 +478,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
- self, instance_name="master", writers=["master"]
+ self, instance_name="master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
@@ -486,7 +488,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
- writers=writers,
+ writers=writers or ["master"],
positive=False,
)
@@ -612,7 +614,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
- self, instance_name="master", writers=["master"]
+ self, instance_name="master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
@@ -625,7 +627,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
("foobar2", "instance_name", "stream_id"),
],
sequence_name="foobar_seq",
- writers=writers,
+ writers=writers or ["master"],
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 5858c7fcc4..47556791f4 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index b7dde51224..c6256fce86 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,59 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.types import UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
-
-class ProfileStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class ProfileStoreTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test")
- @defer.inlineCallbacks
def test_displayname(self):
- yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+ self.get_success(self.store.create_profile(self.u_frank.localpart))
- yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
+ self.get_success(
+ self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
)
self.assertEquals(
"Frank",
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
)
),
)
# test set to None
- yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.u_frank.localpart, None, 2)
+ self.get_success(
+ self.store.set_profile_displayname(self.u_frank.localpart, None)
)
self.assertIsNone(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
)
)
)
- @defer.inlineCallbacks
def test_avatar_url(self):
- yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+ self.get_success(self.store.create_profile(self.u_frank.localpart))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here", 1
)
@@ -74,20 +65,20 @@ class ProfileStoreTestCase(unittest.TestCase):
self.assertEquals(
"http://my.site/here",
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
),
)
# test set to None
- yield defer.ensureDeferred(
- self.store.set_profile_avatar_url(self.u_frank.localpart, None, 2)
+ self.get_success(
+ self.store.set_profile_avatar_url(self.u_frank.localpart, None)
)
self.assertIsNone(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index b2a0e60856..2d2f58903c 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,11 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
@@ -50,10 +48,15 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.depth = 1
def inject_room_member(
- self, room, user, membership, replaces_state=None, extra_content={}
+ self,
+ room,
+ user,
+ membership,
+ replaces_state=None,
+ extra_content: Optional[dict] = None,
):
content = {"membership": membership}
- content.update(extra_content)
+ content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -230,10 +233,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._base_builder = base_builder
self._event_id = event_id
- @defer.inlineCallbacks
- def build(self, prev_event_ids, auth_event_ids):
- built_event = yield defer.ensureDeferred(
- self._base_builder.build(prev_event_ids, auth_event_ids)
+ async def build(self, prev_event_ids, auth_event_ids):
+ built_event = await self._base_builder.build(
+ prev_event_ids, auth_event_ids
)
built_event._event_id = self._event_id
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4eb41c46e8..c82cf15bc2 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,21 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
-class RegistrationStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class RegistrationStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.user_id = "@my-user:test"
@@ -35,9 +28,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg"
- @defer.inlineCallbacks
def test_register(self):
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEquals(
{
@@ -49,93 +41,81 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_version": None,
"consent_server_notice_sent": None,
"appservice_id": None,
- "creation_ts": 1000,
+ "creation_ts": 0,
"user_type": None,
"deactivated": 0,
"shadow_banned": 0,
},
- (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
+ (self.get_success(self.store.get_user_by_id(self.user_id))),
)
- @defer.inlineCallbacks
def test_add_tokens(self):
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
- yield defer.ensureDeferred(
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
- result = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[1])
- )
+ result = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.assertEqual(result.user_id, self.user_id)
self.assertEqual(result.device_id, self.device_id)
self.assertIsNotNone(result.token_id)
- @defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
- yield defer.ensureDeferred(
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
# now delete some
- yield defer.ensureDeferred(
+ self.get_success(
self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
)
# check they were deleted
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[1])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.assertIsNone(user, "access token was not deleted by device_id")
# check the one not associated with the device was not deleted
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[0])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertEqual(self.user_id, user.user_id)
# now delete the rest
- yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
+ self.get_success(self.store.user_delete_access_tokens(self.user_id))
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[0])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertIsNone(user, "access token was not deleted without device_id")
- @defer.inlineCallbacks
def test_is_support_user(self):
TEST_USER = "@test:test"
SUPPORT_USER = "@support:test"
- res = yield defer.ensureDeferred(self.store.is_support_user(None))
+ res = self.get_success(self.store.is_support_user(None))
self.assertFalse(res)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.register_user(user_id=TEST_USER, password_hash=None)
)
- res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
+ res = self.get_success(self.store.is_support_user(TEST_USER))
self.assertFalse(res)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.register_user(
user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
)
)
- res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
+ res = self.get_success(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res)
- @defer.inlineCallbacks
def test_3pid_inhibit_invalid_validation_session_error(self):
"""Tests that enabling the configuration option to inhibit 3PID errors on
/requestToken also inhibits validation errors caused by an unknown session ID.
@@ -143,30 +123,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
# Check that, with the config setting set to false (the default value), a
# validation error is caused by the unknown session ID.
- try:
- yield defer.ensureDeferred(
- self.store.validate_threepid_session(
- "fake_sid",
- "fake_client_secret",
- "fake_token",
- 0,
- )
- )
- except ThreepidValidationError as e:
- self.assertEquals(e.msg, "Unknown session_id", e)
+ e = self.get_failure(
+ self.store.validate_threepid_session(
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
+ ),
+ ThreepidValidationError,
+ )
+ self.assertEquals(e.value.msg, "Unknown session_id", e)
# Set the config setting to true.
self.store._ignore_unknown_session_error = True
# Check that now the validation error is caused by the token not matching.
- try:
- yield defer.ensureDeferred(
- self.store.validate_threepid_session(
- "fake_sid",
- "fake_client_secret",
- "fake_token",
- 0,
- )
- )
- except ThreepidValidationError as e:
- self.assertEquals(e.msg, "Validation token not found or has expired", e)
+ e = self.get_failure(
+ self.store.validate_threepid_session(
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
+ ),
+ ThreepidValidationError,
+ )
+ self.assertEquals(e.value.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index bc8400f240..0089d33c93 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,22 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomAlias, RoomID, UserID
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
-class RoomStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class RoomStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
self.store = hs.get_datastore()
@@ -37,7 +30,7 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(),
@@ -46,7 +39,6 @@ class RoomStoreTestCase(unittest.TestCase):
)
)
- @defer.inlineCallbacks
def test_get_room(self):
self.assertDictContainsSubset(
{
@@ -54,16 +46,12 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
+ (self.get_success(self.store.get_room(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_get_room_unknown_room(self):
- self.assertIsNone(
- (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
- )
+ self.assertIsNone((self.get_success(self.store.get_room("!uknown:test"))))
- @defer.inlineCallbacks
def test_get_room_with_stats(self):
self.assertDictContainsSubset(
{
@@ -71,29 +59,17 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
- (
- yield defer.ensureDeferred(
- self.store.get_room_with_stats(self.room.to_string())
- )
- ),
+ (self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
self.assertIsNone(
- (
- yield defer.ensureDeferred(
- self.store.get_room_with_stats("!uknown:test")
- )
- ),
+ (self.get_success(self.store.get_room_with_stats("!uknown:test"))),
)
-class RoomEventsStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = setup_test_homeserver(self.addCleanup)
-
+class RoomEventsStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
@@ -102,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
@@ -111,23 +87,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
)
)
- @defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield defer.ensureDeferred(
+ self.get_success(
self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
- @defer.inlineCallbacks
def STALE_test_room_name(self):
name = "A-Room-Name"
- yield self.inject_room_event(
+ self.inject_room_event(
etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
)
@@ -137,15 +111,14 @@ class RoomEventsStoreTestCase(unittest.TestCase):
state[0],
)
- @defer.inlineCallbacks
def STALE_test_room_topic(self):
topic = "A place for things"
- yield self.inject_room_event(
+ self.inject_room_event(
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 2471f1267d..f06b452fa9 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,24 +15,18 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
logger = logging.getLogger(__name__)
-class StateStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
-
+class StateStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_datastore = self.storage.state.stores.state
@@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
@@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
)
- @defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
@@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
},
)
- event, context = yield defer.ensureDeferred(
+ event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- yield defer.ensureDeferred(
- self.storage.persistence.persist_event(event, context)
- )
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
- @defer.inlineCallbacks
def test_get_state_groups_ids(self):
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield defer.ensureDeferred(
+ state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
@@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
)
- @defer.inlineCallbacks
def test_get_state_groups(self):
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield defer.ensureDeferred(
+ state_group_map = self.get_success(
self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
@@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
- @defer.inlineCallbacks
def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- e3 = yield self.inject_state_event(
+ e3 = self.inject_state_event(
self.room,
self.u_alice,
EventTypes.Member,
self.u_alice.to_string(),
{"membership": Membership.JOIN},
)
- e4 = yield self.inject_state_event(
+ e4 = self.inject_state_event(
self.room,
self.u_bob,
EventTypes.Member,
self.u_bob.to_string(),
{"membership": Membership.JOIN},
)
- e5 = yield self.inject_state_event(
+ e5 = self.inject_state_event(
self.room,
self.u_bob,
EventTypes.Member,
@@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield defer.ensureDeferred(
- self.storage.state.get_state_for_event(e5.event_id)
- )
+ state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
self.assertIsNotNone(e4)
@@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
@@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
@@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
@@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
@@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
@@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield defer.ensureDeferred(
+ group_ids = self.get_success(
self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with wildcard types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -413,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string()
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -428,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string()
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -444,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# wildcard types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -458,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -480,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -494,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -510,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -524,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index a6f63f4aaf..019c5b7b14 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,10 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase, override_config
ALICE = "@alice:a"
BOB = "@bob:b"
@@ -25,73 +22,52 @@ BOBBY = "@bobby:a"
BELA = "@somenickname:a"
-class UserDirectoryStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.store = self.hs.get_datastore()
+class UserDirectoryStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(ALICE, "alice", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BOB, "bob", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BELA, "Bela", None)
- )
- yield defer.ensureDeferred(
- self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
- )
+ self.get_success(self.store.update_profile_in_user_dir(ALICE, "alice", None))
+ self.get_success(self.store.update_profile_in_user_dir(BOB, "bob", None))
+ self.get_success(self.store.update_profile_in_user_dir(BOBBY, "bobby", None))
+ self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
+ self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
- @defer.inlineCallbacks
def test_search_user_dir(self):
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
+ r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(
r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
)
- @defer.inlineCallbacks
+ @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_all_users(self):
- self.hs.config.user_directory_search_all_users = True
- try:
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
- self.assertFalse(r["limited"])
- self.assertEqual(2, len(r["results"]))
- self.assertDictEqual(
- r["results"][0],
- {"user_id": BOB, "display_name": "bob", "avatar_url": None},
- )
- self.assertDictEqual(
- r["results"][1],
- {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
- )
- finally:
- self.hs.config.user_directory_search_all_users = False
+ r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(2, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BOB, "display_name": "bob", "avatar_url": None},
+ )
+ self.assertDictEqual(
+ r["results"][1],
+ {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
+ )
- @defer.inlineCallbacks
+ @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_stop_words(self):
"""Tests that a user can look up another user by searching for the start if its
display name even if that name happens to be a common English word that would
usually be ignored in full text searches.
"""
- self.hs.config.user_directory_search_all_users = True
- try:
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
- self.assertFalse(r["limited"])
- self.assertEqual(1, len(r["results"]))
- self.assertDictEqual(
- r["results"][0],
- {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
- )
- finally:
- self.hs.config.user_directory_search_all_users = False
+ r = self.get_success(self.store.search_user_dir(ALICE, "be", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+ )
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index b57f36e6ac..6a6cf709f6 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock, patch
+from unittest.mock import Mock, patch
from synapse.util.distributor import Distributor
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 3f2691ee6b..b5f18344dc 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -207,6 +207,226 @@ class EventAuthTestCase(unittest.TestCase):
do_sig_check=False,
)
+ def test_join_rules_public(self):
+ """
+ Test joining a public room.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.join_rules", ""): _join_rules_event(creator, "public"),
+ }
+
+ # Check join.
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user cannot be force-joined to a room.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _member_event(pleb, "join", sender=creator),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Banned should be rejected.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user who left can re-join.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can send a join if they're in the room.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can accept an invite.
+ auth_events[("m.room.member", pleb)] = _member_event(
+ pleb, "invite", sender=creator
+ )
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ def test_join_rules_invite(self):
+ """
+ Test joining an invite only room.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.join_rules", ""): _join_rules_event(creator, "invite"),
+ }
+
+ # A join without an invite is rejected.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user cannot be force-joined to a room.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _member_event(pleb, "join", sender=creator),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Banned should be rejected.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user who left cannot re-join.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can send a join if they're in the room.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can accept an invite.
+ auth_events[("m.room.member", pleb)] = _member_event(
+ pleb, "invite", sender=creator
+ )
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ def test_join_rules_msc3083_restricted(self):
+ """
+ Test joining a restricted room from MSC3083.
+
+ This is pretty much the same test as public.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"),
+ }
+
+ # Older room versions don't understand this join rule
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Check join.
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user cannot be force-joined to a room.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _member_event(pleb, "join", sender=creator),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Banned should be rejected.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user who left can re-join.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can send a join if they're in the room.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can accept an invite.
+ auth_events[("m.room.member", pleb)] = _member_event(
+ pleb, "invite", sender=creator
+ )
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
# helpers for making events
@@ -225,19 +445,24 @@ def _create_event(user_id):
)
-def _join_event(user_id):
+def _member_event(user_id, membership, sender=None):
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
"event_id": _get_event_id(),
"type": "m.room.member",
- "sender": user_id,
+ "sender": sender or user_id,
"state_key": user_id,
- "content": {"membership": "join"},
+ "content": {"membership": membership},
+ "prev_events": [],
}
)
+def _join_event(user_id):
+ return _member_event(user_id, "join")
+
+
def _power_levels_event(sender, content):
return make_event_from_dict(
{
@@ -277,6 +502,21 @@ def _random_state_event(sender):
)
+def _join_rules_event(sender, join_rule):
+ return make_event_from_dict(
+ {
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.join_rules",
+ "sender": sender,
+ "state_key": "",
+ "content": {
+ "join_rule": join_rule,
+ },
+ }
+ )
+
+
event_count = 0
diff --git a/tests/test_federation.py b/tests/test_federation.py
index fc9aab32d0..382cedbd5d 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet.defer import succeed
@@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- with LoggingContext():
+ with LoggingContext("test-context"):
failure = self.get_failure(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 75d28a42df..7d92a16a8d 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -15,9 +15,7 @@
"""Tests REST events for /rooms paths."""
-import json
-
-from synapse.api.constants import LoginType
+from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService
from synapse.rest.client.v2_alpha import register, sync
@@ -113,7 +111,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
)
)
- self.create_user("as_kermit4", token=as_token)
+ self.create_user("as_kermit4", token=as_token, appservice=True)
def test_allowed_after_a_month_mau(self):
# Create and sync so that the MAU counts get updated
@@ -232,14 +230,15 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.reactor.advance(100)
self.assertEqual(2, self.successResultOf(count))
- def create_user(self, localpart, token=None):
- request_data = json.dumps(
- {
- "username": localpart,
- "password": "monkey",
- "auth": {"type": LoginType.DUMMY},
- }
- )
+ def create_user(self, localpart, token=None, appservice=False):
+ request_data = {
+ "username": localpart,
+ "password": "monkey",
+ "auth": {"type": LoginType.DUMMY},
+ }
+
+ if appservice:
+ request_data["type"] = APP_SERVICE_REGISTRATION_TYPE
channel = self.make_request(
"POST",
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index e7aed092c2..0f800a075b 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -14,8 +14,7 @@
# limitations under the License.
import resource
-
-import mock
+from unittest import mock
from synapse.app.phone_stats_home import phone_stats_home
diff --git a/tests/test_state.py b/tests/test_state.py
index 6227a3ba95..0d626f49f6 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -12,8 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from mock import Mock
+from typing import List, Optional
+from unittest.mock import Mock
from twisted.internet import defer
@@ -37,8 +37,8 @@ def create_event(
state_key=None,
depth=2,
event_id=None,
- prev_events=[],
- **kwargs
+ prev_events: Optional[List[str]] = None,
+ **kwargs,
):
global _next_event_id
@@ -58,7 +58,7 @@ def create_event(
"sender": "@user_id:example.com",
"room_id": "!room_id:example.com",
"depth": depth,
- "prev_events": prev_events,
+ "prev_events": prev_events or [],
}
if state_key is not None:
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index a743cdc3a9..0df480db9f 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -13,8 +13,7 @@
# limitations under the License.
import json
-
-from mock import Mock
+from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactorClock
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 43898d8142..b557ffd692 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -21,8 +21,7 @@ import sys
import warnings
from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar
-
-from mock import Mock
+from unittest.mock import Mock
import attr
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index c3c4a93e1f..3dfbf8f8a9 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -33,7 +33,7 @@ async def inject_member_event(
membership: str,
target: Optional[str] = None,
extra_content: Optional[dict] = None,
- **kwargs
+ **kwargs,
) -> EventBase:
"""Inject a membership event into a room."""
if target is None:
@@ -58,7 +58,7 @@ async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs
+ **kwargs,
) -> EventBase:
"""Inject a generic event into a room
@@ -83,7 +83,7 @@ async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs
+ **kwargs,
) -> Tuple[EventBase, EventContext]:
if room_version is None:
room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 510b630114..e502ac197e 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from mock import Mock
+from typing import Optional
+from unittest.mock import Mock
from twisted.internet import defer
from twisted.internet.defer import succeed
@@ -147,9 +147,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
return event
@defer.inlineCallbacks
- def inject_room_member(self, user_id, membership="join", extra_content={}):
+ def inject_room_member(
+ self, user_id, membership="join", extra_content: Optional[dict] = None
+ ):
content = {"membership": membership}
- content.update(extra_content)
+ content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
diff --git a/tests/unittest.py b/tests/unittest.py
index 58a4daa1ec..92764434bd 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -21,8 +21,7 @@ import inspect
import logging
import time
from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
-
-from mock import Mock, patch
+from unittest.mock import Mock, patch
from canonicaljson import json
@@ -471,7 +470,7 @@ class HomeserverTestCase(TestCase):
kwargs["config"] = config_obj
async def run_bg_updates():
- with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
+ with LoggingContext("run_bg_updates"):
while not await stor.db_pool.updates.has_completed_background_updates():
await stor.db_pool.updates.do_next_background_update(1)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index afb11b9caf..8c082e7432 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -15,8 +15,7 @@
# limitations under the License.
import logging
from typing import Set
-
-import mock
+from unittest import mock
from twisted.internet import defer, reactor
@@ -232,8 +231,7 @@ class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def do_lookup():
- with LoggingContext() as c1:
- c1.name = "c1"
+ with LoggingContext("c1") as c1:
r = yield obj.fn(1)
self.assertEqual(current_context(), c1)
return r
@@ -275,8 +273,7 @@ class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def do_lookup():
- with LoggingContext() as c1:
- c1.name = "c1"
+ with LoggingContext("c1") as c1:
try:
d = obj.fn(1)
self.assertEqual(
@@ -661,14 +658,13 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1")
async def list_fn(self, args1, arg2):
- assert current_context().request == "c1"
+ assert current_context().name == "c1"
# we want this to behave like an asynchronous function
await run_on_reactor()
- assert current_context().request == "c1"
+ assert current_context().name == "c1"
return self.mock(args1, arg2)
- with LoggingContext() as c1:
- c1.request = "c1"
+ with LoggingContext("c1") as c1:
obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2)
diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py
index 816795c136..23018081e5 100644
--- a/tests/util/caches/test_ttlcache.py
+++ b/tests/util/caches/test_ttlcache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.util.caches.ttlcache import TTLCache
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index 2012263184..d1372f6bc2 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -16,8 +16,7 @@
import threading
from io import StringIO
-
-from mock import NonCallableMock
+from unittest.mock import NonCallableMock
from twisted.internet import defer, reactor
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 58ee918f65..5d9c4665aa 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -17,11 +17,10 @@ from .. import unittest
class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
- self.assertEquals(current_context().request, value)
+ self.assertEquals(current_context().name, value)
def test_with_context(self):
- with LoggingContext() as context_one:
- context_one.request = "test"
+ with LoggingContext("test"):
self._check_test_key("test")
@defer.inlineCallbacks
@@ -30,15 +29,13 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def competing_callback():
- with LoggingContext() as competing_context:
- competing_context.request = "competing"
+ with LoggingContext("competing"):
yield clock.sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
- with LoggingContext() as context_one:
- context_one.request = "one"
+ with LoggingContext("one"):
yield clock.sleep(0)
self._check_test_key("one")
@@ -47,9 +44,7 @@ class LoggingContextTestCase(unittest.TestCase):
callback_completed = [False]
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
# fire off function, but don't wait on it.
d2 = run_in_background(function)
@@ -133,9 +128,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context()
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
@@ -149,9 +142,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_make_deferred_yieldable_with_chained_deferreds(self):
sentinel_context = current_context()
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
@@ -166,9 +157,7 @@ class LoggingContextTestCase(unittest.TestCase):
"""Check that make_deferred_yieldable does the right thing when its
argument isn't actually a deferred"""
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable("bum")
self._check_test_key("one")
@@ -177,9 +166,9 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
def test_nested_logging_context(self):
- with LoggingContext(request="foo"):
+ with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar")
- self.assertEqual(nested_context.request, "foo-bar")
+ self.assertEqual(nested_context.name, "foo-bar")
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_await(self):
@@ -193,9 +182,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context()
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index a739a6aaaf..ce4f1cc30a 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py
index 4d1aee91d5..3fed55090a 100644
--- a/tests/util/test_ratelimitutils.py
+++ b/tests/util/test_ratelimitutils.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
from synapse.config.homeserver import HomeServerConfig
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -89,9 +91,9 @@ def _await_resolution(reactor, d):
return (reactor.seconds() - start_time) * 1000
-def build_rc_config(settings={}):
+def build_rc_config(settings: Optional[dict] = None):
config_dict = default_config("test")
- config_dict.update(settings)
+ config_dict.update(settings or {})
config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "")
return config.rc_federation
diff --git a/tests/utils.py b/tests/utils.py
index 5d299f766f..65d7ad58d9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -21,10 +21,9 @@ import time
import uuid
import warnings
from typing import Type
+from unittest.mock import Mock, patch
from urllib import parse as urlparse
-from mock import Mock, patch
-
from twisted.internet import defer
from synapse.api.constants import EventTypes
@@ -122,7 +121,6 @@ def default_config(name, parse=False):
"enable_registration_captcha": False,
"macaroon_secret_key": "not even a little secret",
"trusted_third_party_id_servers": [],
- "room_invite_state_types": [],
"password_providers": [],
"worker_replication_url": "",
"worker_app": None,
@@ -198,7 +196,7 @@ def setup_test_homeserver(
config=None,
reactor=None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
- **kwargs
+ **kwargs,
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
diff --git a/tox.ini b/tox.ini
index 5365939e10..8e27efaebe 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,8 @@
[tox]
-envlist = packaging, py35, py36, py37, py38, py39, check_codestyle, check_isort
+envlist = packaging, py36, py37, py38, py39, check_codestyle, check_isort
+
+# we require tox>=2.3.2 for the fix to https://github.com/tox-dev/tox/issues/208
+minversion = 2.3.2
[base]
deps =
@@ -48,6 +51,7 @@ deps =
extras =
# install the optional dependendencies for tox environments without
# '-noextras' in their name
+ # (this requires tox 3)
!noextras: all
test
@@ -74,8 +78,6 @@ commands =
# we use "env" rather than putting a value in `setenv` so that it is not
# inherited by other tox environments.
#
- # keep this in sync with the copy in `testenv:py3-old`.
- #
/usr/bin/env COVERAGE_PROCESS_START={toxinidir}/.coveragerc "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
# As of twisted 16.4, trial tries to import the tests as a package (previously
@@ -121,11 +123,7 @@ commands =
# Install Synapse itself. This won't update any libraries.
pip install -e ".[test]"
- # we have to duplicate the command from `testenv` rather than refer to it
- # as `{[testenv]commands}`, because we run on ubuntu xenial, which has
- # tox 2.3.1, and https://github.com/tox-dev/tox/issues/208.
- #
- /usr/bin/env COVERAGE_PROCESS_START={toxinidir}/.coveragerc "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
+ {[testenv]commands}
[testenv:benchmark]
deps =
@@ -137,7 +135,8 @@ commands =
python -m synmark {posargs:}
[testenv:packaging]
-skip_install=True
+skip_install = true
+usedevelop = false
deps =
check-manifest
commands =
@@ -155,7 +154,8 @@ extras = lint
commands = isort -c --df --sp setup.cfg {[base]lint_targets}
[testenv:check-newsfragment]
-skip_install = True
+skip_install = true
+usedevelop = false
deps = towncrier>=18.6.0rc1
commands =
python -m towncrier.check --compare-with=origin/dinsic
@@ -164,7 +164,8 @@ commands =
commands = {toxinidir}/scripts-dev/generate_sample_config --check
[testenv:combine]
-skip_install = True
+skip_install = true
+usedevelop = false
deps =
coverage
pip>=10 ; python_version >= '3.6'
@@ -174,14 +175,16 @@ commands=
coverage report
[testenv:cov-erase]
-skip_install = True
+skip_install = true
+usedevelop = false
deps =
coverage
commands=
coverage erase
[testenv:cov-html]
-skip_install = True
+skip_install = true
+usedevelop = false
deps =
coverage
commands=
|