diff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist
index fd98cbbaf6..5975cb98cf 100644
--- a/.buildkite/worker-blacklist
+++ b/.buildkite/worker-blacklist
@@ -1,41 +1,10 @@
# This file serves as a blacklist for SyTest tests that we expect will fail in
# Synapse when run under worker mode. For more details, see sytest-blacklist.
-Message history can be paginated
-
Can re-join room if re-invited
-The only membership state included in an initial sync is for all the senders in the timeline
-
-Local device key changes get to remote servers
-
-If remote user leaves room we no longer receive device updates
-
-Forgotten room messages cannot be paginated
-
-Inbound federation can get public room list
-
-Members from the gap are included in gappy incr LL sync
-
-Leaves are present in non-gapped incremental syncs
-
-Old leaves are present in gapped incremental syncs
-
-User sees updates to presence from other users in the incremental sync.
-
-Gapped incremental syncs include all state changes
-
-Old members are included in gappy incr LL sync if they start speaking
-
# new failures as of https://github.com/matrix-org/sytest/pull/732
Device list doesn't change if remote server is down
-Remote servers cannot set power levels in rooms without existing powerlevels
-Remote servers should reject attempts by non-creators to set the power levels
# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5
Server correctly handles incoming m.device_list_update
-
-# this fails reliably with a torture level of 100 due to https://github.com/matrix-org/synapse/issues/6536
-Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state
-
-Can get rooms/{roomId}/members at a given point
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 5bd2ab2b76..b10cbedd6d 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -1,22 +1,36 @@
-version: 2
+version: 2.1
jobs:
dockerhubuploadrelease:
- machine: true
+ docker:
+ - image: docker:git
steps:
- checkout
- - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:${CIRCLE_TAG} .
+ - setup_remote_docker
+ - docker_prepare
- run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
- - run: docker push matrixdotorg/synapse:${CIRCLE_TAG}
+ - docker_build:
+ tag: -t matrixdotorg/synapse:${CIRCLE_TAG}
+ platforms: linux/amd64
+ - docker_build:
+ tag: -t matrixdotorg/synapse:${CIRCLE_TAG}
+ platforms: linux/amd64,linux/arm/v7,linux/arm64
+
dockerhubuploadlatest:
- machine: true
+ docker:
+ - image: docker:git
steps:
- checkout
- - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:latest .
+ - setup_remote_docker
+ - docker_prepare
- run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
- - run: docker push matrixdotorg/synapse:latest
+ - docker_build:
+ tag: -t matrixdotorg/synapse:latest
+ platforms: linux/amd64
+ - docker_build:
+ tag: -t matrixdotorg/synapse:latest
+ platforms: linux/amd64,linux/arm/v7,linux/arm64
workflows:
- version: 2
build:
jobs:
- dockerhubuploadrelease:
@@ -29,3 +43,33 @@ workflows:
filters:
branches:
only: master
+
+commands:
+ docker_prepare:
+ description: Downloads the buildx cli plugin and enables multiarch images
+ parameters:
+ buildx_version:
+ type: string
+ default: "v0.4.1"
+ steps:
+ - run: apk add --no-cache curl
+ - run: mkdir -vp ~/.docker/cli-plugins/ ~/dockercache
+ - run: curl --silent -L "https://github.com/docker/buildx/releases/download/<< parameters.buildx_version >>/buildx-<< parameters.buildx_version >>.linux-amd64" > ~/.docker/cli-plugins/docker-buildx
+ - run: chmod a+x ~/.docker/cli-plugins/docker-buildx
+ # install qemu links in /proc/sys/fs/binfmt_misc on the docker instance running the circleci job
+ - run: docker run --rm --privileged multiarch/qemu-user-static --reset -p yes
+ # create a context named `builder` for the builds
+ - run: docker context create builder
+ # create a buildx builder using the new context, and set it as the default
+ - run: docker buildx create builder --use
+
+ docker_build:
+ description: Builds and pushed images to dockerhub using buildx
+ parameters:
+ platforms:
+ type: string
+ default: linux/amd64
+ tag:
+ type: string
+ steps:
+ - run: docker buildx build -f docker/Dockerfile --push --platform << parameters.platforms >> --label gitsha1=${CIRCLE_SHA1} << parameters.tag >> --progress=plain .
diff --git a/README.rst b/README.rst
index 4a189c8bc4..d609b4b62e 100644
--- a/README.rst
+++ b/README.rst
@@ -1,10 +1,6 @@
-================
-Synapse |shield|
-================
-
-.. |shield| image:: https://img.shields.io/matrix/synapse:matrix.org?label=support&logo=matrix
- :alt: (get support on #synapse:matrix.org)
- :target: https://matrix.to/#/#synapse:matrix.org
+=========================================================
+Synapse |support| |development| |license| |pypi| |python|
+=========================================================
.. contents::
@@ -290,19 +286,6 @@ Testing with SyTest is recommended for verifying that changes related to the
Client-Server API are functioning correctly. See the `installation instructions
<https://github.com/matrix-org/sytest#installing>`_ for details.
-Building Internal API Documentation
-===================================
-
-Before building internal API documentation install sphinx and
-sphinxcontrib-napoleon::
-
- pip install sphinx
- pip install sphinxcontrib-napoleon
-
-Building internal API documentation::
-
- python setup.py build_sphinx
-
Troubleshooting
===============
@@ -387,3 +370,23 @@ something like the following in their logs::
This is normally caused by a misconfiguration in your reverse-proxy. See
`<docs/reverse_proxy.md>`_ and double-check that your settings are correct.
+
+.. |support| image:: https://img.shields.io/matrix/synapse:matrix.org?label=support&logo=matrix
+ :alt: (get support on #synapse:matrix.org)
+ :target: https://matrix.to/#/#synapse:matrix.org
+
+.. |development| image:: https://img.shields.io/matrix/synapse-dev:matrix.org?label=development&logo=matrix
+ :alt: (discuss development on #synapse-dev:matrix.org)
+ :target: https://matrix.to/#/#synapse-dev:matrix.org
+
+.. |license| image:: https://img.shields.io/github/license/matrix-org/synapse
+ :alt: (check license in LICENSE file)
+ :target: LICENSE
+
+.. |pypi| image:: https://img.shields.io/pypi/v/matrix-synapse
+ :alt: (latest version released on PyPi)
+ :target: https://pypi.org/project/matrix-synapse
+
+.. |python| image:: https://img.shields.io/pypi/pyversions/matrix-synapse
+ :alt: (supported python versions)
+ :target: https://pypi.org/project/matrix-synapse
diff --git a/UPGRADE.rst b/UPGRADE.rst
index 49e86e628f..5a68312217 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -75,6 +75,23 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
+Upgrading to v1.22.0
+====================
+
+ThirdPartyEventRules breaking changes
+-------------------------------------
+
+This release introduces a backwards-incompatible change to modules making use of
+``ThirdPartyEventRules`` in Synapse. If you make use of a module defined under the
+``third_party_event_rules`` config option, please make sure it is updated to handle
+the below change:
+
+The ``http_client`` argument is no longer passed to modules as they are initialised. Instead,
+modules are expected to make use of the ``http_client`` property on the ``ModuleApi`` class.
+Modules are now passed a ``module_api`` argument during initialisation, which is an instance of
+``ModuleApi``. ``ModuleApi`` instances have a ``http_client`` property which acts the same as
+the ``http_client`` argument previously passed to ``ThirdPartyEventRules`` modules.
+
Upgrading to v1.21.0
====================
diff --git a/changelog.d/7658.feature b/changelog.d/7658.feature
new file mode 100644
index 0000000000..fbf345988d
--- /dev/null
+++ b/changelog.d/7658.feature
@@ -0,0 +1 @@
+Add a configuration option for always using the "userinfo endpoint" for OpenID Connect. This fixes support for some identity providers, e.g. GitLab. Contributed by Benjamin Koch.
diff --git a/changelog.d/7921.docker b/changelog.d/7921.docker
new file mode 100644
index 0000000000..7cecd67c6a
--- /dev/null
+++ b/changelog.d/7921.docker
@@ -0,0 +1 @@
+Added multi-arch support (arm64,arm/v7) for the docker images. Contributed by @maquis196.
diff --git a/changelog.d/8292.feature b/changelog.d/8292.feature
new file mode 100644
index 0000000000..6d0335e2c8
--- /dev/null
+++ b/changelog.d/8292.feature
@@ -0,0 +1 @@
+Allow `ThirdPartyEventRules` modules to query and manipulate whether a room is in the public rooms directory.
\ No newline at end of file
diff --git a/changelog.d/8312.feature b/changelog.d/8312.feature
new file mode 100644
index 0000000000..222a1b032a
--- /dev/null
+++ b/changelog.d/8312.feature
@@ -0,0 +1 @@
+Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)).
\ No newline at end of file
diff --git a/changelog.d/8369.feature b/changelog.d/8369.feature
new file mode 100644
index 0000000000..542993110b
--- /dev/null
+++ b/changelog.d/8369.feature
@@ -0,0 +1 @@
+Allow running background tasks in a separate worker process.
diff --git a/changelog.d/8380.feature b/changelog.d/8380.feature
new file mode 100644
index 0000000000..05ccea19dc
--- /dev/null
+++ b/changelog.d/8380.feature
@@ -0,0 +1 @@
+Add support for device dehydration ([MSC2697](https://github.com/matrix-org/matrix-doc/pull/2697)).
diff --git a/changelog.d/8390.docker b/changelog.d/8390.docker
new file mode 100644
index 0000000000..f71b8e4bbf
--- /dev/null
+++ b/changelog.d/8390.docker
@@ -0,0 +1 @@
+Add support for passing commandline args to the synapse process. Contributed by @samuel-p.
diff --git a/changelog.d/8407.misc b/changelog.d/8407.misc
new file mode 100644
index 0000000000..d37002d75b
--- /dev/null
+++ b/changelog.d/8407.misc
@@ -0,0 +1 @@
+Add typing information to the device handler.
diff --git a/changelog.d/8432.misc b/changelog.d/8432.misc
new file mode 100644
index 0000000000..01fdad4caf
--- /dev/null
+++ b/changelog.d/8432.misc
@@ -0,0 +1 @@
+Check for unreachable code with mypy.
diff --git a/changelog.d/8433.misc b/changelog.d/8433.misc
new file mode 100644
index 0000000000..05f8b5bbf4
--- /dev/null
+++ b/changelog.d/8433.misc
@@ -0,0 +1 @@
+Add unit test for event persister sharding.
diff --git a/changelog.d/8439.misc b/changelog.d/8439.misc
new file mode 100644
index 0000000000..237cb3b311
--- /dev/null
+++ b/changelog.d/8439.misc
@@ -0,0 +1 @@
+Allow events to be sent to clients sooner when using sharded event persisters.
diff --git a/changelog.d/8443.misc b/changelog.d/8443.misc
new file mode 100644
index 0000000000..633598e6b3
--- /dev/null
+++ b/changelog.d/8443.misc
@@ -0,0 +1 @@
+Configure `public_baseurl` when using demo scripts.
diff --git a/changelog.d/8448.misc b/changelog.d/8448.misc
new file mode 100644
index 0000000000..5ddda1803b
--- /dev/null
+++ b/changelog.d/8448.misc
@@ -0,0 +1 @@
+Add SQL logging on queries that happen during startup.
diff --git a/changelog.d/8450.misc b/changelog.d/8450.misc
new file mode 100644
index 0000000000..4e04c523ab
--- /dev/null
+++ b/changelog.d/8450.misc
@@ -0,0 +1 @@
+Speed up unit tests when using PostgreSQL.
diff --git a/changelog.d/8452.misc b/changelog.d/8452.misc
new file mode 100644
index 0000000000..8288d91c78
--- /dev/null
+++ b/changelog.d/8452.misc
@@ -0,0 +1 @@
+Remove redundant databae loads of stream_ordering for events we already have.
diff --git a/changelog.d/8454.bugfix b/changelog.d/8454.bugfix
new file mode 100644
index 0000000000..c06d490b6f
--- /dev/null
+++ b/changelog.d/8454.bugfix
@@ -0,0 +1 @@
+Fix a longstanding bug where invalid ignored users in account data could break clients.
diff --git a/changelog.d/8457.bugfix b/changelog.d/8457.bugfix
new file mode 100644
index 0000000000..545b06d180
--- /dev/null
+++ b/changelog.d/8457.bugfix
@@ -0,0 +1 @@
+Fix a bug where backfilling a room with an event that was missing the `redacts` field would break.
diff --git a/changelog.d/8458.feature b/changelog.d/8458.feature
new file mode 100644
index 0000000000..542993110b
--- /dev/null
+++ b/changelog.d/8458.feature
@@ -0,0 +1 @@
+Allow running background tasks in a separate worker process.
diff --git a/changelog.d/8461.feature b/changelog.d/8461.feature
new file mode 100644
index 0000000000..3665d670e1
--- /dev/null
+++ b/changelog.d/8461.feature
@@ -0,0 +1 @@
+Change default room version to "6", per [MSC2788](https://github.com/matrix-org/matrix-doc/pull/2788).
diff --git a/changelog.d/8462.doc b/changelog.d/8462.doc
new file mode 100644
index 0000000000..cf84db6db7
--- /dev/null
+++ b/changelog.d/8462.doc
@@ -0,0 +1 @@
+Update the directions for using the manhole with coroutines.
diff --git a/changelog.d/8463.misc b/changelog.d/8463.misc
new file mode 100644
index 0000000000..040c9bb90f
--- /dev/null
+++ b/changelog.d/8463.misc
@@ -0,0 +1 @@
+Reduce inconsistencies between codepaths for membership and non-membership events.
diff --git a/changelog.d/8464.misc b/changelog.d/8464.misc
new file mode 100644
index 0000000000..a552e88f9f
--- /dev/null
+++ b/changelog.d/8464.misc
@@ -0,0 +1 @@
+Combine `SpamCheckerApi` with the more generic `ModuleApi`.
diff --git a/changelog.d/8465.bugfix b/changelog.d/8465.bugfix
new file mode 100644
index 0000000000..73f895b268
--- /dev/null
+++ b/changelog.d/8465.bugfix
@@ -0,0 +1 @@
+Don't attempt to respond to some requests if the client has already disconnected.
\ No newline at end of file
diff --git a/changelog.d/8467.feature b/changelog.d/8467.feature
new file mode 100644
index 0000000000..6d0335e2c8
--- /dev/null
+++ b/changelog.d/8467.feature
@@ -0,0 +1 @@
+Allow `ThirdPartyEventRules` modules to query and manipulate whether a room is in the public rooms directory.
\ No newline at end of file
diff --git a/changelog.d/8468.misc b/changelog.d/8468.misc
new file mode 100644
index 0000000000..32ba991e64
--- /dev/null
+++ b/changelog.d/8468.misc
@@ -0,0 +1 @@
+Additional testing for `ThirdPartyEventRules`.
diff --git a/changelog.d/8474.misc b/changelog.d/8474.misc
new file mode 100644
index 0000000000..65e329a6e3
--- /dev/null
+++ b/changelog.d/8474.misc
@@ -0,0 +1 @@
+Unblacklist some sytests.
diff --git a/changelog.d/8477.misc b/changelog.d/8477.misc
new file mode 100644
index 0000000000..2ee1606b6e
--- /dev/null
+++ b/changelog.d/8477.misc
@@ -0,0 +1 @@
+Include the log level in the phone home stats.
diff --git a/changelog.d/8479.feature b/changelog.d/8479.feature
new file mode 100644
index 0000000000..11adeec8a9
--- /dev/null
+++ b/changelog.d/8479.feature
@@ -0,0 +1 @@
+Add the ability to send non-membership events into a room via the `ModuleApi`.
\ No newline at end of file
diff --git a/changelog.d/8480.misc b/changelog.d/8480.misc
new file mode 100644
index 0000000000..81633af296
--- /dev/null
+++ b/changelog.d/8480.misc
@@ -0,0 +1 @@
+Remove outdated sphinx documentation, scripts and configuration.
\ No newline at end of file
diff --git a/changelog.d/8486.bugfix b/changelog.d/8486.bugfix
new file mode 100644
index 0000000000..63fc091ba6
--- /dev/null
+++ b/changelog.d/8486.bugfix
@@ -0,0 +1 @@
+Fix incremental sync returning an incorrect `prev_batch` token in timeline section, which when used to paginate returned events that were included in the incremental sync. Broken since v0.16.0.
diff --git a/changelog.d/8489.feature b/changelog.d/8489.feature
new file mode 100644
index 0000000000..22591870a4
--- /dev/null
+++ b/changelog.d/8489.feature
@@ -0,0 +1 @@
+ Allow running background tasks in a separate worker process.
diff --git a/changelog.d/8492.misc b/changelog.d/8492.misc
new file mode 100644
index 0000000000..a344aee791
--- /dev/null
+++ b/changelog.d/8492.misc
@@ -0,0 +1 @@
+Clarify error message when plugin config parsers raise an error.
diff --git a/changelog.d/8493.doc b/changelog.d/8493.doc
new file mode 100644
index 0000000000..26797cd99e
--- /dev/null
+++ b/changelog.d/8493.doc
@@ -0,0 +1 @@
+Improve readme by adding new shield.io badges.
diff --git a/changelog.d/8494.misc b/changelog.d/8494.misc
new file mode 100644
index 0000000000..6e56c6b854
--- /dev/null
+++ b/changelog.d/8494.misc
@@ -0,0 +1 @@
+Remove the deprecated `Handlers` object.
diff --git a/changelog.d/8496.misc b/changelog.d/8496.misc
new file mode 100644
index 0000000000..237cb3b311
--- /dev/null
+++ b/changelog.d/8496.misc
@@ -0,0 +1 @@
+Allow events to be sent to clients sooner when using sharded event persisters.
diff --git a/changelog.d/8497.misc b/changelog.d/8497.misc
new file mode 100644
index 0000000000..8bc05e8df6
--- /dev/null
+++ b/changelog.d/8497.misc
@@ -0,0 +1 @@
+Fix a threadsafety bug in unit tests.
diff --git a/changelog.d/8499.misc b/changelog.d/8499.misc
new file mode 100644
index 0000000000..237cb3b311
--- /dev/null
+++ b/changelog.d/8499.misc
@@ -0,0 +1 @@
+Allow events to be sent to clients sooner when using sharded event persisters.
diff --git a/changelog.d/8501.feature b/changelog.d/8501.feature
new file mode 100644
index 0000000000..5220ddd482
--- /dev/null
+++ b/changelog.d/8501.feature
@@ -0,0 +1 @@
+Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)).
diff --git a/changelog.d/8502.feature b/changelog.d/8502.feature
new file mode 100644
index 0000000000..faab8d3042
--- /dev/null
+++ b/changelog.d/8502.feature
@@ -0,0 +1 @@
+Increase default upload size limit from 10M to 50M. Contributed by @Akkowicz.
diff --git a/changelog.d/8505.misc b/changelog.d/8505.misc
new file mode 100644
index 0000000000..5aa5c113bd
--- /dev/null
+++ b/changelog.d/8505.misc
@@ -0,0 +1 @@
+Add type hints to various parts of the code base.
diff --git a/changelog.d/8507.misc b/changelog.d/8507.misc
new file mode 100644
index 0000000000..724da8a996
--- /dev/null
+++ b/changelog.d/8507.misc
@@ -0,0 +1 @@
+ Add type hints to various parts of the code base.
diff --git a/changelog.d/8514.misc b/changelog.d/8514.misc
new file mode 100644
index 0000000000..0e7ac4f220
--- /dev/null
+++ b/changelog.d/8514.misc
@@ -0,0 +1 @@
+Remove unused code from the test framework.
diff --git a/demo/start.sh b/demo/start.sh
index 83396e5c33..f6b5ea137f 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -30,6 +30,8 @@ for port in 8080 8081 8082; do
if ! grep -F "Customisation made by demo/start.sh" -q $DIR/etc/$port.config; then
printf '\n\n# Customisation made by demo/start.sh\n' >> $DIR/etc/$port.config
+ echo "public_baseurl: http://localhost:$port/" >> $DIR/etc/$port.config
+
echo 'enable_registration: true' >> $DIR/etc/$port.config
# Warning, this heredoc depends on the interaction of tabs and spaces. Please don't
diff --git a/docker/README.md b/docker/README.md
index d0da34778e..c8f27b8566 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -83,7 +83,7 @@ docker logs synapse
If all is well, you should now be able to connect to http://localhost:8008 and
see a confirmation message.
-The following environment variables are supported in run mode:
+The following environment variables are supported in `run` mode:
* `SYNAPSE_CONFIG_DIR`: where additional config files are stored. Defaults to
`/data`.
@@ -94,6 +94,20 @@ The following environment variables are supported in run mode:
* `UID`, `GID`: the user and group id to run Synapse as. Defaults to `991`, `991`.
* `TZ`: the [timezone](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) the container will run with. Defaults to `UTC`.
+For more complex setups (e.g. for workers) you can also pass your args directly to synapse using `run` mode. For example like this:
+
+```
+docker run -d --name synapse \
+ --mount type=volume,src=synapse-data,dst=/data \
+ -p 8008:8008 \
+ matrixdotorg/synapse:latest run \
+ -m synapse.app.generic_worker \
+ --config-path=/data/homeserver.yaml \
+ --config-path=/data/generic_worker.yaml
+```
+
+If you do not provide `-m`, the value of the `SYNAPSE_WORKER` environment variable is used. If you do not provide at least one `--config-path` or `-c`, the value of the `SYNAPSE_CONFIG_PATH` environment variable is used instead.
+
## Generating an (admin) user
After synapse is running, you may wish to create a user via `register_new_matrix_user`.
diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml
index c1110f0f53..a808485c12 100644
--- a/docker/conf/homeserver.yaml
+++ b/docker/conf/homeserver.yaml
@@ -90,7 +90,7 @@ federation_rc_concurrent: 3
media_store_path: "/data/media"
uploads_path: "/data/uploads"
-max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "10M" }}"
+max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "50M" }}"
max_image_pixels: "32M"
dynamic_thumbnails: false
diff --git a/docker/start.py b/docker/start.py
index 9f08134158..0d2c590b88 100755
--- a/docker/start.py
+++ b/docker/start.py
@@ -179,7 +179,7 @@ def run_generate_config(environ, ownership):
def main(args, environ):
- mode = args[1] if len(args) > 1 else None
+ mode = args[1] if len(args) > 1 else "run"
desired_uid = int(environ.get("UID", "991"))
desired_gid = int(environ.get("GID", "991"))
synapse_worker = environ.get("SYNAPSE_WORKER", "synapse.app.homeserver")
@@ -205,36 +205,47 @@ def main(args, environ):
config_dir, config_path, environ, ownership
)
- if mode is not None:
+ if mode != "run":
error("Unknown execution mode '%s'" % (mode,))
- config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data")
- config_path = environ.get("SYNAPSE_CONFIG_PATH", config_dir + "/homeserver.yaml")
+ args = args[2:]
- if not os.path.exists(config_path):
- if "SYNAPSE_SERVER_NAME" in environ:
- error(
- """\
+ if "-m" not in args:
+ args = ["-m", synapse_worker] + args
+
+ # if there are no config files passed to synapse, try adding the default file
+ if not any(p.startswith("--config-path") or p.startswith("-c") for p in args):
+ config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data")
+ config_path = environ.get(
+ "SYNAPSE_CONFIG_PATH", config_dir + "/homeserver.yaml"
+ )
+
+ if not os.path.exists(config_path):
+ if "SYNAPSE_SERVER_NAME" in environ:
+ error(
+ """\
Config file '%s' does not exist.
The synapse docker image no longer supports generating a config file on-the-fly
based on environment variables. You can migrate to a static config file by
running with 'migrate_config'. See the README for more details.
"""
+ % (config_path,)
+ )
+
+ error(
+ "Config file '%s' does not exist. You should either create a new "
+ "config file by running with the `generate` argument (and then edit "
+ "the resulting file before restarting) or specify the path to an "
+ "existing config file with the SYNAPSE_CONFIG_PATH variable."
% (config_path,)
)
- error(
- "Config file '%s' does not exist. You should either create a new "
- "config file by running with the `generate` argument (and then edit "
- "the resulting file before restarting) or specify the path to an "
- "existing config file with the SYNAPSE_CONFIG_PATH variable."
- % (config_path,)
- )
+ args += ["--config-path", config_path]
- log("Starting synapse with config file " + config_path)
+ log("Starting synapse with args " + " ".join(args))
- args = ["python", "-m", synapse_worker, "--config-path", config_path]
+ args = ["python"] + args
if ownership is not None:
args = ["gosu", ownership] + args
os.execv("/usr/sbin/gosu", args)
diff --git a/docs/code_style.md b/docs/code_style.md
index 6ef6f80290..f6c825d7d4 100644
--- a/docs/code_style.md
+++ b/docs/code_style.md
@@ -64,8 +64,6 @@ save as it takes a while and is very resource intensive.
- Use underscores for functions and variables.
- **Docstrings**: should follow the [google code
style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings).
- This is so that we can generate documentation with
- [sphinx](http://sphinxcontrib-napoleon.readthedocs.org/en/latest/).
See the
[examples](http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html)
in the sphinx documentation.
diff --git a/docs/manhole.md b/docs/manhole.md
index 7375f5ad46..75b6ae40e0 100644
--- a/docs/manhole.md
+++ b/docs/manhole.md
@@ -35,9 +35,12 @@ This gives a Python REPL in which `hs` gives access to the
`synapse.server.HomeServer` object - which in turn gives access to many other
parts of the process.
+Note that any call which returns a coroutine will need to be wrapped in `ensureDeferred`.
+
As a simple example, retrieving an event from the database:
-```
->>> hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org')
+```pycon
+>>> from twisted.internet import defer
+>>> defer.ensureDeferred(hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org'))
<Deferred at 0x7ff253fc6998 current result: <FrozenEvent event_id='$1416420717069yeQaw:matrix.org', type='m.room.create', state_key=''>>
```
diff --git a/docs/openid.md b/docs/openid.md
index 70b37f858b..4873681999 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -238,13 +238,36 @@ Synapse config:
```yaml
oidc_config:
- enabled: true
- issuer: "https://id.twitch.tv/oauth2/"
- client_id: "your-client-id" # TO BE FILLED
- client_secret: "your-client-secret" # TO BE FILLED
- client_auth_method: "client_secret_post"
- user_mapping_provider:
- config:
- localpart_template: '{{ user.preferred_username }}'
- display_name_template: '{{ user.name }}'
+ enabled: true
+ issuer: "https://id.twitch.tv/oauth2/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ client_auth_method: "client_secret_post"
+ user_mapping_provider:
+ config:
+ localpart_template: "{{ user.preferred_username }}"
+ display_name_template: "{{ user.name }}"
+```
+
+### GitLab
+
+1. Create a [new application](https://gitlab.com/profile/applications).
+2. Add the `read_user` and `openid` scopes.
+3. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback`
+
+Synapse config:
+
+```yaml
+oidc_config:
+ enabled: true
+ issuer: "https://gitlab.com/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ client_auth_method: "client_secret_post"
+ scopes: ["openid", "read_user"]
+ user_profile_method: "userinfo_endpoint"
+ user_mapping_provider:
+ config:
+ localpart_template: '{{ user.nickname }}'
+ display_name_template: '{{ user.name }}'
```
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index 46d8f35771..c7020f2df3 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -54,7 +54,7 @@ server {
proxy_set_header X-Forwarded-For $remote_addr;
# Nginx by default only allows file uploads up to 1M in size
# Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
- client_max_body_size 10M;
+ client_max_body_size 50M;
}
}
```
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 8a3206e845..061226ea6f 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -119,7 +119,7 @@ pid_file: DATADIR/homeserver.pid
# For example, for room version 1, default_room_version should be set
# to "1".
#
-#default_room_version: "5"
+#default_room_version: "6"
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
#
@@ -893,7 +893,7 @@ media_store_path: "DATADIR/media_store"
# The largest allowed upload size in bytes
#
-#max_upload_size: 10M
+#max_upload_size: 50M
# Maximum number of pixels that will be thumbnailed
#
@@ -1714,6 +1714,14 @@ oidc_config:
#
#skip_verification: true
+ # Whether to fetch the user profile from the userinfo endpoint. Valid
+ # values are: "auto" or "userinfo_endpoint".
+ #
+ # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
+ # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
+ #
+ #user_profile_method: "userinfo_endpoint"
+
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
#
@@ -2496,6 +2504,11 @@ opentracing:
# events: worker1
# typing: worker1
+# The worker that is used to run background tasks (e.g. cleaning up expired
+# data). If not provided this defaults to the main process.
+#
+#run_background_tasks_on: worker1
+
# Configuration for Redis when using workers. This *must* be enabled when
# using workers (unless using old style direct TCP configuration).
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index eb10e115f9..7fc08f1b70 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -11,7 +11,7 @@ able to be imported by the running Synapse.
The Python class is instantiated with two objects:
* Any configuration (see below).
-* An instance of `synapse.spam_checker_api.SpamCheckerApi`.
+* An instance of `synapse.module_api.ModuleApi`.
It then implements methods which return a boolean to alter behavior in Synapse.
@@ -26,11 +26,8 @@ well as some specific methods:
The details of the each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class.
-The `SpamCheckerApi` class provides a way for the custom spam checker class to
-call back into the homeserver internals. It currently implements the following
-methods:
-
-* `get_state_events_in_room`
+The `ModuleApi` class provides a way for the custom spam checker class to
+call back into the homeserver internals.
### Example
diff --git a/docs/sphinx/README.rst b/docs/sphinx/README.rst
deleted file mode 100644
index a7ab7c5500..0000000000
--- a/docs/sphinx/README.rst
+++ /dev/null
@@ -1 +0,0 @@
-TODO: how (if at all) is this actually maintained?
diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py
deleted file mode 100644
index ca4b879526..0000000000
--- a/docs/sphinx/conf.py
+++ /dev/null
@@ -1,271 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Synapse documentation build configuration file, created by
-# sphinx-quickstart on Tue Jun 10 17:31:02 2014.
-#
-# This file is execfile()d with the current directory set to its
-# containing dir.
-#
-# Note that not all possible configuration values are present in this
-# autogenerated file.
-#
-# All configuration values have a default; values that are commented out
-# serve to show the default.
-
-import sys
-import os
-
-# If extensions (or modules to document with autodoc) are in another directory,
-# add these directories to sys.path here. If the directory is relative to the
-# documentation root, use os.path.abspath to make it absolute, like shown here.
-sys.path.insert(0, os.path.abspath(".."))
-
-# -- General configuration ------------------------------------------------
-
-# If your documentation needs a minimal Sphinx version, state it here.
-# needs_sphinx = '1.0'
-
-# Add any Sphinx extension module names here, as strings. They can be
-# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
-# ones.
-extensions = [
- "sphinx.ext.autodoc",
- "sphinx.ext.intersphinx",
- "sphinx.ext.coverage",
- "sphinx.ext.ifconfig",
- "sphinxcontrib.napoleon",
-]
-
-# Add any paths that contain templates here, relative to this directory.
-templates_path = ["_templates"]
-
-# The suffix of source filenames.
-source_suffix = ".rst"
-
-# The encoding of source files.
-# source_encoding = 'utf-8-sig'
-
-# The master toctree document.
-master_doc = "index"
-
-# General information about the project.
-project = "Synapse"
-copyright = (
- "Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd"
-)
-
-# The version info for the project you're documenting, acts as replacement for
-# |version| and |release|, also used in various other places throughout the
-# built documents.
-#
-# The short X.Y version.
-version = "1.0"
-# The full version, including alpha/beta/rc tags.
-release = "1.0"
-
-# The language for content autogenerated by Sphinx. Refer to documentation
-# for a list of supported languages.
-# language = None
-
-# There are two options for replacing |today|: either, you set today to some
-# non-false value, then it is used:
-# today = ''
-# Else, today_fmt is used as the format for a strftime call.
-# today_fmt = '%B %d, %Y'
-
-# List of patterns, relative to source directory, that match files and
-# directories to ignore when looking for source files.
-exclude_patterns = ["_build"]
-
-# The reST default role (used for this markup: `text`) to use for all
-# documents.
-# default_role = None
-
-# If true, '()' will be appended to :func: etc. cross-reference text.
-# add_function_parentheses = True
-
-# If true, the current module name will be prepended to all description
-# unit titles (such as .. function::).
-# add_module_names = True
-
-# If true, sectionauthor and moduleauthor directives will be shown in the
-# output. They are ignored by default.
-# show_authors = False
-
-# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = "sphinx"
-
-# A list of ignored prefixes for module index sorting.
-# modindex_common_prefix = []
-
-# If true, keep warnings as "system message" paragraphs in the built documents.
-# keep_warnings = False
-
-
-# -- Options for HTML output ----------------------------------------------
-
-# The theme to use for HTML and HTML Help pages. See the documentation for
-# a list of builtin themes.
-html_theme = "default"
-
-# Theme options are theme-specific and customize the look and feel of a theme
-# further. For a list of options available for each theme, see the
-# documentation.
-# html_theme_options = {}
-
-# Add any paths that contain custom themes here, relative to this directory.
-# html_theme_path = []
-
-# The name for this set of Sphinx documents. If None, it defaults to
-# "<project> v<release> documentation".
-# html_title = None
-
-# A shorter title for the navigation bar. Default is the same as html_title.
-# html_short_title = None
-
-# The name of an image file (relative to this directory) to place at the top
-# of the sidebar.
-# html_logo = None
-
-# The name of an image file (within the static path) to use as favicon of the
-# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
-# pixels large.
-# html_favicon = None
-
-# Add any paths that contain custom static files (such as style sheets) here,
-# relative to this directory. They are copied after the builtin static files,
-# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ["_static"]
-
-# Add any extra paths that contain custom files (such as robots.txt or
-# .htaccess) here, relative to this directory. These files are copied
-# directly to the root of the documentation.
-# html_extra_path = []
-
-# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
-# using the given strftime format.
-# html_last_updated_fmt = '%b %d, %Y'
-
-# If true, SmartyPants will be used to convert quotes and dashes to
-# typographically correct entities.
-# html_use_smartypants = True
-
-# Custom sidebar templates, maps document names to template names.
-# html_sidebars = {}
-
-# Additional templates that should be rendered to pages, maps page names to
-# template names.
-# html_additional_pages = {}
-
-# If false, no module index is generated.
-# html_domain_indices = True
-
-# If false, no index is generated.
-# html_use_index = True
-
-# If true, the index is split into individual pages for each letter.
-# html_split_index = False
-
-# If true, links to the reST sources are added to the pages.
-# html_show_sourcelink = True
-
-# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
-# html_show_sphinx = True
-
-# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
-# html_show_copyright = True
-
-# If true, an OpenSearch description file will be output, and all pages will
-# contain a <link> tag referring to it. The value of this option must be the
-# base URL from which the finished HTML is served.
-# html_use_opensearch = ''
-
-# This is the file name suffix for HTML files (e.g. ".xhtml").
-# html_file_suffix = None
-
-# Output file base name for HTML help builder.
-htmlhelp_basename = "Synapsedoc"
-
-
-# -- Options for LaTeX output ---------------------------------------------
-
-latex_elements = {
- # The paper size ('letterpaper' or 'a4paper').
- #'papersize': 'letterpaper',
- # The font size ('10pt', '11pt' or '12pt').
- #'pointsize': '10pt',
- # Additional stuff for the LaTeX preamble.
- #'preamble': '',
-}
-
-# Grouping the document tree into LaTeX files. List of tuples
-# (source start file, target name, title,
-# author, documentclass [howto, manual, or own class]).
-latex_documents = [("index", "Synapse.tex", "Synapse Documentation", "TNG", "manual")]
-
-# The name of an image file (relative to this directory) to place at the top of
-# the title page.
-# latex_logo = None
-
-# For "manual" documents, if this is true, then toplevel headings are parts,
-# not chapters.
-# latex_use_parts = False
-
-# If true, show page references after internal links.
-# latex_show_pagerefs = False
-
-# If true, show URL addresses after external links.
-# latex_show_urls = False
-
-# Documents to append as an appendix to all manuals.
-# latex_appendices = []
-
-# If false, no module index is generated.
-# latex_domain_indices = True
-
-
-# -- Options for manual page output ---------------------------------------
-
-# One entry per manual page. List of tuples
-# (source start file, name, description, authors, manual section).
-man_pages = [("index", "synapse", "Synapse Documentation", ["TNG"], 1)]
-
-# If true, show URL addresses after external links.
-# man_show_urls = False
-
-
-# -- Options for Texinfo output -------------------------------------------
-
-# Grouping the document tree into Texinfo files. List of tuples
-# (source start file, target name, title, author,
-# dir menu entry, description, category)
-texinfo_documents = [
- (
- "index",
- "Synapse",
- "Synapse Documentation",
- "TNG",
- "Synapse",
- "One line description of project.",
- "Miscellaneous",
- )
-]
-
-# Documents to append as an appendix to all manuals.
-# texinfo_appendices = []
-
-# If false, no module index is generated.
-# texinfo_domain_indices = True
-
-# How to display URL addresses: 'footnote', 'no', or 'inline'.
-# texinfo_show_urls = 'footnote'
-
-# If true, do not generate a @detailmenu in the "Top" node's menu.
-# texinfo_no_detailmenu = False
-
-
-# Example configuration for intersphinx: refer to the Python standard library.
-intersphinx_mapping = {"http://docs.python.org/": None}
-
-napoleon_include_special_with_doc = True
-napoleon_use_ivar = True
diff --git a/docs/sphinx/index.rst b/docs/sphinx/index.rst
deleted file mode 100644
index 76a4c0c7bf..0000000000
--- a/docs/sphinx/index.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-.. Synapse documentation master file, created by
- sphinx-quickstart on Tue Jun 10 17:31:02 2014.
- You can adapt this file completely to your liking, but it should at least
- contain the root `toctree` directive.
-
-Welcome to Synapse's documentation!
-===================================
-
-Contents:
-
-.. toctree::
- synapse
-
-Indices and tables
-==================
-
-* :ref:`genindex`
-* :ref:`modindex`
-* :ref:`search`
-
diff --git a/docs/sphinx/modules.rst b/docs/sphinx/modules.rst
deleted file mode 100644
index 1c7f70bd13..0000000000
--- a/docs/sphinx/modules.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse
-=======
-
-.. toctree::
- :maxdepth: 4
-
- synapse
diff --git a/docs/sphinx/synapse.api.auth.rst b/docs/sphinx/synapse.api.auth.rst
deleted file mode 100644
index 931eb59836..0000000000
--- a/docs/sphinx/synapse.api.auth.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.auth module
-=======================
-
-.. automodule:: synapse.api.auth
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.constants.rst b/docs/sphinx/synapse.api.constants.rst
deleted file mode 100644
index a1e3c47f68..0000000000
--- a/docs/sphinx/synapse.api.constants.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.constants module
-============================
-
-.. automodule:: synapse.api.constants
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.dbobjects.rst b/docs/sphinx/synapse.api.dbobjects.rst
deleted file mode 100644
index e9d31167e0..0000000000
--- a/docs/sphinx/synapse.api.dbobjects.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.dbobjects module
-============================
-
-.. automodule:: synapse.api.dbobjects
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.errors.rst b/docs/sphinx/synapse.api.errors.rst
deleted file mode 100644
index f1c6881478..0000000000
--- a/docs/sphinx/synapse.api.errors.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.errors module
-=========================
-
-.. automodule:: synapse.api.errors
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.event_stream.rst b/docs/sphinx/synapse.api.event_stream.rst
deleted file mode 100644
index 9291cb2dbc..0000000000
--- a/docs/sphinx/synapse.api.event_stream.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.event_stream module
-===============================
-
-.. automodule:: synapse.api.event_stream
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.events.factory.rst b/docs/sphinx/synapse.api.events.factory.rst
deleted file mode 100644
index 2e71ff6070..0000000000
--- a/docs/sphinx/synapse.api.events.factory.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.events.factory module
-=================================
-
-.. automodule:: synapse.api.events.factory
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.events.room.rst b/docs/sphinx/synapse.api.events.room.rst
deleted file mode 100644
index 6cd5998599..0000000000
--- a/docs/sphinx/synapse.api.events.room.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.events.room module
-==============================
-
-.. automodule:: synapse.api.events.room
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.events.rst b/docs/sphinx/synapse.api.events.rst
deleted file mode 100644
index b762da55ee..0000000000
--- a/docs/sphinx/synapse.api.events.rst
+++ /dev/null
@@ -1,18 +0,0 @@
-synapse.api.events package
-==========================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.api.events.factory
- synapse.api.events.room
-
-Module contents
----------------
-
-.. automodule:: synapse.api.events
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.handlers.events.rst b/docs/sphinx/synapse.api.handlers.events.rst
deleted file mode 100644
index d2e1b54ac0..0000000000
--- a/docs/sphinx/synapse.api.handlers.events.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.handlers.events module
-==================================
-
-.. automodule:: synapse.api.handlers.events
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.handlers.factory.rst b/docs/sphinx/synapse.api.handlers.factory.rst
deleted file mode 100644
index b04a93f740..0000000000
--- a/docs/sphinx/synapse.api.handlers.factory.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.handlers.factory module
-===================================
-
-.. automodule:: synapse.api.handlers.factory
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.handlers.federation.rst b/docs/sphinx/synapse.api.handlers.federation.rst
deleted file mode 100644
index 61a6542210..0000000000
--- a/docs/sphinx/synapse.api.handlers.federation.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.handlers.federation module
-======================================
-
-.. automodule:: synapse.api.handlers.federation
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.handlers.register.rst b/docs/sphinx/synapse.api.handlers.register.rst
deleted file mode 100644
index 388f144eca..0000000000
--- a/docs/sphinx/synapse.api.handlers.register.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.handlers.register module
-====================================
-
-.. automodule:: synapse.api.handlers.register
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.handlers.room.rst b/docs/sphinx/synapse.api.handlers.room.rst
deleted file mode 100644
index 8ca156c7ff..0000000000
--- a/docs/sphinx/synapse.api.handlers.room.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.handlers.room module
-================================
-
-.. automodule:: synapse.api.handlers.room
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.handlers.rst b/docs/sphinx/synapse.api.handlers.rst
deleted file mode 100644
index e84f563fcb..0000000000
--- a/docs/sphinx/synapse.api.handlers.rst
+++ /dev/null
@@ -1,21 +0,0 @@
-synapse.api.handlers package
-============================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.api.handlers.events
- synapse.api.handlers.factory
- synapse.api.handlers.federation
- synapse.api.handlers.register
- synapse.api.handlers.room
-
-Module contents
----------------
-
-.. automodule:: synapse.api.handlers
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.notifier.rst b/docs/sphinx/synapse.api.notifier.rst
deleted file mode 100644
index 631b42a497..0000000000
--- a/docs/sphinx/synapse.api.notifier.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.notifier module
-===========================
-
-.. automodule:: synapse.api.notifier
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.register_events.rst b/docs/sphinx/synapse.api.register_events.rst
deleted file mode 100644
index 79ad4ce211..0000000000
--- a/docs/sphinx/synapse.api.register_events.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.register_events module
-==================================
-
-.. automodule:: synapse.api.register_events
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.room_events.rst b/docs/sphinx/synapse.api.room_events.rst
deleted file mode 100644
index bead1711f5..0000000000
--- a/docs/sphinx/synapse.api.room_events.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.room_events module
-==============================
-
-.. automodule:: synapse.api.room_events
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.rst b/docs/sphinx/synapse.api.rst
deleted file mode 100644
index f4d39ff331..0000000000
--- a/docs/sphinx/synapse.api.rst
+++ /dev/null
@@ -1,30 +0,0 @@
-synapse.api package
-===================
-
-Subpackages
------------
-
-.. toctree::
-
- synapse.api.events
- synapse.api.handlers
- synapse.api.streams
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.api.auth
- synapse.api.constants
- synapse.api.errors
- synapse.api.notifier
- synapse.api.storage
-
-Module contents
----------------
-
-.. automodule:: synapse.api
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.server.rst b/docs/sphinx/synapse.api.server.rst
deleted file mode 100644
index b01600235e..0000000000
--- a/docs/sphinx/synapse.api.server.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.server module
-=========================
-
-.. automodule:: synapse.api.server
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.storage.rst b/docs/sphinx/synapse.api.storage.rst
deleted file mode 100644
index afa40685c4..0000000000
--- a/docs/sphinx/synapse.api.storage.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.storage module
-==========================
-
-.. automodule:: synapse.api.storage
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.stream.rst b/docs/sphinx/synapse.api.stream.rst
deleted file mode 100644
index 0d5e3f01bf..0000000000
--- a/docs/sphinx/synapse.api.stream.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.stream module
-=========================
-
-.. automodule:: synapse.api.stream
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.streams.event.rst b/docs/sphinx/synapse.api.streams.event.rst
deleted file mode 100644
index 2ac45a35c8..0000000000
--- a/docs/sphinx/synapse.api.streams.event.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.streams.event module
-================================
-
-.. automodule:: synapse.api.streams.event
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.streams.rst b/docs/sphinx/synapse.api.streams.rst
deleted file mode 100644
index 72eb205caf..0000000000
--- a/docs/sphinx/synapse.api.streams.rst
+++ /dev/null
@@ -1,17 +0,0 @@
-synapse.api.streams package
-===========================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.api.streams.event
-
-Module contents
----------------
-
-.. automodule:: synapse.api.streams
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.app.homeserver.rst b/docs/sphinx/synapse.app.homeserver.rst
deleted file mode 100644
index 54b93da8fe..0000000000
--- a/docs/sphinx/synapse.app.homeserver.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.app.homeserver module
-=============================
-
-.. automodule:: synapse.app.homeserver
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.app.rst b/docs/sphinx/synapse.app.rst
deleted file mode 100644
index 4535b79827..0000000000
--- a/docs/sphinx/synapse.app.rst
+++ /dev/null
@@ -1,17 +0,0 @@
-synapse.app package
-===================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.app.homeserver
-
-Module contents
----------------
-
-.. automodule:: synapse.app
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.db.rst b/docs/sphinx/synapse.db.rst
deleted file mode 100644
index 83df6c03db..0000000000
--- a/docs/sphinx/synapse.db.rst
+++ /dev/null
@@ -1,10 +0,0 @@
-synapse.db package
-==================
-
-Module contents
----------------
-
-.. automodule:: synapse.db
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.handler.rst b/docs/sphinx/synapse.federation.handler.rst
deleted file mode 100644
index 5597f5c46d..0000000000
--- a/docs/sphinx/synapse.federation.handler.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.handler module
-=================================
-
-.. automodule:: synapse.federation.handler
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.messaging.rst b/docs/sphinx/synapse.federation.messaging.rst
deleted file mode 100644
index 4bbaabf3ef..0000000000
--- a/docs/sphinx/synapse.federation.messaging.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.messaging module
-===================================
-
-.. automodule:: synapse.federation.messaging
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.pdu_codec.rst b/docs/sphinx/synapse.federation.pdu_codec.rst
deleted file mode 100644
index 8f0b15a63c..0000000000
--- a/docs/sphinx/synapse.federation.pdu_codec.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.pdu_codec module
-===================================
-
-.. automodule:: synapse.federation.pdu_codec
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.persistence.rst b/docs/sphinx/synapse.federation.persistence.rst
deleted file mode 100644
index db7ab8ade1..0000000000
--- a/docs/sphinx/synapse.federation.persistence.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.persistence module
-=====================================
-
-.. automodule:: synapse.federation.persistence
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.replication.rst b/docs/sphinx/synapse.federation.replication.rst
deleted file mode 100644
index 49e26e0928..0000000000
--- a/docs/sphinx/synapse.federation.replication.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.replication module
-=====================================
-
-.. automodule:: synapse.federation.replication
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.rst b/docs/sphinx/synapse.federation.rst
deleted file mode 100644
index 7240c7901b..0000000000
--- a/docs/sphinx/synapse.federation.rst
+++ /dev/null
@@ -1,22 +0,0 @@
-synapse.federation package
-==========================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.federation.handler
- synapse.federation.pdu_codec
- synapse.federation.persistence
- synapse.federation.replication
- synapse.federation.transport
- synapse.federation.units
-
-Module contents
----------------
-
-.. automodule:: synapse.federation
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.transport.rst b/docs/sphinx/synapse.federation.transport.rst
deleted file mode 100644
index 877956b3c9..0000000000
--- a/docs/sphinx/synapse.federation.transport.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.transport module
-===================================
-
-.. automodule:: synapse.federation.transport
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.units.rst b/docs/sphinx/synapse.federation.units.rst
deleted file mode 100644
index 8f9212b07d..0000000000
--- a/docs/sphinx/synapse.federation.units.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.units module
-===============================
-
-.. automodule:: synapse.federation.units
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.persistence.rst b/docs/sphinx/synapse.persistence.rst
deleted file mode 100644
index 37c0c23720..0000000000
--- a/docs/sphinx/synapse.persistence.rst
+++ /dev/null
@@ -1,19 +0,0 @@
-synapse.persistence package
-===========================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.persistence.service
- synapse.persistence.tables
- synapse.persistence.transactions
-
-Module contents
----------------
-
-.. automodule:: synapse.persistence
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.persistence.service.rst b/docs/sphinx/synapse.persistence.service.rst
deleted file mode 100644
index 3514d3c76f..0000000000
--- a/docs/sphinx/synapse.persistence.service.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.persistence.service module
-==================================
-
-.. automodule:: synapse.persistence.service
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.persistence.tables.rst b/docs/sphinx/synapse.persistence.tables.rst
deleted file mode 100644
index 907b02769d..0000000000
--- a/docs/sphinx/synapse.persistence.tables.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.persistence.tables module
-=================================
-
-.. automodule:: synapse.persistence.tables
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.persistence.transactions.rst b/docs/sphinx/synapse.persistence.transactions.rst
deleted file mode 100644
index 475c02a8c5..0000000000
--- a/docs/sphinx/synapse.persistence.transactions.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.persistence.transactions module
-=======================================
-
-.. automodule:: synapse.persistence.transactions
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.rest.base.rst b/docs/sphinx/synapse.rest.base.rst
deleted file mode 100644
index 84d2d9b31d..0000000000
--- a/docs/sphinx/synapse.rest.base.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.rest.base module
-========================
-
-.. automodule:: synapse.rest.base
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.rest.events.rst b/docs/sphinx/synapse.rest.events.rst
deleted file mode 100644
index ebbe26c746..0000000000
--- a/docs/sphinx/synapse.rest.events.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.rest.events module
-==========================
-
-.. automodule:: synapse.rest.events
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.rest.register.rst b/docs/sphinx/synapse.rest.register.rst
deleted file mode 100644
index a4a48a8a8f..0000000000
--- a/docs/sphinx/synapse.rest.register.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.rest.register module
-============================
-
-.. automodule:: synapse.rest.register
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.rest.room.rst b/docs/sphinx/synapse.rest.room.rst
deleted file mode 100644
index 63fc5c2840..0000000000
--- a/docs/sphinx/synapse.rest.room.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.rest.room module
-========================
-
-.. automodule:: synapse.rest.room
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.rest.rst b/docs/sphinx/synapse.rest.rst
deleted file mode 100644
index 016af926b2..0000000000
--- a/docs/sphinx/synapse.rest.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-synapse.rest package
-====================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.rest.base
- synapse.rest.events
- synapse.rest.register
- synapse.rest.room
-
-Module contents
----------------
-
-.. automodule:: synapse.rest
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.rst b/docs/sphinx/synapse.rst
deleted file mode 100644
index e7869e0e5d..0000000000
--- a/docs/sphinx/synapse.rst
+++ /dev/null
@@ -1,30 +0,0 @@
-synapse package
-===============
-
-Subpackages
------------
-
-.. toctree::
-
- synapse.api
- synapse.app
- synapse.federation
- synapse.persistence
- synapse.rest
- synapse.util
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.server
- synapse.state
-
-Module contents
----------------
-
-.. automodule:: synapse
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.server.rst b/docs/sphinx/synapse.server.rst
deleted file mode 100644
index 7f33f084d7..0000000000
--- a/docs/sphinx/synapse.server.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.server module
-=====================
-
-.. automodule:: synapse.server
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.state.rst b/docs/sphinx/synapse.state.rst
deleted file mode 100644
index 744be2a8be..0000000000
--- a/docs/sphinx/synapse.state.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.state module
-====================
-
-.. automodule:: synapse.state
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.async.rst b/docs/sphinx/synapse.util.async.rst
deleted file mode 100644
index 542bb54444..0000000000
--- a/docs/sphinx/synapse.util.async.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.util.async module
-=========================
-
-.. automodule:: synapse.util.async
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.dbutils.rst b/docs/sphinx/synapse.util.dbutils.rst
deleted file mode 100644
index afaa9eb749..0000000000
--- a/docs/sphinx/synapse.util.dbutils.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.util.dbutils module
-===========================
-
-.. automodule:: synapse.util.dbutils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.http.rst b/docs/sphinx/synapse.util.http.rst
deleted file mode 100644
index 344af5a490..0000000000
--- a/docs/sphinx/synapse.util.http.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.util.http module
-========================
-
-.. automodule:: synapse.util.http
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.lockutils.rst b/docs/sphinx/synapse.util.lockutils.rst
deleted file mode 100644
index 16ee26cabd..0000000000
--- a/docs/sphinx/synapse.util.lockutils.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.util.lockutils module
-=============================
-
-.. automodule:: synapse.util.lockutils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.logutils.rst b/docs/sphinx/synapse.util.logutils.rst
deleted file mode 100644
index 2b79fa7a4b..0000000000
--- a/docs/sphinx/synapse.util.logutils.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.util.logutils module
-============================
-
-.. automodule:: synapse.util.logutils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.rst b/docs/sphinx/synapse.util.rst
deleted file mode 100644
index 01a0c3a591..0000000000
--- a/docs/sphinx/synapse.util.rst
+++ /dev/null
@@ -1,21 +0,0 @@
-synapse.util package
-====================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.util.async
- synapse.util.http
- synapse.util.lockutils
- synapse.util.logutils
- synapse.util.stringutils
-
-Module contents
----------------
-
-.. automodule:: synapse.util
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.stringutils.rst b/docs/sphinx/synapse.util.stringutils.rst
deleted file mode 100644
index ec626eee28..0000000000
--- a/docs/sphinx/synapse.util.stringutils.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.util.stringutils module
-===============================
-
-.. automodule:: synapse.util.stringutils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index db318baa9d..ad145439b4 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -15,7 +15,7 @@ example flow would be (where '>' indicates master to worker and
> SERVER example.com
< REPLICATE
- > POSITION events master 53
+ > POSITION events master 53 53
> RDATA events master 54 ["$foo1:bar.com", ...]
> RDATA events master 55 ["$foo4:bar.com", ...]
@@ -138,9 +138,9 @@ the wire:
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE
- > POSITION events master 1
- > POSITION backfill master 1
- > POSITION caches master 1
+ > POSITION events master 1 1
+ > POSITION backfill master 1 1
+ > POSITION caches master 1 1
> RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
> RDATA events master 14 ["$149019767112vOHxz:localhost:8823",
"!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]
@@ -185,6 +185,11 @@ client (C):
updates via HTTP API, rather than via the DB, then processes should make the
request to the appropriate process.
+ Two positions are included, the "new" position and the last position sent respectively.
+ This allows servers to tell instances that the positions have advanced but no
+ data has been written, without clients needlessly checking to see if they
+ have missed any updates.
+
#### ERROR (S, C)
There was an error
diff --git a/docs/workers.md b/docs/workers.md
index ad4d8ca9f2..84a9759e34 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -319,6 +319,23 @@ stream_writers:
events: event_persister1
```
+#### Background tasks
+
+There is also *experimental* support for moving background tasks to a separate
+worker. Background tasks are run periodically or started via replication. Exactly
+which tasks are configured to run depends on your Synapse configuration (e.g. if
+stats is enabled).
+
+To enable this, the worker must have a `worker_name` and can be configured to run
+background tasks. For example, to move background tasks to a dedicated worker,
+the shared configuration would include:
+
+```yaml
+run_background_tasks_on: background_worker
+```
+
+You might also wish to investigate the `update_user_directory` and
+`media_instance_running_background_jobs` settings.
### `synapse.app.pusher`
diff --git a/mypy.ini b/mypy.ini
index 7986781432..f08fe992a4 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -6,6 +6,7 @@ check_untyped_defs = True
show_error_codes = True
show_traceback = True
mypy_path = stubs
+warn_unreachable = True
files =
synapse/api,
synapse/appservice,
@@ -14,8 +15,12 @@ files =
synapse/events/builder.py,
synapse/events/spamcheck.py,
synapse/federation,
+ synapse/handlers/account_data.py,
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
+ synapse/handlers/deactivate_account.py,
+ synapse/handlers/device.py,
+ synapse/handlers/devicemessage.py,
synapse/handlers/directory.py,
synapse/handlers/events.py,
synapse/handlers/federation.py,
@@ -24,7 +29,9 @@ files =
synapse/handlers/message.py,
synapse/handlers/oidc_handler.py,
synapse/handlers/pagination.py,
+ synapse/handlers/password_policy.py,
synapse/handlers/presence.py,
+ synapse/handlers/read_marker.py,
synapse/handlers/room.py,
synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py,
@@ -58,6 +65,7 @@ files =
synapse/types.py,
synapse/util/async_helpers.py,
synapse/util/caches/descriptors.py,
+ synapse/util/caches/response_cache.py,
synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py,
tests/replication,
@@ -142,3 +150,6 @@ ignore_missing_imports = True
[mypy-nacl.*]
ignore_missing_imports = True
+
+[mypy-hiredis]
+ignore_missing_imports = True
diff --git a/scripts-dev/sphinx_api_docs.sh b/scripts-dev/sphinx_api_docs.sh
deleted file mode 100644
index ee72b29657..0000000000
--- a/scripts-dev/sphinx_api_docs.sh
+++ /dev/null
@@ -1 +0,0 @@
-sphinx-apidoc -o docs/sphinx/ synapse/ -ef
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index ae2887b7d2..2d0b59ab53 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -90,6 +90,7 @@ BOOLEAN_COLUMNS = {
"room_stats_state": ["is_federatable"],
"local_media_repository": ["safe_from_quarantine"],
"users": ["shadow_banned"],
+ "e2e_fallback_keys_json": ["used"],
}
@@ -489,7 +490,7 @@ class Porter(object):
hs = MockHomeserver(self.hs_config)
- with make_conn(db_config, engine) as db_conn:
+ with make_conn(db_config, engine, "portdb") as db_conn:
engine.check_database(
db_conn, allow_outdated_version=allow_outdated_version
)
diff --git a/setup.cfg b/setup.cfg
index a32278ea8a..f46e43fad0 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,8 +1,3 @@
-[build_sphinx]
-source-dir = docs/sphinx
-build-dir = docs/build
-all_files = 1
-
[trial]
test_suite = tests
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index c66413f003..522244bb57 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -16,7 +16,7 @@
"""Contains *incomplete* type hints for txredisapi.
"""
-from typing import List, Optional, Union
+from typing import List, Optional, Union, Type
class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
@@ -42,3 +42,21 @@ def lazyConnection(
class SubscriberFactory:
def buildProtocol(self, addr): ...
+
+class ConnectionHandler: ...
+
+class RedisFactory:
+ continueTrying: bool
+ handler: RedisProtocol
+ def __init__(
+ self,
+ uuid: str,
+ dbid: Optional[int],
+ poolsize: int,
+ isLazy: bool = False,
+ handler: Type = ConnectionHandler,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ replyTimeout: Optional[int] = None,
+ convertNumbers: Optional[int] = True,
+ ): ...
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 46013cde15..592abd844b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -155,3 +155,8 @@ class EventContentFields:
class RoomEncryptionAlgorithms:
MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
DEFAULT = MEGOLM_V1_AES_SHA2
+
+
+class AccountDataTypes:
+ DIRECT = "m.direct"
+ IGNORED_USER_LIST = "m.ignored_user_list"
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index fb476ddaf5..f6f7b2bf42 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -28,6 +28,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
from synapse.app import check_bind_error
+from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.server import ListenerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
@@ -271,9 +272,19 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start()
+ # Log when we start the shut down process.
+ hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", logger.info, "Shutting down..."
+ )
+
setup_sentry(hs)
setup_sdnotify(hs)
+ # If background tasks are running on the main process, start collecting the
+ # phone home stats.
+ if hs.config.run_background_tasks:
+ start_phone_stats_home(hs)
+
# We now freeze all allocated objects in the hopes that (almost)
# everything currently allocated are things that will be used for the
# rest of time. Doing so means less work each GC (hopefully).
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 7d309b1bb0..b4bd4d8e7a 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -89,7 +89,7 @@ async def export_data_command(hs, args):
user_id = args.user_id
directory = args.output_directory
- res = await hs.get_handlers().admin_handler.export_user_data(
+ res = await hs.get_admin_handler().export_user_data(
user_id, FileExfiltrationWriter(user_id, directory=directory)
)
print(res)
@@ -208,6 +208,7 @@ def start(config_options):
# Explicitly disable background processes
config.update_user_directory = False
+ config.run_background_tasks = False
config.start_pushers = False
config.send_federation = False
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index c38413c893..d53181deb1 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -127,12 +127,16 @@ from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer, cache_in_self
from synapse.storage.databases.main.censor_events import CensorEventsStore
+from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
from synapse.storage.databases.main.media_repository import MediaRepositoryStore
+from synapse.storage.databases.main.metrics import ServerMetricsStore
from synapse.storage.databases.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.databases.main.presence import UserPresenceState
from synapse.storage.databases.main.search import SearchWorkerStore
+from synapse.storage.databases.main.stats import StatsStore
+from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.databases.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
@@ -454,6 +458,7 @@ class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
+ StatsStore,
UIAuthWorkerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
@@ -463,6 +468,7 @@ class GenericWorkerSlavedStore(
SlavedAccountDataStore,
SlavedPusherStore,
CensorEventsStore,
+ ClientIpWorkerStore,
SlavedEventStore,
SlavedKeyStore,
RoomStore,
@@ -476,7 +482,9 @@ class GenericWorkerSlavedStore(
SlavedFilteringStore,
MonthlyActiveUsersWorkerStore,
MediaRepositoryStore,
+ ServerMetricsStore,
SearchWorkerStore,
+ TransactionWorkerStore,
BaseSlavedStore,
):
pass
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index dff739e106..2b5465417f 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -17,14 +17,10 @@
import gc
import logging
-import math
import os
-import resource
import sys
from typing import Iterable
-from prometheus_client import Gauge
-
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
@@ -60,8 +56,6 @@ from synapse.http.server import (
from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.module_api import ModuleApi
from synapse.python_dependencies import check_requirements
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -111,7 +105,7 @@ class SynapseHomeServer(HomeServer):
additional_resources = listener_config.http_options.additional_resources
logger.debug("Configuring additional resources: %r", additional_resources)
- module_api = ModuleApi(self, self.get_auth_handler())
+ module_api = self.get_module_api()
for path, resmodule in additional_resources.items():
handler_cls, config = load_module(resmodule)
handler = handler_cls(config, module_api)
@@ -334,20 +328,6 @@ class SynapseHomeServer(HomeServer):
logger.warning("Unrecognized listener type: %s", listener.type)
-# Gauges to expose monthly active user control metrics
-current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
-current_mau_by_service_gauge = Gauge(
- "synapse_admin_mau_current_mau_by_service",
- "Current MAU by service",
- ["app_service"],
-)
-max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
-registered_reserved_users_mau_gauge = Gauge(
- "synapse_admin_mau:registered_reserved_users",
- "Registered users with reserved threepids",
-)
-
-
def setup(config_options):
"""
Args:
@@ -389,8 +369,6 @@ def setup(config_options):
except UpgradeDatabaseException as e:
quit_with_error("Failed to upgrade database: %s" % (e,))
- hs.setup_master()
-
async def do_acme() -> bool:
"""
Reprovision an ACME certificate, if it's required.
@@ -486,92 +464,6 @@ class SynapseService(service.Service):
return self._port.stopListening()
-# Contains the list of processes we will be monitoring
-# currently either 0 or 1
-_stats_process = []
-
-
-async def phone_stats_home(hs, stats, stats_process=_stats_process):
- logger.info("Gathering stats for reporting")
- now = int(hs.get_clock().time())
- uptime = int(now - hs.start_time)
- if uptime < 0:
- uptime = 0
-
- #
- # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test.
- #
- old = stats_process[0]
- new = (now, resource.getrusage(resource.RUSAGE_SELF))
- stats_process[0] = new
-
- # Get RSS in bytes
- stats["memory_rss"] = new[1].ru_maxrss
-
- # Get CPU time in % of a single core, not % of all cores
- used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
- old[1].ru_utime + old[1].ru_stime
- )
- if used_cpu_time == 0 or new[0] == old[0]:
- stats["cpu_average"] = 0
- else:
- stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
-
- #
- # General statistics
- #
-
- stats["homeserver"] = hs.config.server_name
- stats["server_context"] = hs.config.server_context
- stats["timestamp"] = now
- stats["uptime_seconds"] = uptime
- version = sys.version_info
- stats["python_version"] = "{}.{}.{}".format(
- version.major, version.minor, version.micro
- )
- stats["total_users"] = await hs.get_datastore().count_all_users()
-
- total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
- stats["total_nonbridged_users"] = total_nonbridged_users
-
- daily_user_type_results = await hs.get_datastore().count_daily_user_type()
- for name, count in daily_user_type_results.items():
- stats["daily_user_type_" + name] = count
-
- room_count = await hs.get_datastore().get_room_count()
- stats["total_room_count"] = room_count
-
- stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
- stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
- stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
- stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
-
- r30_results = await hs.get_datastore().count_r30_users()
- for name, count in r30_results.items():
- stats["r30_users_" + name] = count
-
- daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
- stats["daily_sent_messages"] = daily_sent_messages
- stats["cache_factor"] = hs.config.caches.global_factor
- stats["event_cache_size"] = hs.config.caches.event_cache_size
-
- #
- # Database version
- #
-
- # This only reports info about the *main* database.
- stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
- stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
-
- logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
- try:
- await hs.get_proxied_http_client().put_json(
- hs.config.report_stats_endpoint, stats
- )
- except Exception as e:
- logger.warning("Error reporting stats: %s", e)
-
-
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
@@ -597,81 +489,6 @@ def run(hs):
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
- clock = hs.get_clock()
-
- stats = {}
-
- def performance_stats_init():
- _stats_process.clear()
- _stats_process.append(
- (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
- )
-
- def start_phone_stats_home():
- return run_as_background_process(
- "phone_stats_home", phone_stats_home, hs, stats
- )
-
- def generate_user_daily_visit_stats():
- return run_as_background_process(
- "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits
- )
-
- # Rather than update on per session basis, batch up the requests.
- # If you increase the loop period, the accuracy of user_daily_visits
- # table will decrease
- clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
-
- # monthly active user limiting functionality
- def reap_monthly_active_users():
- return run_as_background_process(
- "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users
- )
-
- clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60)
- reap_monthly_active_users()
-
- async def generate_monthly_active_users():
- current_mau_count = 0
- current_mau_count_by_service = {}
- reserved_users = ()
- store = hs.get_datastore()
- if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
- current_mau_count = await store.get_monthly_active_count()
- current_mau_count_by_service = (
- await store.get_monthly_active_count_by_service()
- )
- reserved_users = await store.get_registered_reserved_users()
- current_mau_gauge.set(float(current_mau_count))
-
- for app_service, count in current_mau_count_by_service.items():
- current_mau_by_service_gauge.labels(app_service).set(float(count))
-
- registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
- max_mau_gauge.set(float(hs.config.max_mau_value))
-
- def start_generate_monthly_active_users():
- return run_as_background_process(
- "generate_monthly_active_users", generate_monthly_active_users
- )
-
- start_generate_monthly_active_users()
- if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
- clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000)
- # End of monthly active user settings
-
- if hs.config.report_stats:
- logger.info("Scheduling stats reporting for 3 hour intervals")
- clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
-
- # We need to defer this init for the cases that we daemonize
- # otherwise the process ID we get is that of the non-daemon process
- clock.call_later(0, performance_stats_init)
-
- # We wait 5 minutes to send the first set of stats as the server can
- # be quite busy the first few minutes
- clock.call_later(5 * 60, start_phone_stats_home)
-
_base.start_reactor(
"synapse-homeserver",
soft_file_limit=hs.config.soft_file_limit,
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
new file mode 100644
index 0000000000..c38cf8231f
--- /dev/null
+++ b/synapse/app/phone_stats_home.py
@@ -0,0 +1,190 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import math
+import resource
+import sys
+
+from prometheus_client import Gauge
+
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+
+logger = logging.getLogger("synapse.app.homeserver")
+
+# Contains the list of processes we will be monitoring
+# currently either 0 or 1
+_stats_process = []
+
+# Gauges to expose monthly active user control metrics
+current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
+current_mau_by_service_gauge = Gauge(
+ "synapse_admin_mau_current_mau_by_service",
+ "Current MAU by service",
+ ["app_service"],
+)
+max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
+registered_reserved_users_mau_gauge = Gauge(
+ "synapse_admin_mau:registered_reserved_users",
+ "Registered users with reserved threepids",
+)
+
+
+@wrap_as_background_process("phone_stats_home")
+async def phone_stats_home(hs, stats, stats_process=_stats_process):
+ logger.info("Gathering stats for reporting")
+ now = int(hs.get_clock().time())
+ uptime = int(now - hs.start_time)
+ if uptime < 0:
+ uptime = 0
+
+ #
+ # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test.
+ #
+ old = stats_process[0]
+ new = (now, resource.getrusage(resource.RUSAGE_SELF))
+ stats_process[0] = new
+
+ # Get RSS in bytes
+ stats["memory_rss"] = new[1].ru_maxrss
+
+ # Get CPU time in % of a single core, not % of all cores
+ used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
+ old[1].ru_utime + old[1].ru_stime
+ )
+ if used_cpu_time == 0 or new[0] == old[0]:
+ stats["cpu_average"] = 0
+ else:
+ stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
+
+ #
+ # General statistics
+ #
+
+ stats["homeserver"] = hs.config.server_name
+ stats["server_context"] = hs.config.server_context
+ stats["timestamp"] = now
+ stats["uptime_seconds"] = uptime
+ version = sys.version_info
+ stats["python_version"] = "{}.{}.{}".format(
+ version.major, version.minor, version.micro
+ )
+ stats["total_users"] = await hs.get_datastore().count_all_users()
+
+ total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
+ stats["total_nonbridged_users"] = total_nonbridged_users
+
+ daily_user_type_results = await hs.get_datastore().count_daily_user_type()
+ for name, count in daily_user_type_results.items():
+ stats["daily_user_type_" + name] = count
+
+ room_count = await hs.get_datastore().get_room_count()
+ stats["total_room_count"] = room_count
+
+ stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
+ stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+ stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
+ stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
+
+ r30_results = await hs.get_datastore().count_r30_users()
+ for name, count in r30_results.items():
+ stats["r30_users_" + name] = count
+
+ daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
+ stats["daily_sent_messages"] = daily_sent_messages
+ stats["cache_factor"] = hs.config.caches.global_factor
+ stats["event_cache_size"] = hs.config.caches.event_cache_size
+
+ #
+ # Database version
+ #
+
+ # This only reports info about the *main* database.
+ stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
+ stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
+
+ #
+ # Logging configuration
+ #
+ synapse_logger = logging.getLogger("synapse")
+ log_level = synapse_logger.getEffectiveLevel()
+ stats["log_level"] = logging.getLevelName(log_level)
+
+ logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
+ try:
+ await hs.get_proxied_http_client().put_json(
+ hs.config.report_stats_endpoint, stats
+ )
+ except Exception as e:
+ logger.warning("Error reporting stats: %s", e)
+
+
+def start_phone_stats_home(hs):
+ """
+ Start the background tasks which report phone home stats.
+ """
+ clock = hs.get_clock()
+
+ stats = {}
+
+ def performance_stats_init():
+ _stats_process.clear()
+ _stats_process.append(
+ (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
+ )
+
+ # Rather than update on per session basis, batch up the requests.
+ # If you increase the loop period, the accuracy of user_daily_visits
+ # table will decrease
+ clock.looping_call(hs.get_datastore().generate_user_daily_visits, 5 * 60 * 1000)
+
+ # monthly active user limiting functionality
+ clock.looping_call(hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60)
+ hs.get_datastore().reap_monthly_active_users()
+
+ @wrap_as_background_process("generate_monthly_active_users")
+ async def generate_monthly_active_users():
+ current_mau_count = 0
+ current_mau_count_by_service = {}
+ reserved_users = ()
+ store = hs.get_datastore()
+ if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
+ current_mau_count = await store.get_monthly_active_count()
+ current_mau_count_by_service = (
+ await store.get_monthly_active_count_by_service()
+ )
+ reserved_users = await store.get_registered_reserved_users()
+ current_mau_gauge.set(float(current_mau_count))
+
+ for app_service, count in current_mau_count_by_service.items():
+ current_mau_by_service_gauge.labels(app_service).set(float(count))
+
+ registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
+ max_mau_gauge.set(float(hs.config.max_mau_value))
+
+ if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
+ generate_monthly_active_users()
+ clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
+ # End of monthly active user settings
+
+ if hs.config.report_stats:
+ logger.info("Scheduling stats reporting for 3 hour intervals")
+ clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000, hs, stats)
+
+ # We need to defer this init for the cases that we daemonize
+ # otherwise the process ID we get is that of the non-daemon process
+ clock.call_later(0, performance_stats_init)
+
+ # We wait 5 minutes to send the first set of stats as the server can
+ # be quite busy the first few minutes
+ clock.call_later(5 * 60, phone_stats_home, hs, stats)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index c526c28b93..e8f0793795 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Tuple
from prometheus_client import Counter
@@ -93,7 +93,7 @@ class ApplicationServiceApi(SimpleHttpClient):
self.protocol_meta_cache = ResponseCache(
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
- )
+ ) # type: ResponseCache[Tuple[str, str]]
async def query_user(self, service, user_id):
if service.url is None:
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index f924116819..7597fbc864 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -56,6 +56,7 @@ class OIDCConfig(Config):
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
+ self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
ump_config = oidc_config.get("user_mapping_provider", {})
@@ -159,6 +160,14 @@ class OIDCConfig(Config):
#
#skip_verification: true
+ # Whether to fetch the user profile from the userinfo endpoint. Valid
+ # values are: "auto" or "userinfo_endpoint".
+ #
+ # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
+ # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
+ #
+ #user_profile_method: "userinfo_endpoint"
+
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
#
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 01009f3924..ba1e9d2361 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -100,7 +100,7 @@ class ContentRepositoryConfig(Config):
"media_instance_running_background_jobs",
)
- self.max_upload_size = self.parse_size(config.get("max_upload_size", "10M"))
+ self.max_upload_size = self.parse_size(config.get("max_upload_size", "50M"))
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
@@ -242,7 +242,7 @@ class ContentRepositoryConfig(Config):
# The largest allowed upload size in bytes
#
- #max_upload_size: 10M
+ #max_upload_size: 50M
# Maximum number of pixels that will be thumbnailed
#
diff --git a/synapse/config/server.py b/synapse/config/server.py
index ef6d70e3f8..85aa49c02d 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -39,7 +39,7 @@ logger = logging.Logger(__name__)
# in the list.
DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
-DEFAULT_ROOM_VERSION = "5"
+DEFAULT_ROOM_VERSION = "6"
ROOM_COMPLEXITY_TOO_GREAT = (
"Your homeserver is unable to join rooms this large or complex. "
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 9ddb8b546b..ad37b93c02 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -18,7 +18,7 @@ import os
import warnings
from datetime import datetime
from hashlib import sha256
-from typing import List
+from typing import List, Optional
from unpaddedbase64 import encode_base64
@@ -177,8 +177,8 @@ class TlsConfig(Config):
"use_insecure_ssl_client_just_for_testing_do_not_use"
)
- self.tls_certificate = None
- self.tls_private_key = None
+ self.tls_certificate = None # type: Optional[crypto.X509]
+ self.tls_private_key = None # type: Optional[crypto.PKey]
def is_disk_cert_valid(self, allow_self_signed=True):
"""
@@ -226,12 +226,12 @@ class TlsConfig(Config):
days_remaining = (expires_on - now).days
return days_remaining
- def read_certificate_from_disk(self, require_cert_and_key):
+ def read_certificate_from_disk(self, require_cert_and_key: bool):
"""
Read the certificates and private key from disk.
Args:
- require_cert_and_key (bool): set to True to throw an error if the certificate
+ require_cert_and_key: set to True to throw an error if the certificate
and key file are not given
"""
if require_cert_and_key:
@@ -479,13 +479,13 @@ class TlsConfig(Config):
}
)
- def read_tls_certificate(self):
+ def read_tls_certificate(self) -> crypto.X509:
"""Reads the TLS certificate from the configured file, and returns it
Also checks if it is self-signed, and warns if so
Returns:
- OpenSSL.crypto.X509: the certificate
+ The certificate
"""
cert_path = self.tls_certificate_file
logger.info("Loading TLS certificate from %s", cert_path)
@@ -504,11 +504,11 @@ class TlsConfig(Config):
return cert
- def read_tls_private_key(self):
+ def read_tls_private_key(self) -> crypto.PKey:
"""Reads the TLS private key from the configured file, and returns it
Returns:
- OpenSSL.crypto.PKey: the private key
+ The private key
"""
private_key_path = self.tls_private_key_file
logger.info("Loading TLS key from %s", private_key_path)
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index f23e42cdf9..57ab097eba 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -132,6 +132,19 @@ class WorkerConfig(Config):
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
+ # Whether this worker should run background tasks or not.
+ #
+ # As a note for developers, the background tasks guarded by this should
+ # be able to run on only a single instance (meaning that they don't
+ # depend on any in-memory state of a particular worker).
+ #
+ # No effort is made to ensure only a single instance of these tasks is
+ # running.
+ background_tasks_instance = config.get("run_background_tasks_on") or "master"
+ self.run_background_tasks = (
+ self.worker_name is None and background_tasks_instance == "master"
+ ) or self.worker_name == background_tasks_instance
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
## Workers ##
@@ -167,6 +180,11 @@ class WorkerConfig(Config):
#stream_writers:
# events: worker1
# typing: worker1
+
+ # The worker that is used to run background tasks (e.g. cleaning up expired
+ # data). If not provided this defaults to the main process.
+ #
+ #run_background_tasks_on: worker1
"""
def read_arguments(self, args):
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 8c907ad596..56f8dc9caf 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -446,6 +446,8 @@ def check_redaction(
if room_version_obj.event_format == EventFormatVersions.V1:
redacter_domain = get_domain_from_id(event.event_id)
+ if not isinstance(event.redacts, str):
+ return False
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index dc49df0812..7a51d0a22f 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -97,13 +97,16 @@ class DefaultDictProperty(DictProperty):
class _EventInternalMetadata:
- __slots__ = ["_dict"]
+ __slots__ = ["_dict", "stream_ordering"]
def __init__(self, internal_metadata_dict: JsonDict):
# we have to copy the dict, because it turns out that the same dict is
# reused. TODO: fix that
self._dict = dict(internal_metadata_dict)
+ # the stream ordering of this event. None, until it has been persisted.
+ self.stream_ordering = None # type: Optional[int]
+
outlier = DictProperty("outlier") # type: bool
out_of_band_membership = DictProperty("out_of_band_membership") # type: bool
send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str
@@ -113,7 +116,6 @@ class _EventInternalMetadata:
redacted = DictProperty("redacted") # type: bool
txn_id = DictProperty("txn_id") # type: str
token_id = DictProperty("token_id") # type: str
- stream_ordering = DictProperty("stream_ordering") # type: int
# XXX: These are set by StreamWorkerStore._set_before_and_after.
# I'm pretty sure that these are never persisted to the database, so shouldn't
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index b0fc859a47..bad18f7fdf 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -17,24 +17,25 @@
import inspect
from typing import Any, Dict, List, Optional, Tuple
-from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
MYPY = False
if MYPY:
+ import synapse.events
import synapse.server
class SpamChecker:
def __init__(self, hs: "synapse.server.HomeServer"):
self.spam_checkers = [] # type: List[Any]
+ api = hs.get_module_api()
for module, config in hs.config.spam_checkers:
# Older spam checkers don't accept the `api` argument, so we
# try and detect support.
spam_args = inspect.getfullargspec(module)
if "api" in spam_args.args:
- api = SpamCheckerApi(hs)
self.spam_checkers.append(module(config=config, api=api))
else:
self.spam_checkers.append(module(config=config))
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 9d5310851c..1535cc5339 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Callable
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import Requester
+from synapse.types import Requester, StateMap
class ThirdPartyEventRules:
@@ -38,7 +39,7 @@ class ThirdPartyEventRules:
if module is not None:
self.third_party_rules = module(
- config=config, http_client=hs.get_simple_http_client()
+ config=config, module_api=hs.get_module_api(),
)
async def check_event_allowed(
@@ -59,12 +60,14 @@ class ThirdPartyEventRules:
prev_state_ids = await context.get_prev_state_ids()
# Retrieve the state events from the database.
- state_events = {}
- for key, event_id in prev_state_ids.items():
- state_events[key] = await self.store.get_event(event_id, allow_none=True)
+ events = await self.store.get_events(prev_state_ids.values())
+ state_events = {(ev.type, ev.state_key): ev for ev in events.values()}
- ret = await self.third_party_rules.check_event_allowed(event, state_events)
- return ret
+ # The module can modify the event slightly if it wants, but caution should be
+ # exercised, and it's likely to go very wrong if applied to events received over
+ # federation.
+
+ return await self.third_party_rules.check_event_allowed(event, state_events)
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
@@ -106,6 +109,48 @@ class ThirdPartyEventRules:
if self.third_party_rules is None:
return True
+ state_events = await self._get_state_map_for_room(room_id)
+
+ ret = await self.third_party_rules.check_threepid_can_be_invited(
+ medium, address, state_events
+ )
+ return ret
+
+ async def check_visibility_can_be_modified(
+ self, room_id: str, new_visibility: str
+ ) -> bool:
+ """Check if a room is allowed to be published to, or removed from, the public room
+ list.
+
+ Args:
+ room_id: The ID of the room.
+ new_visibility: The new visibility state. Either "public" or "private".
+
+ Returns:
+ True if the room's visibility can be modified, False if not.
+ """
+ if self.third_party_rules is None:
+ return True
+
+ check_func = getattr(
+ self.third_party_rules, "check_visibility_can_be_modified", None
+ )
+ if not check_func or not isinstance(check_func, Callable):
+ return True
+
+ state_events = await self._get_state_map_for_room(room_id)
+
+ return await check_func(room_id, state_events, new_visibility)
+
+ async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
+ """Given a room ID, return the state events of that room.
+
+ Args:
+ room_id: The ID of the room.
+
+ Returns:
+ A dict mapping (event type, state key) to state event.
+ """
state_ids = await self.store.get_filtered_current_state_ids(room_id)
room_state_events = await self.store.get_events(state_ids.values())
@@ -113,7 +158,4 @@ class ThirdPartyEventRules:
for key, event_id in state_ids.items():
state_events[key] = room_state_events[event_id]
- ret = await self.third_party_rules.check_threepid_can_be_invited(
- medium, address, state_events
- )
- return ret
+ return state_events
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 32c73d3413..355cbe05f1 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -49,6 +49,11 @@ def prune_event(event: EventBase) -> EventBase:
pruned_event_dict, event.room_version, event.internal_metadata.get_dict()
)
+ # copy the internal fields
+ pruned_event.internal_metadata.stream_ordering = (
+ event.internal_metadata.stream_ordering
+ )
+
# Mark the event as redacted
pruned_event.internal_metadata.redacted = True
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 24329dd0e3..23278e36b7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,7 +22,6 @@ from typing import (
Callable,
Dict,
List,
- Match,
Optional,
Tuple,
Union,
@@ -100,10 +99,15 @@ class FederationServer(FederationBase):
super().__init__(hs)
self.auth = hs.get_auth()
- self.handler = hs.get_handlers().federation_handler
+ self.handler = hs.get_federation_handler()
self.state = hs.get_state_handler()
self.device_handler = hs.get_device_handler()
+
+ # Ensure the following handlers are loaded since they register callbacks
+ # with FederationHandlerRegistry.
+ hs.get_directory_handler()
+
self._federation_ratelimiter = hs.get_federation_ratelimiter()
self._server_linearizer = Linearizer("fed_server")
@@ -112,7 +116,7 @@ class FederationServer(FederationBase):
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
hs, "fed_txn_handler", timeout_ms=30000
- )
+ ) # type: ResponseCache[Tuple[str, str]]
self.transaction_actions = TransactionActions(self.store)
@@ -120,10 +124,12 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
- self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
+ self._state_resp_cache = ResponseCache(
+ hs, "state_resp", timeout_ms=30000
+ ) # type: ResponseCache[Tuple[str, str]]
self._state_ids_resp_cache = ResponseCache(
hs, "state_ids_resp", timeout_ms=30000
- )
+ ) # type: ResponseCache[Tuple[str, str]]
self._federation_metrics_domains = (
hs.get_config().federation.federation_metrics_domains
@@ -825,14 +831,14 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
return False
-def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
+def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
if not isinstance(acl_entry, str):
logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
return False
regex = glob_to_regex(acl_entry)
- return regex.match(server_name)
+ return bool(regex.match(server_name))
class FederationHandlerRegistry:
@@ -862,7 +868,7 @@ class FederationHandlerRegistry:
self._edu_type_to_instance = {} # type: Dict[str, str]
def register_edu_handler(
- self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]]
+ self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
):
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 8bb17b3a05..e33b29a42c 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -297,6 +297,8 @@ class FederationSender:
sent_pdus_destination_dist_total.inc(len(destinations))
sent_pdus_destination_dist_count.inc()
+ assert pdu.internal_metadata.stream_ordering
+
# track the fact that we have a PDU for these destinations,
# to allow us to perform catch-up later on if the remote is unreachable
# for a while.
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index bc99af3fdd..db8e456fe8 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -158,6 +158,7 @@ class PerDestinationQueue:
# yet know if we have anything to catch up (None)
self._pending_pdus.append(pdu)
else:
+ assert pdu.internal_metadata.stream_ordering
self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
self.attempt_new_transaction()
@@ -361,6 +362,7 @@ class PerDestinationQueue:
last_successful_stream_ordering = (
final_pdu.internal_metadata.stream_ordering
)
+ assert last_successful_stream_ordering
await self._store.set_destination_last_successful_stream_ordering(
self._destination, last_successful_stream_ordering
)
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 286f0054be..bfebb0f644 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -12,36 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from .admin import AdminHandler
-from .directory import DirectoryHandler
-from .federation import FederationHandler
-from .identity import IdentityHandler
-from .search import SearchHandler
-
-
-class Handlers:
-
- """ Deprecated. A collection of handlers.
-
- At some point most of the classes whose name ended "Handler" were
- accessed through this class.
-
- However this makes it painful to unit test the handlers and to run cut
- down versions of synapse that only use specific handlers because using a
- single handler required creating all of the handlers. So some of the
- handlers have been lifted out of the Handlers object and are now accessed
- directly through the homeserver object itself.
-
- Any new handlers should follow the new pattern of being accessed through
- the homeserver object and should not be added to the Handlers object.
-
- The remaining handlers should be moved out of the handlers object.
- """
-
- def __init__(self, hs):
- self.federation_handler = FederationHandler(hs)
- self.directory_handler = DirectoryHandler(hs)
- self.admin_handler = AdminHandler(hs)
- self.identity_handler = IdentityHandler(hs)
- self.search_handler = SearchHandler(hs)
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 9112a0ab86..341135822e 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -12,16 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, List, Tuple
+
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
class AccountDataEventSource:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- def get_current_key(self, direction="f"):
+ def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_account_data_stream_id()
- async def get_new_events(self, user, from_key, **kwargs):
+ async def get_new_events(
+ self, user: UserID, from_key: int, **kwargs
+ ) -> Tuple[List[JsonDict], int]:
user_id = user.to_string()
last_stream_id = from_key
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 00eae92052..1d1ddc2245 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -164,7 +164,14 @@ class AuthHandler(BaseHandler):
self.bcrypt_rounds = hs.config.bcrypt_rounds
+ # we can't use hs.get_module_api() here, because to do so will create an
+ # import loop.
+ #
+ # TODO: refactor this class to separate the lower-level stuff that
+ # ModuleApi can use from the higher-level stuff that uses ModuleApi, as
+ # better way to break the loop
account_handler = ModuleApi(hs, self)
+
self.password_providers = [
module(config=config, account_handler=account_handler)
for module, config in hs.config.password_providers
@@ -212,7 +219,7 @@ class AuthHandler(BaseHandler):
self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time.
- if hs.config.worker_app is None:
+ if hs.config.run_background_tasks:
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
@@ -1073,7 +1080,7 @@ class AuthHandler(BaseHandler):
if medium == "email":
address = canonicalise_email(address)
- identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler = self.hs.get_identity_handler()
result = await identity_handler.try_unbind_threepid(
user_id, {"medium": medium, "address": address, "id_server": id_server}
)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 0635ad5708..58c9f12686 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -22,19 +22,22 @@ from synapse.types import UserID, create_requester
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class DeactivateAccountHandler(BaseHandler):
"""Handler which deals with deactivating user accounts."""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
- self._identity_handler = hs.get_handlers().identity_handler
+ self._identity_handler = hs.get_identity_handler()
self.user_directory_handler = hs.get_user_directory_handler()
# Flag that indicates whether the process to part users from rooms is running
@@ -137,7 +140,7 @@ class DeactivateAccountHandler(BaseHandler):
return identity_server_supports_unbinding
- async def _reject_pending_invites_for_user(self, user_id: str):
+ async def _reject_pending_invites_for_user(self, user_id: str) -> None:
"""Reject pending invites addressed to a given user ID.
Args:
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b9d9098104..debb1b4f29 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api import errors
from synapse.api.constants import EventTypes
@@ -29,7 +29,10 @@ from synapse.api.errors import (
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import (
+ Collection,
+ JsonDict,
StreamToken,
+ UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
@@ -41,13 +44,16 @@ from synapse.util.retryutils import NotRetryingDestination
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
MAX_DEVICE_DISPLAY_NAME_LEN = 100
class DeviceWorkerHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
@@ -105,7 +111,9 @@ class DeviceWorkerHandler(BaseHandler):
@trace
@measure_func("device.get_user_ids_changed")
- async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
+ async def get_user_ids_changed(
+ self, user_id: str, from_token: StreamToken
+ ) -> JsonDict:
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
"""
@@ -221,8 +229,8 @@ class DeviceWorkerHandler(BaseHandler):
possibly_joined = possibly_changed & users_who_share_room
possibly_left = (possibly_changed | possibly_left) - users_who_share_room
else:
- possibly_joined = []
- possibly_left = []
+ possibly_joined = set()
+ possibly_left = set()
result = {"changed": list(possibly_joined), "left": list(possibly_left)}
@@ -230,7 +238,7 @@ class DeviceWorkerHandler(BaseHandler):
return result
- async def on_federation_query_user_devices(self, user_id):
+ async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
user_id
)
@@ -249,7 +257,7 @@ class DeviceWorkerHandler(BaseHandler):
class DeviceHandler(DeviceWorkerHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.federation_sender = hs.get_federation_sender()
@@ -264,7 +272,7 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
- def _check_device_name_length(self, name: str):
+ def _check_device_name_length(self, name: Optional[str]):
"""
Checks whether a device name is longer than the maximum allowed length.
@@ -283,8 +291,11 @@ class DeviceHandler(DeviceWorkerHandler):
)
async def check_device_registered(
- self, user_id, device_id, initial_device_display_name=None
- ):
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_device_display_name: Optional[str] = None,
+ ) -> str:
"""
If the given device has not been registered, register it with the
supplied display name.
@@ -292,12 +303,11 @@ class DeviceHandler(DeviceWorkerHandler):
If no device_id is supplied, we make one up.
Args:
- user_id (str): @user:id
- device_id (str | None): device id supplied by client
- initial_device_display_name (str | None): device display name from
- client
+ user_id: @user:id
+ device_id: device id supplied by client
+ initial_device_display_name: device display name from client
Returns:
- str: device id (generated if none was supplied)
+ device id (generated if none was supplied)
"""
self._check_device_name_length(initial_device_display_name)
@@ -316,15 +326,15 @@ class DeviceHandler(DeviceWorkerHandler):
# times in case of a clash.
attempts = 0
while attempts < 5:
- device_id = stringutils.random_string(10).upper()
+ new_device_id = stringutils.random_string(10).upper()
new_device = await self.store.store_device(
user_id=user_id,
- device_id=device_id,
+ device_id=new_device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
- await self.notify_device_update(user_id, [device_id])
- return device_id
+ await self.notify_device_update(user_id, [new_device_id])
+ return new_device_id
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@@ -433,7 +443,9 @@ class DeviceHandler(DeviceWorkerHandler):
@trace
@measure_func("notify_device_update")
- async def notify_device_update(self, user_id, device_ids):
+ async def notify_device_update(
+ self, user_id: str, device_ids: Collection[str]
+ ) -> None:
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
@@ -445,7 +457,7 @@ class DeviceHandler(DeviceWorkerHandler):
user_id
)
- hosts = set()
+ hosts = set() # type: Set[str]
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
@@ -497,7 +509,7 @@ class DeviceHandler(DeviceWorkerHandler):
self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
- async def user_left_room(self, user, room_id):
+ async def user_left_room(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
@@ -505,8 +517,89 @@ class DeviceHandler(DeviceWorkerHandler):
# receive device updates. Mark this in DB.
await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+ async def store_dehydrated_device(
+ self,
+ user_id: str,
+ device_data: JsonDict,
+ initial_device_display_name: Optional[str] = None,
+ ) -> str:
+ """Store a dehydrated device for a user. If the user had a previous
+ dehydrated device, it is removed.
+
+ Args:
+ user_id: the user that we are storing the device for
+ device_data: the dehydrated device information
+ initial_device_display_name: The display name to use for the device
+ Returns:
+ device id of the dehydrated device
+ """
+ device_id = await self.check_device_registered(
+ user_id, None, initial_device_display_name,
+ )
+ old_device_id = await self.store.store_dehydrated_device(
+ user_id, device_id, device_data
+ )
+ if old_device_id is not None:
+ await self.delete_device(user_id, old_device_id)
+ return device_id
+
+ async def get_dehydrated_device(
+ self, user_id: str
+ ) -> Optional[Tuple[str, JsonDict]]:
+ """Retrieve the information for a dehydrated device.
+
+ Args:
+ user_id: the user whose dehydrated device we are looking for
+ Returns:
+ a tuple whose first item is the device ID, and the second item is
+ the dehydrated device information
+ """
+ return await self.store.get_dehydrated_device(user_id)
+
+ async def rehydrate_device(
+ self, user_id: str, access_token: str, device_id: str
+ ) -> dict:
+ """Process a rehydration request from the user.
+
+ Args:
+ user_id: the user who is rehydrating the device
+ access_token: the access token used for the request
+ device_id: the ID of the device that will be rehydrated
+ Returns:
+ a dict containing {"success": True}
+ """
+ success = await self.store.remove_dehydrated_device(user_id, device_id)
+
+ if not success:
+ raise errors.NotFoundError()
+
+ # If the dehydrated device was successfully deleted (the device ID
+ # matched the stored dehydrated device), then modify the access
+ # token to use the dehydrated device's ID and copy the old device
+ # display name to the dehydrated device, and destroy the old device
+ # ID
+ old_device_id = await self.store.set_device_for_access_token(
+ access_token, device_id
+ )
+ old_device = await self.store.get_device(user_id, old_device_id)
+ await self.store.update_device(user_id, device_id, old_device["display_name"])
+ # can't call self.delete_device because that will clobber the
+ # access token so call the storage layer directly
+ await self.store.delete_device(user_id, old_device_id)
+ await self.store.delete_e2e_keys_by_device(
+ user_id=user_id, device_id=old_device_id
+ )
-def _update_device_from_client_ips(device, client_ips):
+ # tell everyone that the old device is gone and that the dehydrated
+ # device has a new display name
+ await self.notify_device_update(user_id, [old_device_id, device_id])
+
+ return {"success": True}
+
+
+def _update_device_from_client_ips(
+ device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
+) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@@ -514,7 +607,7 @@ def _update_device_from_client_ips(device, client_ips):
class DeviceListUpdater:
"Handles incoming device list updates from federation and updates the DB"
- def __init__(self, hs, device_handler):
+ def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
@@ -523,7 +616,9 @@ class DeviceListUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
# user_id -> list of updates waiting to be handled.
- self._pending_updates = {}
+ self._pending_updates = (
+ {}
+ ) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
@@ -546,7 +641,9 @@ class DeviceListUpdater:
)
@trace
- async def incoming_device_list_update(self, origin, edu_content):
+ async def incoming_device_list_update(
+ self, origin: str, edu_content: JsonDict
+ ) -> None:
"""Called on incoming device list update from federation. Responsible
for parsing the EDU and adding to pending updates list.
"""
@@ -607,7 +704,7 @@ class DeviceListUpdater:
await self._handle_device_updates(user_id)
@measure_func("_incoming_device_list_update")
- async def _handle_device_updates(self, user_id):
+ async def _handle_device_updates(self, user_id: str) -> None:
"Actually handle pending updates."
with (await self._remote_edu_linearizer.queue(user_id)):
@@ -655,7 +752,9 @@ class DeviceListUpdater:
stream_id for _, stream_id, _, _ in pending_updates
)
- async def _need_to_do_resync(self, user_id, updates):
+ async def _need_to_do_resync(
+ self, user_id: str, updates: Iterable[Tuple[str, str, Iterable[str], JsonDict]]
+ ) -> bool:
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
@@ -686,7 +785,7 @@ class DeviceListUpdater:
return False
@trace
- async def _maybe_retry_device_resync(self):
+ async def _maybe_retry_device_resync(self) -> None:
"""Retry to resync device lists that are out of sync, except if another retry is
in progress.
"""
@@ -729,7 +828,7 @@ class DeviceListUpdater:
async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
- ) -> Optional[dict]:
+ ) -> Optional[JsonDict]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
@@ -753,7 +852,7 @@ class DeviceListUpdater:
# it later.
await self.store.mark_remote_user_device_cache_as_stale(user_id)
- return
+ return None
except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
"Failed to handle device list update for %s: %s", user_id, e,
@@ -770,12 +869,12 @@ class DeviceListUpdater:
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
- return
+ return None
except FederationDeniedError as e:
set_tag("error", True)
log_kv({"reason": "FederationDeniedError"})
logger.info(e)
- return
+ return None
except Exception as e:
set_tag("error", True)
log_kv(
@@ -788,7 +887,7 @@ class DeviceListUpdater:
# it later.
await self.store.mark_remote_user_device_cache_as_stale(user_id)
- return
+ return None
log_kv({"result": result})
stream_id = result["stream_id"]
devices = result["devices"]
@@ -849,7 +948,7 @@ class DeviceListUpdater:
user_id: str,
master_key: Optional[Dict[str, Any]],
self_signing_key: Optional[Dict[str, Any]],
- ) -> list:
+ ) -> List[str]:
"""Process the given new master and self-signing key for the given remote user.
Args:
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 64ef7f63ab..9cac5a8463 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict
+from typing import TYPE_CHECKING, Any, Dict
from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background
@@ -24,18 +24,22 @@ from synapse.logging.opentracing import (
set_tag,
start_active_span,
)
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+
logger = logging.getLogger(__name__)
class DeviceMessageHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
"""
Args:
- hs (synapse.server.HomeServer): server
+ hs: server
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@@ -48,7 +52,7 @@ class DeviceMessageHandler:
self._device_list_updater = hs.get_device_handler().device_list_updater
- async def on_direct_to_device_edu(self, origin, content):
+ async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
@@ -95,7 +99,7 @@ class DeviceMessageHandler:
message_type: str,
sender_user_id: str,
by_device: Dict[str, Dict[str, Any]],
- ):
+ ) -> None:
"""Checks inbound device messages for unknown remote devices, and if
found marks the remote cache for the user as stale.
"""
@@ -138,11 +142,16 @@ class DeviceMessageHandler:
self._device_list_updater.user_device_resync, sender_user_id
)
- async def send_device_message(self, sender_user_id, message_type, messages):
+ async def send_device_message(
+ self,
+ sender_user_id: str,
+ message_type: str,
+ messages: Dict[str, Dict[str, JsonDict]],
+ ) -> None:
set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id)
local_messages = {}
- remote_messages = {}
+ remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
for user_id, by_device in messages.items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 62aa9a2da8..ad5683d251 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -46,6 +46,7 @@ class DirectoryHandler(BaseHandler):
self.config = hs.config
self.enable_room_list_search = hs.config.enable_room_list_search
self.require_membership = hs.config.require_membership_for_aliases
+ self.third_party_event_rules = hs.get_third_party_event_rules()
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
@@ -383,7 +384,7 @@ class DirectoryHandler(BaseHandler):
"""
creator = await self.store.get_room_alias_creator(alias.to_string())
- if creator is not None and creator == user_id:
+ if creator == user_id:
return True
# Resolve the alias to the corresponding room.
@@ -454,6 +455,15 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to publish room")
+ # Check if publishing is blocked by a third party module
+ allowed_by_third_party_rules = await (
+ self.third_party_event_rules.check_visibility_can_be_modified(
+ room_id, visibility
+ )
+ )
+ if not allowed_by_third_party_rules:
+ raise SynapseError(403, "Not allowed to publish room")
+
await self.store.set_room_is_public(room_id, making_public)
async def edit_published_appservice_room_list(
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index dd40fd1299..611742ae72 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -496,6 +496,22 @@ class E2eKeysHandler:
log_kv(
{"message": "Did not update one_time_keys", "reason": "no keys given"}
)
+ fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
+ if fallback_keys and isinstance(fallback_keys, dict):
+ log_kv(
+ {
+ "message": "Updating fallback_keys for device.",
+ "user_id": user_id,
+ "device_id": device_id,
+ }
+ )
+ await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
+ elif fallback_keys:
+ log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
+ else:
+ log_kv(
+ {"message": "Did not update fallback_keys", "reason": "no keys given"}
+ )
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1a8144405a..5ac2fc5656 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -3008,6 +3008,9 @@ class FederationHandler(BaseHandler):
elif event.internal_metadata.is_outlier():
return
+ # the event has been persisted so it should have a stream ordering.
+ assert event.internal_metadata.stream_ordering
+
event_pos = PersistedEventPosition(
self._instance_name, event.internal_metadata.stream_ordering
)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 39a85801c1..98075f48d2 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Tuple
from twisted.internet import defer
@@ -47,12 +47,14 @@ class InitialSyncHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
- self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
+ self.snapshot_cache = ResponseCache(
+ hs, "initial_sync_cache"
+ ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
- def snapshot_all_rooms(
+ async def snapshot_all_rooms(
self,
user_id: str,
pagin_config: PaginationConfig,
@@ -84,7 +86,7 @@ class InitialSyncHandler(BaseHandler):
include_archived,
)
- return self.snapshot_cache.wrap(
+ return await self.snapshot_cache.wrap(
key,
self._snapshot_all_rooms,
user_id,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ee271e85e5..ad0b7bd868 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -59,6 +59,7 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING:
+ from synapse.events.third_party_rules import ThirdPartyEventRules
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -393,7 +394,9 @@ class EventCreationHandler:
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
- self.third_party_event_rules = hs.get_third_party_event_rules()
+ self.third_party_event_rules = (
+ self.hs.get_third_party_event_rules()
+ ) # type: ThirdPartyEventRules
self._block_events_without_consent_error = (
self.config.block_events_without_consent_error
@@ -635,59 +638,6 @@ class EventCreationHandler:
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
- async def send_nonmember_event(
- self,
- requester: Requester,
- event: EventBase,
- context: EventContext,
- ratelimit: bool = True,
- ignore_shadow_ban: bool = False,
- ) -> int:
- """
- Persists and notifies local clients and federation of an event.
-
- Args:
- requester: The requester sending the event.
- event: The event to send.
- context: The context of the event.
- ratelimit: Whether to rate limit this send.
- ignore_shadow_ban: True if shadow-banned users should be allowed to
- send this event.
-
- Return:
- The stream_id of the persisted event.
-
- Raises:
- ShadowBanError if the requester has been shadow-banned.
- """
- if event.type == EventTypes.Member:
- raise SynapseError(
- 500, "Tried to send member event through non-member codepath"
- )
-
- if not ignore_shadow_ban and requester.shadow_banned:
- # We randomly sleep a bit just to annoy the requester.
- await self.clock.sleep(random.randint(1, 10))
- raise ShadowBanError()
-
- user = UserID.from_string(event.sender)
-
- assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
-
- if event.is_state():
- prev_event = await self.deduplicate_state_event(event, context)
- if prev_event is not None:
- logger.info(
- "Not bothering to persist state event %s duplicated by %s",
- event.event_id,
- prev_event.event_id,
- )
- return await self.store.get_stream_id_for_event(prev_event.event_id)
-
- return await self.handle_new_client_event(
- requester=requester, event=event, context=context, ratelimit=ratelimit
- )
-
async def deduplicate_state_event(
self, event: EventBase, context: EventContext
) -> Optional[EventBase]:
@@ -728,7 +678,7 @@ class EventCreationHandler:
"""
Creates an event, then sends it.
- See self.create_event and self.send_nonmember_event.
+ See self.create_event and self.handle_new_client_event.
Args:
requester: The requester sending the event.
@@ -738,9 +688,19 @@ class EventCreationHandler:
ignore_shadow_ban: True if shadow-banned users should be allowed to
send this event.
+ Returns:
+ The event, and its stream ordering (if state event deduplication happened,
+ the previous, duplicate event).
+
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
+
+ if event_dict["type"] == EventTypes.Member:
+ raise SynapseError(
+ 500, "Tried to send member event through non-member codepath"
+ )
+
if not ignore_shadow_ban and requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
@@ -756,20 +716,27 @@ class EventCreationHandler:
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
)
+ assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
+ event.sender,
+ )
+
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
- stream_id = await self.send_nonmember_event(
- requester,
- event,
- context,
+ ev = await self.handle_new_client_event(
+ requester=requester,
+ event=event,
+ context=context,
ratelimit=ratelimit,
ignore_shadow_ban=ignore_shadow_ban,
)
- return event, stream_id
+
+ # we know it was persisted, so must have a stream ordering
+ assert ev.internal_metadata.stream_ordering
+ return ev, ev.internal_metadata.stream_ordering
@measure_func("create_new_client_event")
async def create_new_client_event(
@@ -843,8 +810,11 @@ class EventCreationHandler:
context: EventContext,
ratelimit: bool = True,
extra_users: List[UserID] = [],
- ) -> int:
- """Processes a new event. This includes checking auth, persisting it,
+ ignore_shadow_ban: bool = False,
+ ) -> EventBase:
+ """Processes a new event.
+
+ This includes deduplicating, checking auth, persisting,
notifying users, sending to remote servers, etc.
If called from a worker will hit out to the master process for final
@@ -857,10 +827,39 @@ class EventCreationHandler:
ratelimit
extra_users: Any extra users to notify about event
+ ignore_shadow_ban: True if shadow-banned users should be allowed to
+ send this event.
+
Return:
- The stream_id of the persisted event.
+ If the event was deduplicated, the previous, duplicate, event. Otherwise,
+ `event`.
+
+ Raises:
+ ShadowBanError if the requester has been shadow-banned.
"""
+ # we don't apply shadow-banning to membership events here. Invites are blocked
+ # higher up the stack, and we allow shadow-banned users to send join and leave
+ # events as normal.
+ if (
+ event.type != EventTypes.Member
+ and not ignore_shadow_ban
+ and requester.shadow_banned
+ ):
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
+
+ if event.is_state():
+ prev_event = await self.deduplicate_state_event(event, context)
+ if prev_event is not None:
+ logger.info(
+ "Not bothering to persist state event %s duplicated by %s",
+ event.event_id,
+ prev_event.event_id,
+ )
+ return prev_event
+
if event.is_state() and (event.type, event.state_key) == (
EventTypes.Create,
"",
@@ -915,13 +914,13 @@ class EventCreationHandler:
)
stream_id = result["stream_id"]
event.internal_metadata.stream_ordering = stream_id
- return stream_id
+ return event
stream_id = await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
- return stream_id
+ return event
except Exception:
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
@@ -1018,7 +1017,7 @@ class EventCreationHandler:
# Check the alias is currently valid (if it has changed).
room_alias_str = event.content.get("alias", None)
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
if room_alias_str and room_alias_str != original_alias:
await self._validate_canonical_alias(
directory_handler, room_alias_str, event.room_id
@@ -1044,7 +1043,7 @@ class EventCreationHandler:
directory_handler, alias_str, event.room_id
)
- federation_handler = self.hs.get_handlers().federation_handler
+ federation_handler = self.hs.get_federation_handler()
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
@@ -1232,7 +1231,7 @@ class EventCreationHandler:
# Since this is a dummy-event it is OK if it is sent by a
# shadow-banned user.
- await self.send_nonmember_event(
+ await self.handle_new_client_event(
requester, event, context, ratelimit=False, ignore_shadow_ban=True,
)
return True
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 19cd652675..05ac86e697 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -96,6 +96,7 @@ class OidcHandler:
self.hs = hs
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
+ self._user_profile_method = hs.config.oidc_user_profile_method # type: str
self._client_auth = ClientAuth(
hs.config.oidc_client_id,
hs.config.oidc_client_secret,
@@ -196,11 +197,11 @@ class OidcHandler:
% (m["response_types_supported"],)
)
- # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
+ # Ensure there's a userinfo endpoint to fetch from if it is required.
if self._uses_userinfo:
if m.get("userinfo_endpoint") is None:
raise ValueError(
- 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
+ 'provider has no "userinfo_endpoint", even though it is required'
)
else:
# If we're not using userinfo, we need a valid jwks to validate the ID token
@@ -220,8 +221,10 @@ class OidcHandler:
``access_token`` with the ``userinfo_endpoint``.
"""
- # Maybe that should be user-configurable and not inferred?
- return "openid" not in self._scopes
+ return (
+ "openid" not in self._scopes
+ or self._user_profile_method == "userinfo_endpoint"
+ )
async def load_metadata(self) -> OpenIDProviderMetadata:
"""Load and validate the provider metadata.
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 2c2a633938..085b685959 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -383,7 +383,7 @@ class PaginationHandler:
"room_key", leave_token
)
- await self.hs.get_handlers().federation_handler.maybe_backfill(
+ await self.hs.get_federation_handler().maybe_backfill(
room_id, curr_topo, limit=pagin_config.limit,
)
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
index 88e2f87200..6c635cc31b 100644
--- a/synapse/handlers/password_policy.py
+++ b/synapse/handlers/password_policy.py
@@ -16,14 +16,18 @@
import logging
import re
+from typing import TYPE_CHECKING
from synapse.api.errors import Codes, PasswordRefusedError
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class PasswordPolicyHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
@@ -33,11 +37,11 @@ class PasswordPolicyHandler:
self.regexp_uppercase = re.compile("[A-Z]")
self.regexp_lowercase = re.compile("[a-z]")
- def validate_password(self, password):
+ def validate_password(self, password: str) -> None:
"""Checks whether a given password complies with the server's policy.
Args:
- password (str): The password to check against the server's policy.
+ password: The password to check against the server's policy.
Raises:
PasswordRefusedError: The password doesn't comply with the server's policy.
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index c32f314a1c..a7550806e6 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -14,23 +14,29 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class ReadMarkerHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.read_marker_linearizer = Linearizer(name="read_marker")
self.notifier = hs.get_notifier()
- async def received_client_read_marker(self, room_id, user_id, event_id):
+ async def received_client_read_marker(
+ self, room_id: str, user_id: str, event_id: str
+ ) -> None:
"""Updates the read marker for a given user in a given room if the event ID given
is ahead in the stream relative to the current read marker.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 538f4b2a61..a6f1d21674 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -48,7 +48,7 @@ class RegistrationHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
- self.identity_handler = self.hs.get_handlers().identity_handler
+ self.identity_handler = self.hs.get_identity_handler()
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index d5f7c78edf..93ed51063a 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -120,7 +120,7 @@ class RoomCreationHandler(BaseHandler):
# subsequent requests
self._upgrade_response_cache = ResponseCache(
hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
- )
+ ) # type: ResponseCache[Tuple[str, str]]
self._server_notices_mxid = hs.config.server_notices_mxid
self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -185,6 +185,7 @@ class RoomCreationHandler(BaseHandler):
ShadowBanError if the requester is shadow-banned.
"""
user_id = requester.user.to_string()
+ assert self.hs.is_mine_id(user_id), "User must be our own: %s" % (user_id,)
# start by allocating a new room id
r = await self.store.get_room(old_room_id)
@@ -229,8 +230,8 @@ class RoomCreationHandler(BaseHandler):
)
# now send the tombstone
- await self.event_creation_handler.send_nonmember_event(
- requester, tombstone_event, tombstone_context
+ await self.event_creation_handler.handle_new_client_event(
+ requester=requester, event=tombstone_event, context=tombstone_context,
)
old_room_state = await tombstone_context.get_current_state_ids()
@@ -681,7 +682,16 @@ class RoomCreationHandler(BaseHandler):
creator_id=user_id, is_public=is_public, room_version=room_version,
)
- directory_handler = self.hs.get_handlers().directory_handler
+ # Check whether this visibility value is blocked by a third party module
+ allowed_by_third_party_rules = await (
+ self.third_party_event_rules.check_visibility_can_be_modified(
+ room_id, visibility
+ )
+ )
+ if not allowed_by_third_party_rules:
+ raise SynapseError(403, "Room visibility value not allowed.")
+
+ directory_handler = self.hs.get_directory_handler()
if room_alias:
await directory_handler.create_association(
requester=requester,
@@ -962,8 +972,6 @@ class RoomCreationHandler(BaseHandler):
try:
random_string = stringutils.random_string(18)
gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
- if isinstance(gen_room_id, bytes):
- gen_room_id = gen_room_id.decode("utf-8")
await self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8feba8c90a..ffbc62ff44 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64
from synapse import types
-from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
+from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -64,9 +64,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.state_handler = hs.get_state_handler()
self.config = hs.config
- self.federation_handler = hs.get_handlers().federation_handler
- self.directory_handler = hs.get_handlers().directory_handler
- self.identity_handler = hs.get_handlers().identity_handler
+ self.federation_handler = hs.get_federation_handler()
+ self.directory_handler = hs.get_directory_handler()
+ self.identity_handler = hs.get_identity_handler()
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -188,15 +188,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
require_consent=require_consent,
)
- # Check if this event matches the previous membership event for the user.
- duplicate = await self.event_creation_handler.deduplicate_state_event(
- event, context
- )
- if duplicate is not None:
- # Discard the new event since this membership change is a no-op.
- _, stream_id = await self.store.get_event_ordering(duplicate.event_id)
- return duplicate.event_id, stream_id
-
prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@@ -221,7 +212,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
- stream_id = await self.event_creation_handler.handle_new_client_event(
+ result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
@@ -231,7 +222,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target, room_id)
- return event.event_id, stream_id
+ # we know it was persisted, so should have a stream ordering
+ assert result_event.internal_metadata.stream_ordering
+ return result_event.event_id, result_event.internal_metadata.stream_ordering
async def copy_room_tags_and_direct_to_room(
self, old_room_id, new_room_id, user_id
@@ -247,7 +240,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
user_account_data, _ = await self.store.get_account_data_for_user(user_id)
# Copy direct message state if applicable
- direct_rooms = user_account_data.get("m.direct", {})
+ direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {})
# Check which key this room is under
if isinstance(direct_rooms, dict):
@@ -258,7 +251,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Save back to user's m.direct account data
await self.store.add_account_data_for_user(
- user_id, "m.direct", direct_rooms
+ user_id, AccountDataTypes.DIRECT, direct_rooms
)
break
@@ -441,12 +434,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
same_membership = old_membership == effective_membership_state
same_sender = requester.user.to_string() == old_state.sender
if same_sender and same_membership and same_content:
- _, stream_id = await self.store.get_event_ordering(
- old_state.event_id
- )
+ # duplicate event.
+ # we know it was persisted, so must have a stream ordering.
+ assert old_state.internal_metadata.stream_ordering
return (
old_state.event_id,
- stream_id,
+ old_state.internal_metadata.stream_ordering,
)
if old_membership in ["ban", "leave"] and action == "kick":
@@ -642,7 +635,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
async def send_membership_event(
self,
- requester: Requester,
+ requester: Optional[Requester],
event: EventBase,
context: EventContext,
ratelimit: bool = True,
@@ -672,12 +665,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
else:
requester = types.create_requester(target_user)
- prev_event = await self.event_creation_handler.deduplicate_state_event(
- event, context
- )
- if prev_event is not None:
- return
-
prev_state_ids = await context.get_prev_state_ids()
if event.membership == Membership.JOIN:
if requester.is_guest:
@@ -1185,10 +1172,13 @@ class RoomMemberMasterHandler(RoomMemberHandler):
context = await self.state_handler.compute_event_context(event)
context.app_service = requester.app_service
- stream_id = await self.event_creation_handler.handle_new_client_event(
+ result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[UserID.from_string(target_user)],
)
- return event.event_id, stream_id
+ # we know it was persisted, so must have a stream ordering
+ assert result_event.internal_metadata.stream_ordering
+
+ return result_event.event_id, result_event.internal_metadata.stream_ordering
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 249ffe2a55..dc62b21c06 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -49,7 +49,7 @@ class StatsHandler:
# Guard to ensure we only process deltas one at a time
self._is_processing = False
- if hs.config.stats_enabled:
+ if self.stats_enabled and hs.config.run_background_tasks:
self.notifier.add_replication_callback(self.notify_new_event)
# We kick this off so that we don't have to wait for a change before
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index bfe2583002..a306631094 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup
import attr
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.filtering import FilterCollection
from synapse.events import EventBase
from synapse.logging.context import current_context
@@ -87,7 +87,7 @@ class SyncConfig:
class TimelineBatch:
prev_batch = attr.ib(type=StreamToken)
events = attr.ib(type=List[EventBase])
- limited = attr.ib(bool)
+ limited = attr.ib(type=bool)
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -201,6 +201,8 @@ class SyncResult:
device_lists: List of user_ids whose devices have changed
device_one_time_keys_count: Dict of algorithm to count for one time keys
for this device
+ device_unused_fallback_key_types: List of key types that have an unused fallback
+ key
groups: Group updates, if any
"""
@@ -213,6 +215,7 @@ class SyncResult:
to_device = attr.ib(type=List[JsonDict])
device_lists = attr.ib(type=DeviceLists)
device_one_time_keys_count = attr.ib(type=JsonDict)
+ device_unused_fallback_key_types = attr.ib(type=List[str])
groups = attr.ib(type=Optional[GroupsSyncResult])
def __bool__(self) -> bool:
@@ -240,7 +243,9 @@ class SyncHandler:
self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
- self.response_cache = ResponseCache(hs, "sync")
+ self.response_cache = ResponseCache(
+ hs, "sync"
+ ) # type: ResponseCache[Tuple[Any, ...]]
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
self.storage = hs.get_storage()
@@ -457,8 +462,13 @@ class SyncHandler:
recents = []
if not limited or block_all_timeline:
+ prev_batch_token = now_token
+ if recents:
+ room_key = recents[0].internal_metadata.before
+ prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+
return TimelineBatch(
- events=recents, prev_batch=now_token, limited=False
+ events=recents, prev_batch=prev_batch_token, limited=False
)
filtering_factor = 2
@@ -1014,10 +1024,14 @@ class SyncHandler:
logger.debug("Fetching OTK data")
device_id = sync_config.device_id
one_time_key_counts = {} # type: JsonDict
+ unused_fallback_key_types = [] # type: List[str]
if device_id:
one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
+ unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
+ user_id, device_id
+ )
logger.debug("Fetching group data")
await self._generate_sync_entry_for_groups(sync_result_builder)
@@ -1041,6 +1055,7 @@ class SyncHandler:
device_lists=device_lists,
groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts,
+ device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token,
)
@@ -1378,13 +1393,16 @@ class SyncHandler:
return set(), set(), set(), set()
ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", user_id=user_id
+ AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
)
+ # If there is ignored users account data and it matches the proper type,
+ # then use it.
+ ignored_users = frozenset() # type: FrozenSet[str]
if ignored_account_data:
- ignored_users = ignored_account_data.get("ignored_users", {}).keys()
- else:
- ignored_users = frozenset()
+ ignored_users_data = ignored_account_data.get("ignored_users", {})
+ if isinstance(ignored_users_data, dict):
+ ignored_users = frozenset(ignored_users_data.keys())
if since_token:
room_changes = await self._get_rooms_changed(
@@ -1478,7 +1496,7 @@ class SyncHandler:
return False
async def _get_rooms_changed(
- self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str]
+ self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
) -> _RoomChanges:
"""Gets the the changes that have happened since the last sync.
"""
@@ -1690,7 +1708,7 @@ class SyncHandler:
return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms)
async def _get_all_rooms(
- self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str]
+ self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
) -> _RoomChanges:
"""Returns entries for all rooms for the user.
@@ -1764,7 +1782,7 @@ class SyncHandler:
async def _generate_room_entry(
self,
sync_result_builder: "SyncResultBuilder",
- ignored_users: Set[str],
+ ignored_users: FrozenSet[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]],
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 9146dc1a3b..3d66bf305e 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -143,7 +143,7 @@ class _BaseThreepidAuthChecker:
threepid_creds = authdict["threepid_creds"]
- identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler = self.hs.get_identity_handler()
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 996a31a9ec..00b98af3d4 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -257,7 +257,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
- callback_return = raw_callback_return
+ callback_return = raw_callback_return # type: ignore
return callback_return
@@ -406,7 +406,7 @@ class JsonResource(DirectServeJsonResource):
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
- callback_return = raw_callback_return
+ callback_return = raw_callback_return # type: ignore
return callback_return
@@ -651,6 +651,11 @@ def respond_with_json_bytes(
Returns:
twisted.web.server.NOT_DONE_YET if the request is still active.
"""
+ if request._disconnected:
+ logger.warning(
+ "Not sending response to request %s, already disconnected.", request
+ )
+ return
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 144506c8f2..0fc2ea609e 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import os.path
import sys
@@ -89,14 +88,7 @@ class LogContextObserver:
context = current_context()
# Copy the context information to the log event.
- if context is not None:
- context.copy_to_twisted_log_entry(event)
- else:
- # If there's no logging context, not even the root one, we might be
- # starting up or it might be from non-Synapse code. Log it as if it
- # came from the root logger.
- event["request"] = None
- event["scope"] = None
+ context.copy_to_twisted_log_entry(event)
self.observer(event)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index fcbd5378c4..0142542852 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -14,12 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Iterable, Optional, Tuple
from twisted.internet import defer
+from synapse.events import EventBase
+from synapse.http.client import SimpleHttpClient
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
-from synapse.types import UserID
+from synapse.storage.state import StateFilter
+from synapse.types import JsonDict, UserID, create_requester
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
"""
This package defines the 'stable' API which can be used by extension modules which
@@ -43,6 +50,27 @@ class ModuleApi:
self._auth = hs.get_auth()
self._auth_handler = auth_handler
+ # We expose these as properties below in order to attach a helpful docstring.
+ self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
+ self._public_room_list_manager = PublicRoomListManager(hs)
+
+ @property
+ def http_client(self):
+ """Allows making outbound HTTP requests to remote resources.
+
+ An instance of synapse.http.client.SimpleHttpClient
+ """
+ return self._http_client
+
+ @property
+ def public_room_list_manager(self):
+ """Allows adding to, removing from and checking the status of rooms in the
+ public room list.
+
+ An instance of synapse.module_api.PublicRoomListManager
+ """
+ return self._public_room_list_manager
+
def get_user_by_req(self, req, allow_guest=False):
"""Check the access_token provided for a request
@@ -266,3 +294,97 @@ class ModuleApi:
await self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url,
)
+
+ @defer.inlineCallbacks
+ def get_state_events_in_room(
+ self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
+ ) -> defer.Deferred:
+ """Gets current state events for the given room.
+
+ (This is exposed for compatibility with the old SpamCheckerApi. We should
+ probably deprecate it and replace it with an async method in a subclass.)
+
+ Args:
+ room_id: The room ID to get state events in.
+ types: The event type and state key (using None
+ to represent 'any') of the room state to acquire.
+
+ Returns:
+ twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
+ The filtered state events in the room.
+ """
+ state_ids = yield defer.ensureDeferred(
+ self._store.get_filtered_current_state_ids(
+ room_id=room_id, state_filter=StateFilter.from_types(types)
+ )
+ )
+ state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
+ return state.values()
+
+ async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase:
+ """Create and send an event into a room. Membership events are currently not supported.
+
+ Args:
+ event_dict: A dictionary representing the event to send.
+ Required keys are `type`, `room_id`, `sender` and `content`.
+
+ Returns:
+ The event that was sent. If state event deduplication happened, then
+ the previous, duplicate event instead.
+
+ Raises:
+ SynapseError if the event was not allowed.
+ """
+ # Create a requester object
+ requester = create_requester(event_dict["sender"])
+
+ # Create and send the event
+ (
+ event,
+ _,
+ ) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event(
+ requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
+ )
+
+ return event
+
+
+class PublicRoomListManager:
+ """Contains methods for adding to, removing from and querying whether a room
+ is in the public room list.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ self._store = hs.get_datastore()
+
+ async def room_is_in_public_room_list(self, room_id: str) -> bool:
+ """Checks whether a room is in the public room list.
+
+ Args:
+ room_id: The ID of the room.
+
+ Returns:
+ Whether the room is in the public room list. Returns False if the room does
+ not exist.
+ """
+ room = await self._store.get_room(room_id)
+ if not room:
+ return False
+
+ return room.get("is_public", False)
+
+ async def add_room_to_public_room_list(self, room_id: str) -> None:
+ """Publishes a room to the public room list.
+
+ Args:
+ room_id: The ID of the room.
+ """
+ await self._store.set_room_is_public(room_id, True)
+
+ async def remove_room_from_public_room_list(self, room_id: str) -> None:
+ """Removes a room from the public room list.
+
+ Args:
+ room_id: The ID of the room.
+ """
+ await self._store.set_room_is_public(room_id, False)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 59415f6f88..13adeed01e 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -339,7 +339,7 @@ class Notifier:
self,
stream_key: str,
new_token: Union[int, RoomStreamToken],
- users: Collection[UserID] = [],
+ users: Collection[Union[str, UserID]] = [],
rooms: Collection[str] = [],
):
""" Used to inform listeners that something has happened event wise.
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 709ace01e5..3a68ce636f 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,7 +16,7 @@
import logging
import re
-from typing import Any, Dict, List, Pattern, Union
+from typing import Any, Dict, List, Optional, Pattern, Union
from synapse.events import EventBase
from synapse.types import UserID
@@ -181,7 +181,7 @@ class PushRuleEvaluatorForEvent:
return r.search(body)
- def _get_value(self, dotted_key: str) -> str:
+ def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 64edadb624..2b3972cb14 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -92,7 +92,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if self.CACHE:
self.response_cache = ResponseCache(
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
- )
+ ) # type: ResponseCache[str]
# We reserve `instance_name` as a parameter to sending requests, so we
# assert here that sub classes don't try and use the name.
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 5393b9a9e7..b4f4a68b5c 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.clock = hs.get_clock()
- self.federation_handler = hs.get_handlers().federation_handler
+ self.federation_handler = hs.get_federation_handler()
@staticmethod
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 30680baee8..e7cc74a5d2 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -47,7 +47,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
def __init__(self, hs):
super().__init__(hs)
- self.federation_handler = hs.get_handlers().federation_handler
+ self.federation_handler = hs.get_federation_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e165429cad..e27ee216f0 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -191,6 +191,10 @@ class ReplicationDataHandler:
async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, instance_name, token, [])
+ # We poke the generic "replication" notifier to wake anything up that
+ # may be streaming.
+ self.notifier.notify_replication()
+
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 8cd47770c1..ac532ed588 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -141,15 +141,23 @@ class RdataCommand(Command):
class PositionCommand(Command):
- """Sent by the server to tell the client the stream position without
- needing to send an RDATA.
+ """Sent by an instance to tell others the stream position without needing to
+ send an RDATA.
+
+ Two tokens are sent, the new position and the last position sent by the
+ instance (in an RDATA or other POSITION). The tokens are chosen so that *no*
+ rows were written by the instance between the `prev_token` and `new_token`.
+ (If an instance hasn't sent a position before then the new position can be
+ used for both.)
Format::
- POSITION <stream_name> <instance_name> <token>
+ POSITION <stream_name> <instance_name> <prev_token> <new_token>
- On receipt of a POSITION command clients should check if they have missed
- any updates, and if so then fetch them out of band.
+ On receipt of a POSITION command instances should check if they have missed
+ any updates, and if so then fetch them out of band. Instances can check this
+ by comparing their view of the current token for the sending instance with
+ the included `prev_token`.
The `<instance_name>` is the process that sent the command and is the source
of the stream.
@@ -157,18 +165,26 @@ class PositionCommand(Command):
NAME = "POSITION"
- def __init__(self, stream_name, instance_name, token):
+ def __init__(self, stream_name, instance_name, prev_token, new_token):
self.stream_name = stream_name
self.instance_name = instance_name
- self.token = token
+ self.prev_token = prev_token
+ self.new_token = new_token
@classmethod
def from_line(cls, line):
- stream_name, instance_name, token = line.split(" ", 2)
- return cls(stream_name, instance_name, int(token))
+ stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
+ return cls(stream_name, instance_name, int(prev_token), int(new_token))
def to_line(self):
- return " ".join((self.stream_name, self.instance_name, str(self.token)))
+ return " ".join(
+ (
+ self.stream_name,
+ self.instance_name,
+ str(self.prev_token),
+ str(self.new_token),
+ )
+ )
class ErrorCommand(_SimpleCommand):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..95e5502bf2 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -101,8 +101,9 @@ class ReplicationCommandHandler:
self._streams_to_replicate = [] # type: List[Stream]
for stream in self._streams.values():
- if stream.NAME == CachesStream.NAME:
- # All workers can write to the cache invalidation stream.
+ if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
+ # All workers can write to the cache invalidation stream when
+ # using redis.
self._streams_to_replicate.append(stream)
continue
@@ -251,10 +252,9 @@ class ReplicationCommandHandler:
using TCP.
"""
if hs.config.redis.redis_enabled:
- import txredisapi
-
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
+ lazyConnection,
)
logger.info(
@@ -271,7 +271,8 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
- outbound_redis_connection = txredisapi.lazyConnection(
+ outbound_redis_connection = lazyConnection(
+ reactor=hs.get_reactor(),
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
@@ -313,11 +314,14 @@ class ReplicationCommandHandler:
# We respond with current position of all streams this instance
# replicates.
for stream in self.get_streams_to_replicate():
+ # Note that we use the current token as the prev token here (rather
+ # than stream.last_token), as we can't be sure that there have been
+ # no rows written between last token and the current token (since we
+ # might be racing with the replication sending bg process).
+ current_token = stream.current_token(self._instance_name)
self.send_command(
PositionCommand(
- stream.NAME,
- self._instance_name,
- stream.current_token(self._instance_name),
+ stream.NAME, self._instance_name, current_token, current_token,
)
)
@@ -511,16 +515,16 @@ class ReplicationCommandHandler:
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
# between then and now.
- missing_updates = cmd.token != current_token
+ missing_updates = cmd.prev_token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
- cmd.token,
+ cmd.new_token,
)
(updates, current_token, missing_updates) = await stream.get_updates_since(
- cmd.instance_name, current_token, cmd.token
+ cmd.instance_name, current_token, cmd.new_token
)
# TODO: add some tests for this
@@ -536,11 +540,11 @@ class ReplicationCommandHandler:
[stream.parse_row(row) for row in rows],
)
- logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+ logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(
- cmd.stream_name, cmd.instance_name, cmd.token
+ cmd.stream_name, cmd.instance_name, cmd.new_token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0b0d204e64..a509e599c2 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -51,10 +51,11 @@ import fcntl
import logging
import struct
from inspect import isawaitable
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
+from twisted.internet import task
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
- self.time_we_closed = None # When we requested the connection be closed
+ # When we requested the connection be closed
+ self.time_we_closed = None # type: Optional[int]
- self.received_ping = False # Have we reecived a ping from the other side
+ self.received_ping = False # Have we received a ping from the other side
self.state = ConnectionStates.CONNECTING
@@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings.
- self._send_ping_loop = None
+ self._send_ping_loop = None # type: Optional[task.LoopingCall]
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index f225e533de..de19705c1f 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
import txredisapi
@@ -228,3 +228,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.password = self.password
return p
+
+
+def lazyConnection(
+ reactor,
+ host: str = "localhost",
+ port: int = 6379,
+ dbid: Optional[int] = None,
+ reconnect: bool = True,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ connectTimeout: Optional[int] = None,
+ replyTimeout: Optional[int] = None,
+ convertNumbers: bool = True,
+) -> txredisapi.RedisProtocol:
+ """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
+ reactor.
+ """
+
+ isLazy = True
+ poolsize = 1
+
+ uuid = "%s:%d" % (host, port)
+ factory = txredisapi.RedisFactory(
+ uuid,
+ dbid,
+ poolsize,
+ isLazy,
+ txredisapi.ConnectionHandler,
+ charset,
+ password,
+ replyTimeout,
+ convertNumbers,
+ )
+ factory.continueTrying = reconnect
+ for x in range(poolsize):
+ reactor.connectTCP(host, port, factory, connectTimeout)
+
+ return factory.handler
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 687984e7a8..666c13fdb7 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -23,7 +23,9 @@ from prometheus_client import Counter
from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.commands import PositionCommand
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.replication.tcp.streams import EventsStream
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@@ -84,6 +86,23 @@ class ReplicationStreamer:
# Set of streams to replicate.
self.streams = self.command_handler.get_streams_to_replicate()
+ # If we have streams then we must have redis enabled or on master
+ assert (
+ not self.streams
+ or hs.config.redis.redis_enabled
+ or not hs.config.worker.worker_app
+ )
+
+ # If we are replicating an event stream we want to periodically check if
+ # we should send updated POSITIONs. We do this as a looping call rather
+ # explicitly poking when the position advances (without new data to
+ # replicate) to reduce replication traffic (otherwise each writer would
+ # likely send a POSITION for each new event received over replication).
+ #
+ # Note that if the position hasn't advanced then we won't send anything.
+ if any(EventsStream.NAME == s.NAME for s in self.streams):
+ self.clock.looping_call(self.on_notifier_poke, 1000)
+
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
connections if there are.
@@ -91,7 +110,7 @@ class ReplicationStreamer:
This should get called each time new data is available, even if it
is currently being executed, so that nothing gets missed
"""
- if not self.command_handler.connected():
+ if not self.command_handler.connected() or not self.streams:
# Don't bother if nothing is listening. We still need to advance
# the stream tokens otherwise they'll fall behind forever
for stream in self.streams:
@@ -136,6 +155,8 @@ class ReplicationStreamer:
self._replication_torture_level / 1000.0
)
+ last_token = stream.last_token
+
logger.debug(
"Getting stream: %s: %s -> %s",
stream.NAME,
@@ -159,6 +180,30 @@ class ReplicationStreamer:
)
stream_updates_counter.labels(stream.NAME).inc(len(updates))
+ else:
+ # The token has advanced but there is no data to
+ # send, so we send a `POSITION` to inform other
+ # workers of the updated position.
+ if stream.NAME == EventsStream.NAME:
+ # XXX: We only do this for the EventStream as it
+ # turns out that e.g. account data streams share
+ # their "current token" with each other, meaning
+ # that it is *not* safe to send a POSITION.
+ logger.info(
+ "Sending position: %s -> %s",
+ stream.NAME,
+ current_token,
+ )
+ self.command_handler.send_command(
+ PositionCommand(
+ stream.NAME,
+ self._instance_name,
+ last_token,
+ current_token,
+ )
+ )
+ continue
+
# Some streams return multiple rows with the same stream IDs,
# we need to make sure they get sent out in batches. We do
# this by setting the current token to all but the last of
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 54dccd15a6..61b282ab2d 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -240,13 +240,18 @@ class BackfillStream(Stream):
ROW_TYPE = BackfillStreamRow
def __init__(self, hs):
- store = hs.get_datastore()
+ self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_current_backfill_token),
- store.get_all_new_backfill_event_rows,
+ self._current_token,
+ self.store.get_all_new_backfill_event_rows,
)
+ def _current_token(self, instance_name: str) -> int:
+ # The backfill stream over replication operates on *positive* numbers,
+ # which means we need to negate it.
+ return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
+
class PresenceStream(Stream):
PresenceStreamRow = namedtuple(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index ccc7ca30d8..82e9e0d64e 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -155,7 +155,7 @@ class EventsStream(Stream):
# now we fetch up to that many rows from the events table
event_rows = await self._store.get_all_new_forward_event_rows(
- from_token, current_token, target_row_count
+ instance_name, from_token, current_token, target_row_count
) # type: List[Tuple]
# we rely on get_all_new_forward_event_rows strictly honouring the limit, so
@@ -180,7 +180,7 @@ class EventsStream(Stream):
upper_limit,
state_rows_limited,
) = await self._store.get_all_updated_current_state_deltas(
- from_token, upper_limit, target_row_count
+ instance_name, from_token, upper_limit, target_row_count
)
limited = limited or state_rows_limited
@@ -189,7 +189,7 @@ class EventsStream(Stream):
# not to bother with the limit.
ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
- from_token, upper_limit
+ instance_name, from_token, upper_limit
) # type: List[Tuple]
# we now need to turn the raw database rows returned into tuples suitable
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 57cac22252..789431ef25 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -57,6 +57,7 @@ from synapse.rest.admin.users import (
UsersRestServletV2,
WhoisRestServlet,
)
+from synapse.types import RoomStreamToken
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
@@ -109,7 +110,9 @@ class PurgeHistoryRestServlet(RestServlet):
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
- room_token = await self.store.get_topological_token_for_event(event_id)
+ room_token = RoomStreamToken(
+ event.depth, event.internal_metadata.stream_ordering
+ )
token = await room_token.to_string(self.store)
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 09726d52d6..f5304ff43d 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -138,7 +138,7 @@ class ListRoomRestServlet(RestServlet):
def __init__(self, hs):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
@@ -273,7 +273,7 @@ class JoinRoomAliasServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
async def on_POST(self, request, room_identifier):
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 20dc1d0e05..8efefbc0a0 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -45,7 +45,7 @@ class UsersRestServlet(RestServlet):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
@@ -82,7 +82,7 @@ class UsersRestServletV2(RestServlet):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request):
await assert_requester_is_admin(self.auth, request)
@@ -135,7 +135,7 @@ class UserRestServletV2(RestServlet):
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
self.store = hs.get_datastore()
self.auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
@@ -448,7 +448,7 @@ class WhoisRestServlet(RestServlet):
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
+ self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
@@ -461,7 +461,7 @@ class WhoisRestServlet(RestServlet):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only whois a local user")
- ret = await self.handlers.admin_handler.get_whois(target_user)
+ ret = await self.admin_handler.get_whois(target_user)
return 200, ret
@@ -591,7 +591,6 @@ class SearchUsersRestServlet(RestServlet):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
async def on_GET(self, request, target_user_id):
"""Get request to search user table for specific users according to
@@ -612,7 +611,7 @@ class SearchUsersRestServlet(RestServlet):
term = parse_string(request, "term", required=True)
logger.info("term: %s ", term)
- ret = await self.handlers.store.search_users(term)
+ ret = await self.store.search_users(term)
return 200, ret
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index faabeeb91c..e5af26b176 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -42,14 +42,13 @@ class ClientDirectoryServer(RestServlet):
def __init__(self, hs):
super().__init__()
self.store = hs.get_datastore()
- self.handlers = hs.get_handlers()
+ self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
async def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
- dir_handler = self.handlers.directory_handler
- res = await dir_handler.get_association(room_alias)
+ res = await self.directory_handler.get_association(room_alias)
return 200, res
@@ -79,19 +78,19 @@ class ClientDirectoryServer(RestServlet):
requester = await self.auth.get_user_by_req(request)
- await self.handlers.directory_handler.create_association(
+ await self.directory_handler.create_association(
requester, room_alias, room_id, servers
)
return 200, {}
async def on_DELETE(self, request, room_alias):
- dir_handler = self.handlers.directory_handler
-
try:
service = self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
- await dir_handler.delete_appservice_association(service, room_alias)
+ await self.directory_handler.delete_appservice_association(
+ service, room_alias
+ )
logger.info(
"Application service at %s deleted alias %s",
service.url,
@@ -107,7 +106,7 @@ class ClientDirectoryServer(RestServlet):
room_alias = RoomAlias.from_string(room_alias)
- await dir_handler.delete_association(requester, room_alias)
+ await self.directory_handler.delete_association(requester, room_alias)
logger.info(
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
@@ -122,7 +121,7 @@ class ClientDirectoryListServer(RestServlet):
def __init__(self, hs):
super().__init__()
self.store = hs.get_datastore()
- self.handlers = hs.get_handlers()
+ self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
async def on_GET(self, request, room_id):
@@ -138,7 +137,7 @@ class ClientDirectoryListServer(RestServlet):
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
- await self.handlers.directory_handler.edit_published_room_list(
+ await self.directory_handler.edit_published_room_list(
requester, room_id, visibility
)
@@ -147,7 +146,7 @@ class ClientDirectoryListServer(RestServlet):
async def on_DELETE(self, request, room_id):
requester = await self.auth.get_user_by_req(request)
- await self.handlers.directory_handler.edit_published_room_list(
+ await self.directory_handler.edit_published_room_list(
requester, room_id, "private"
)
@@ -162,7 +161,7 @@ class ClientAppserviceDirectoryListServer(RestServlet):
def __init__(self, hs):
super().__init__()
self.store = hs.get_datastore()
- self.handlers = hs.get_handlers()
+ self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
def on_PUT(self, request, network_id, room_id):
@@ -180,7 +179,7 @@ class ClientAppserviceDirectoryListServer(RestServlet):
403, "Only appservices can edit the appservice published room list"
)
- await self.handlers.directory_handler.edit_published_appservice_room_list(
+ await self.directory_handler.edit_published_appservice_room_list(
requester.app_service.id, network_id, room_id, visibility
)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 3d1693d7ac..d7deb9300d 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -67,7 +67,6 @@ class LoginRestServlet(RestServlet):
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
- self.handlers = hs.get_handlers()
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index b63389e5fe..00b4397082 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -112,7 +112,6 @@ class RoomCreateRestServlet(TransactionRestServlet):
class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super().__init__(hs)
- self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
@@ -798,7 +797,6 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super().__init__(hs)
- self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
@@ -903,7 +901,7 @@ class RoomAliasListServlet(RestServlet):
def __init__(self, hs: "synapse.server.HomeServer"):
super().__init__()
self.auth = hs.get_auth()
- self.directory_handler = hs.get_handlers().directory_handler
+ self.directory_handler = hs.get_directory_handler()
async def on_GET(self, request, room_id):
requester = await self.auth.get_user_by_req(request)
@@ -920,7 +918,7 @@ class SearchRestServlet(RestServlet):
def __init__(self, hs):
super().__init__()
- self.handlers = hs.get_handlers()
+ self.search_handler = hs.get_search_handler()
self.auth = hs.get_auth()
async def on_POST(self, request):
@@ -929,9 +927,7 @@ class SearchRestServlet(RestServlet):
content = parse_json_object_from_request(request)
batch = parse_string(request, "next_batch")
- results = await self.handlers.search_handler.search(
- requester.user, content, batch
- )
+ results = await self.search_handler.search(requester.user, content, batch)
return 200, results
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index ab5815e7f7..e857cff176 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -56,7 +56,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.hs = hs
self.datastore = hs.get_datastore()
self.config = hs.config
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self.mailer = Mailer(
@@ -327,7 +327,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.config = hs.config
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.store = self.hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
@@ -424,7 +424,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
self.hs = hs
super().__init__()
self.store = self.hs.get_datastore()
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request):
body = parse_json_object_from_request(request)
@@ -574,7 +574,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request):
if not self.config.account_threepid_delegate_msisdn:
@@ -604,7 +604,7 @@ class ThreepidRestServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.hs = hs
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
@@ -660,7 +660,7 @@ class ThreepidAddRestServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.hs = hs
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -711,7 +711,7 @@ class ThreepidBindRestServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.hs = hs
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
async def on_POST(self, request):
@@ -740,7 +740,7 @@ class ThreepidUnbindRestServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.hs = hs
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastore()
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 7e174de692..af117cb27c 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,6 +22,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from ._base import client_patterns, interactive_auth_handler
@@ -151,7 +153,139 @@ class DeviceRestServlet(RestServlet):
return 200, {}
+class DehydratedDeviceServlet(RestServlet):
+ """Retrieve or store a dehydrated device.
+
+ GET /org.matrix.msc2697.v2/dehydrated_device
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+
+ {
+ "device_id": "dehydrated_device_id",
+ "device_data": {
+ "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
+ "account": "dehydrated_device"
+ }
+ }
+
+ PUT /org.matrix.msc2697/dehydrated_device
+ Content-Type: application/json
+
+ {
+ "device_data": {
+ "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
+ "account": "dehydrated_device"
+ }
+ }
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+
+ {
+ "device_id": "dehydrated_device_id"
+ }
+
+ """
+
+ PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=())
+
+ def __init__(self, hs):
+ super().__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+
+ async def on_GET(self, request: SynapseRequest):
+ requester = await self.auth.get_user_by_req(request)
+ dehydrated_device = await self.device_handler.get_dehydrated_device(
+ requester.user.to_string()
+ )
+ if dehydrated_device is not None:
+ (device_id, device_data) = dehydrated_device
+ result = {"device_id": device_id, "device_data": device_data}
+ return (200, result)
+ else:
+ raise errors.NotFoundError("No dehydrated device available")
+
+ async def on_PUT(self, request: SynapseRequest):
+ submission = parse_json_object_from_request(request)
+ requester = await self.auth.get_user_by_req(request)
+
+ if "device_data" not in submission:
+ raise errors.SynapseError(
+ 400, "device_data missing", errcode=errors.Codes.MISSING_PARAM,
+ )
+ elif not isinstance(submission["device_data"], dict):
+ raise errors.SynapseError(
+ 400,
+ "device_data must be an object",
+ errcode=errors.Codes.INVALID_PARAM,
+ )
+
+ device_id = await self.device_handler.store_dehydrated_device(
+ requester.user.to_string(),
+ submission["device_data"],
+ submission.get("initial_device_display_name", None),
+ )
+ return 200, {"device_id": device_id}
+
+
+class ClaimDehydratedDeviceServlet(RestServlet):
+ """Claim a dehydrated device.
+
+ POST /org.matrix.msc2697.v2/dehydrated_device/claim
+ Content-Type: application/json
+
+ {
+ "device_id": "dehydrated_device_id"
+ }
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+
+ {
+ "success": true,
+ }
+
+ """
+
+ PATTERNS = client_patterns(
+ "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=()
+ )
+
+ def __init__(self, hs):
+ super().__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+
+ async def on_POST(self, request: SynapseRequest):
+ requester = await self.auth.get_user_by_req(request)
+
+ submission = parse_json_object_from_request(request)
+
+ if "device_id" not in submission:
+ raise errors.SynapseError(
+ 400, "device_id missing", errcode=errors.Codes.MISSING_PARAM,
+ )
+ elif not isinstance(submission["device_id"], str):
+ raise errors.SynapseError(
+ 400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM,
+ )
+
+ result = await self.device_handler.rehydrate_device(
+ requester.user.to_string(),
+ self.auth.get_access_token_from_request(request),
+ submission["device_id"],
+ )
+
+ return (200, result)
+
+
def register_servlets(hs, http_server):
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
+ DehydratedDeviceServlet(hs).register(http_server)
+ ClaimDehydratedDeviceServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 55c4606569..b91996c738 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -67,6 +68,7 @@ class KeyUploadServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
+ self.device_handler = hs.get_device_handler()
@trace(opname="upload_keys")
async def on_POST(self, request, device_id):
@@ -75,23 +77,28 @@ class KeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
if device_id is not None:
- # passing the device_id here is deprecated; however, we allow it
- # for now for compatibility with older clients.
+ # Providing the device_id should only be done for setting keys
+ # for dehydrated devices; however, we allow it for any device for
+ # compatibility with older clients.
if requester.device_id is not None and device_id != requester.device_id:
- set_tag("error", True)
- log_kv(
- {
- "message": "Client uploading keys for a different device",
- "logged_in_id": requester.device_id,
- "key_being_uploaded": device_id,
- }
- )
- logger.warning(
- "Client uploading keys for a different device "
- "(logged in as %s, uploading for %s)",
- requester.device_id,
- device_id,
+ dehydrated_device = await self.device_handler.get_dehydrated_device(
+ user_id
)
+ if dehydrated_device is not None and device_id != dehydrated_device[0]:
+ set_tag("error", True)
+ log_kv(
+ {
+ "message": "Client uploading keys for a different device",
+ "logged_in_id": requester.device_id,
+ "key_being_uploaded": device_id,
+ }
+ )
+ logger.warning(
+ "Client uploading keys for a different device "
+ "(logged in as %s, uploading for %s)",
+ requester.device_id,
+ device_id,
+ )
else:
device_id = requester.device_id
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index ffa2dfce42..395b6a82a9 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -78,7 +78,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
"""
super().__init__()
self.hs = hs
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.config = hs.config
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
@@ -176,7 +176,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
"""
super().__init__()
self.hs = hs
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request):
body = parse_json_object_from_request(request)
@@ -370,7 +370,7 @@ class RegisterRestServlet(RestServlet):
self.store = hs.get_datastore()
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.room_member_handler = hs.get_room_member_handler()
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_registration_ratelimiter()
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 6779df952f..2b84eb89c0 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -236,6 +236,7 @@ class SyncRestServlet(RestServlet):
"leave": sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
+ "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
"next_batch": await sync_result.next_batch.to_string(self.store),
}
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 6568e61829..67aa993f19 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -213,6 +213,12 @@ async def respond_with_responder(
file_size (int|None): Size in bytes of the media. If not known it should be None
upload_name (str|None): The name of the requested file, if any.
"""
+ if request._disconnected:
+ logger.warning(
+ "Not sending response to request %s, already disconnected.", request
+ )
+ return
+
if not responder:
respond_404(request)
return
diff --git a/synapse/server.py b/synapse/server.py
index 5e3752c333..e793793cdc 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -54,19 +54,22 @@ from synapse.federation.sender import FederationSender
from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
-from synapse.handlers import Handlers
from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.acme import AcmeHandler
+from synapse.handlers.admin import AdminHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
from synapse.handlers.cas_handler import CasHandler
from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
+from synapse.handlers.directory import DirectoryHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
+from synapse.handlers.federation import FederationHandler
from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler
+from synapse.handlers.identity import IdentityHandler
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.message import EventCreationHandler, MessageHandler
from synapse.handlers.pagination import PaginationHandler
@@ -84,6 +87,7 @@ from synapse.handlers.room import (
from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
+from synapse.handlers.search import SearchHandler
from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
@@ -91,6 +95,7 @@ from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+from synapse.module_api import ModuleApi
from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
@@ -185,7 +190,10 @@ class HomeServer(metaclass=abc.ABCMeta):
we are listening on to provide HTTP services.
"""
- REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
+ REQUIRED_ON_BACKGROUND_TASK_STARTUP = [
+ "auth",
+ "stats",
+ ]
# This is overridden in derived application classes
# (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
@@ -251,14 +259,20 @@ class HomeServer(metaclass=abc.ABCMeta):
self.datastores = Databases(self.DATASTORE_CLASS, self)
logger.info("Finished setting up.")
- def setup_master(self) -> None:
+ # Register background tasks required by this server. This must be done
+ # somewhat manually due to the background tasks not being registered
+ # unless handlers are instantiated.
+ if self.config.run_background_tasks:
+ self.setup_background_tasks()
+
+ def setup_background_tasks(self) -> None:
"""
Some handlers have side effects on instantiation (like registering
background updates). This function causes them to be fetched, and
therefore instantiated, to run those side effects.
"""
- for i in self.REQUIRED_ON_MASTER_STARTUP:
- getattr(self, "get_" + i)()
+ for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
+ getattr(self, "get_" + i + "_handler")()
def get_reactor(self) -> twisted.internet.base.ReactorBase:
"""
@@ -309,10 +323,6 @@ class HomeServer(metaclass=abc.ABCMeta):
return FederationServer(self)
@cache_in_self
- def get_handlers(self) -> Handlers:
- return Handlers(self)
-
- @cache_in_self
def get_notifier(self) -> Notifier:
return Notifier(self)
@@ -399,6 +409,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return DeviceMessageHandler(self)
@cache_in_self
+ def get_directory_handler(self) -> DirectoryHandler:
+ return DirectoryHandler(self)
+
+ @cache_in_self
def get_e2e_keys_handler(self) -> E2eKeysHandler:
return E2eKeysHandler(self)
@@ -411,6 +425,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return AcmeHandler(self)
@cache_in_self
+ def get_admin_handler(self) -> AdminHandler:
+ return AdminHandler(self)
+
+ @cache_in_self
def get_application_service_api(self) -> ApplicationServiceApi:
return ApplicationServiceApi(self)
@@ -431,6 +449,14 @@ class HomeServer(metaclass=abc.ABCMeta):
return EventStreamHandler(self)
@cache_in_self
+ def get_federation_handler(self) -> FederationHandler:
+ return FederationHandler(self)
+
+ @cache_in_self
+ def get_identity_handler(self) -> IdentityHandler:
+ return IdentityHandler(self)
+
+ @cache_in_self
def get_initial_sync_handler(self) -> InitialSyncHandler:
return InitialSyncHandler(self)
@@ -450,6 +476,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return DeactivateAccountHandler(self)
@cache_in_self
+ def get_search_handler(self) -> SearchHandler:
+ return SearchHandler(self)
+
+ @cache_in_self
def get_set_password_handler(self) -> SetPasswordHandler:
return SetPasswordHandler(self)
@@ -647,6 +677,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_federation_ratelimiter(self) -> FederationRateLimiter:
return FederationRateLimiter(self.clock, config=self.config.rc_federation)
+ @cache_in_self
+ def get_module_api(self) -> ModuleApi:
+ return ModuleApi(self, self.get_auth_handler())
+
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 395ac5ab02..3ce25bb012 100644
--- a/synapse/spam_checker_api/__init__.py
+++ b/synapse/spam_checker_api/__init__.py
@@ -12,19 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
from enum import Enum
-from twisted.internet import defer
-
-from synapse.storage.state import StateFilter
-
-MYPY = False
-if MYPY:
- import synapse.server
-
-logger = logging.getLogger(__name__)
-
class RegistrationBehaviour(Enum):
"""
@@ -34,35 +23,3 @@ class RegistrationBehaviour(Enum):
ALLOW = "allow"
SHADOW_BAN = "shadow_ban"
DENY = "deny"
-
-
-class SpamCheckerApi:
- """A proxy object that gets passed to spam checkers so they can get
- access to rooms and other relevant information.
- """
-
- def __init__(self, hs: "synapse.server.HomeServer"):
- self.hs = hs
-
- self._store = hs.get_datastore()
-
- @defer.inlineCallbacks
- def get_state_events_in_room(self, room_id: str, types: tuple) -> defer.Deferred:
- """Gets state events for the given room.
-
- Args:
- room_id: The room ID to get state events in.
- types: The event type and state key (using None
- to represent 'any') of the room state to acquire.
-
- Returns:
- twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
- The filtered state events in the room.
- """
- state_ids = yield defer.ensureDeferred(
- self._store.get_filtered_current_state_ids(
- room_id=room_id, state_filter=StateFilter.from_types(types)
- )
- )
- state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
- return state.values()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 31082bb16a..5b0900aa3c 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -738,7 +738,7 @@ def _make_state_cache_entry(
# failing that, look for the closest match.
prev_group = None
- delta_ids = None
+ delta_ids = None # type: Optional[StateMap[str]]
for old_group, old_state in state_groups_ids.items():
n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6116191b16..0ba3a025cf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from typing import (
overload,
)
+import attr
from prometheus_client import Histogram
from typing_extensions import Literal
@@ -90,13 +91,17 @@ def make_pool(
return adbapi.ConnectionPool(
db_config.config["name"],
cp_reactor=reactor,
- cp_openfun=engine.on_new_connection,
+ cp_openfun=lambda conn: engine.on_new_connection(
+ LoggingDatabaseConnection(conn, engine, "on_new_connection")
+ ),
**db_config.config.get("args", {})
)
def make_conn(
- db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ db_config: DatabaseConnectionConfig,
+ engine: BaseDatabaseEngine,
+ default_txn_name: str,
) -> Connection:
"""Make a new connection to the database and return it.
@@ -109,11 +114,60 @@ def make_conn(
for k, v in db_config.config.get("args", {}).items()
if not k.startswith("cp_")
}
- db_conn = engine.module.connect(**db_params)
+ native_db_conn = engine.module.connect(**db_params)
+ db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
+
engine.on_new_connection(db_conn)
return db_conn
+@attr.s(slots=True)
+class LoggingDatabaseConnection:
+ """A wrapper around a database connection that returns `LoggingTransaction`
+ as its cursor class.
+
+ This is mainly used on startup to ensure that queries get logged correctly
+ """
+
+ conn = attr.ib(type=Connection)
+ engine = attr.ib(type=BaseDatabaseEngine)
+ default_txn_name = attr.ib(type=str)
+
+ def cursor(
+ self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+ ) -> "LoggingTransaction":
+ if not txn_name:
+ txn_name = self.default_txn_name
+
+ return LoggingTransaction(
+ self.conn.cursor(),
+ name=txn_name,
+ database_engine=self.engine,
+ after_callbacks=after_callbacks,
+ exception_callbacks=exception_callbacks,
+ )
+
+ def close(self) -> None:
+ self.conn.close()
+
+ def commit(self) -> None:
+ self.conn.commit()
+
+ def rollback(self, *args, **kwargs) -> None:
+ self.conn.rollback(*args, **kwargs)
+
+ def __enter__(self) -> "Connection":
+ self.conn.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback) -> bool:
+ return self.conn.__exit__(exc_type, exc_value, traceback)
+
+ # Proxy through any unknown lookups to the DB conn class.
+ def __getattr__(self, name):
+ return getattr(self.conn, name)
+
+
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
#
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
@@ -247,6 +301,12 @@ class LoggingTransaction:
def close(self) -> None:
self.txn.close()
+ def __enter__(self) -> "LoggingTransaction":
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+
class PerformanceCounters:
def __init__(self):
@@ -395,7 +455,7 @@ class DatabasePool:
def new_transaction(
self,
- conn: Connection,
+ conn: LoggingDatabaseConnection,
desc: str,
after_callbacks: List[_CallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
@@ -436,12 +496,10 @@ class DatabasePool:
i = 0
N = 5
while True:
- cursor = LoggingTransaction(
- conn.cursor(),
- name,
- self.engine,
- after_callbacks,
- exception_callbacks,
+ cursor = conn.cursor(
+ txn_name=name,
+ after_callbacks=after_callbacks,
+ exception_callbacks=exception_callbacks,
)
try:
r = func(cursor, *args, **kwargs)
@@ -638,7 +696,10 @@ class DatabasePool:
if db_autocommit:
self.engine.attempt_to_set_autocommit(conn, True)
- return func(conn, *args, **kwargs)
+ db_conn = LoggingDatabaseConnection(
+ conn, self.engine, "runWithConnection"
+ )
+ return func(db_conn, *args, **kwargs)
finally:
if db_autocommit:
self.engine.attempt_to_set_autocommit(conn, False)
@@ -1678,7 +1739,7 @@ class DatabasePool:
def get_cache_dict(
self,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
table: str,
entity_column: str,
stream_column: str,
@@ -1699,9 +1760,7 @@ class DatabasePool:
"limit": limit,
}
- sql = self.engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
+ txn = db_conn.cursor(txn_name="get_cache_dict")
txn.execute(sql, (int(max_value),))
cache = {row[0]: int(row[1]) for row in txn}
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index aa5d490624..0c24325011 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -46,7 +46,7 @@ class Databases:
db_name = database_config.name
engine = create_engine(database_config.config)
- with make_conn(database_config, engine) as db_conn:
+ with make_conn(database_config, engine, "startup") as db_conn:
logger.info("[database config %r]: Checking database server", db_name)
engine.check_database(db_conn)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0cb12f4c61..9b16f45f3e 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,9 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import calendar
import logging
-import time
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import PresenceState
@@ -268,9 +266,6 @@ class DataStore(
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
- # Used in _generate_user_daily_visits to keep track of progress
- self._last_user_visit_update = self._get_start_of_day()
-
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
@@ -289,7 +284,6 @@ class DataStore(
" last_user_sync_ts, status_msg, currently_active FROM presence_stream"
" WHERE state != ?"
)
- sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
@@ -301,192 +295,6 @@ class DataStore(
return [UserPresenceState(**row) for row in rows]
- async def count_daily_users(self) -> int:
- """
- Counts the number of users who used this homeserver in the last 24 hours.
- """
- yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
- return await self.db_pool.runInteraction(
- "count_daily_users", self._count_users, yesterday
- )
-
- async def count_monthly_users(self) -> int:
- """
- Counts the number of users who used this homeserver in the last 30 days.
- Note this method is intended for phonehome metrics only and is different
- from the mau figure in synapse.storage.monthly_active_users which,
- amongst other things, includes a 3 day grace period before a user counts.
- """
- thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- return await self.db_pool.runInteraction(
- "count_monthly_users", self._count_users, thirty_days_ago
- )
-
- def _count_users(self, txn, time_from):
- """
- Returns number of users seen in the past time_from period
- """
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT user_id FROM user_ips
- WHERE last_seen > ?
- GROUP BY user_id
- ) u
- """
- txn.execute(sql, (time_from,))
- (count,) = txn.fetchone()
- return count
-
- async def count_r30_users(self) -> Dict[str, int]:
- """
- Counts the number of 30 day retained users, defined as:-
- * Users who have created their accounts more than 30 days ago
- * Where last seen at most 30 days ago
- * Where account creation and last_seen are > 30 days apart
-
- Returns:
- A mapping of counts globally as well as broken out by platform.
- """
-
- def _count_r30_users(txn):
- thirty_days_in_secs = 86400 * 30
- now = int(self._clock.time())
- thirty_days_ago_in_secs = now - thirty_days_in_secs
-
- sql = """
- SELECT platform, COALESCE(count(*), 0) FROM (
- SELECT
- users.name, platform, users.creation_ts * 1000,
- MAX(uip.last_seen)
- FROM users
- INNER JOIN (
- SELECT
- user_id,
- last_seen,
- CASE
- WHEN user_agent LIKE '%%Android%%' THEN 'android'
- WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
- WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
- WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
- WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
- ELSE 'unknown'
- END
- AS platform
- FROM user_ips
- ) uip
- ON users.name = uip.user_id
- AND users.appservice_id is NULL
- AND users.creation_ts < ?
- AND uip.last_seen/1000 > ?
- AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
- GROUP BY users.name, platform, users.creation_ts
- ) u GROUP BY platform
- """
-
- results = {}
- txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
- for row in txn:
- if row[0] == "unknown":
- pass
- results[row[0]] = row[1]
-
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT users.name, users.creation_ts * 1000,
- MAX(uip.last_seen)
- FROM users
- INNER JOIN (
- SELECT
- user_id,
- last_seen
- FROM user_ips
- ) uip
- ON users.name = uip.user_id
- AND appservice_id is NULL
- AND users.creation_ts < ?
- AND uip.last_seen/1000 > ?
- AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
- GROUP BY users.name, users.creation_ts
- ) u
- """
-
- txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
- (count,) = txn.fetchone()
- results["all"] = count
-
- return results
-
- return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
-
- def _get_start_of_day(self):
- """
- Returns millisecond unixtime for start of UTC day.
- """
- now = time.gmtime()
- today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
- return today_start * 1000
-
- async def generate_user_daily_visits(self) -> None:
- """
- Generates daily visit data for use in cohort/ retention analysis
- """
-
- def _generate_user_daily_visits(txn):
- logger.info("Calling _generate_user_daily_visits")
- today_start = self._get_start_of_day()
- a_day_in_milliseconds = 24 * 60 * 60 * 1000
- now = self.clock.time_msec()
-
- sql = """
- INSERT INTO user_daily_visits (user_id, device_id, timestamp)
- SELECT u.user_id, u.device_id, ?
- FROM user_ips AS u
- LEFT JOIN (
- SELECT user_id, device_id, timestamp FROM user_daily_visits
- WHERE timestamp = ?
- ) udv
- ON u.user_id = udv.user_id AND u.device_id=udv.device_id
- INNER JOIN users ON users.name=u.user_id
- WHERE last_seen > ? AND last_seen <= ?
- AND udv.timestamp IS NULL AND users.is_guest=0
- AND users.appservice_id IS NULL
- GROUP BY u.user_id, u.device_id
- """
-
- # This means that the day has rolled over but there could still
- # be entries from the previous day. There is an edge case
- # where if the user logs in at 23:59 and overwrites their
- # last_seen at 00:01 then they will not be counted in the
- # previous day's stats - it is important that the query is run
- # often to minimise this case.
- if today_start > self._last_user_visit_update:
- yesterday_start = today_start - a_day_in_milliseconds
- txn.execute(
- sql,
- (
- yesterday_start,
- yesterday_start,
- self._last_user_visit_update,
- today_start,
- ),
- )
- self._last_user_visit_update = today_start
-
- txn.execute(
- sql, (today_start, today_start, self._last_user_visit_update, now)
- )
- # Update _last_user_visit_update to now. The reason to do this
- # rather just clamping to the beginning of the day is to limit
- # the size of the join - meaning that the query can be run more
- # frequently
- self._last_user_visit_update = now
-
- await self.db_pool.runInteraction(
- "generate_user_daily_visits", _generate_user_daily_visits
- )
-
async def get_users(self) -> List[Dict[str, Any]]:
"""Function to retrieve a list of users in users table.
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index ef81d73573..49ee23470d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -18,6 +18,7 @@ import abc
import logging
from typing import Dict, List, Optional, Tuple
+from synapse.api.constants import AccountDataTypes
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
@@ -291,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
) -> bool:
ignored_account_data = await self.get_global_account_data_by_type_for_user(
- "m.ignored_user_list",
+ AccountDataTypes.IGNORED_USER_LIST,
ignorer_user_id,
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
return False
- return ignored_user_id in ignored_account_data.get("ignored_users", {})
+ try:
+ return ignored_user_id in ignored_account_data.get("ignored_users", {})
+ except TypeError:
+ # The type of the ignored_users field is invalid.
+ return False
class AccountDataStore(AccountDataWorkerStore):
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index f211ddbaf8..849bd5ba7a 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -17,12 +17,12 @@ import logging
from typing import TYPE_CHECKING
from synapse.events.utils import prune_event_dict
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.databases.main.events import encode_json
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util.frozenutils import frozendict_json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -35,14 +35,13 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
- def _censor_redactions():
- return run_as_background_process(
- "_censor_redactions", self._censor_redactions
- )
-
- if self.hs.config.redaction_retention_period is not None:
- hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+ if (
+ hs.config.run_background_tasks
+ and self.hs.config.redaction_retention_period is not None
+ ):
+ hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
+ @wrap_as_background_process("_censor_redactions")
async def _censor_redactions(self):
"""Censors all redactions older than the configured period that haven't
been censored yet.
@@ -105,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
- pruned_json = encode_json(
+ pruned_json = frozendict_json_encoder.encode(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
@@ -171,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
return
# Prune the event's dict then convert it to JSON.
- pruned_json = encode_json(
+ pruned_json = frozendict_json_encoder.encode(
prune_event_dict(event.room_version, event.get_dict())
)
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 239c7a949c..a25a888443 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -351,7 +351,63 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return updated
-class ClientIpStore(ClientIpBackgroundUpdateStore):
+class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self.user_ips_max_age = hs.config.user_ips_max_age
+
+ if hs.config.run_background_tasks and self.user_ips_max_age:
+ self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+
+ @wrap_as_background_process("prune_old_user_ips")
+ async def _prune_old_user_ips(self):
+ """Removes entries in user IPs older than the configured period.
+ """
+
+ if self.user_ips_max_age is None:
+ # Nothing to do
+ return
+
+ if not await self.db_pool.updates.has_completed_background_update(
+ "devices_last_seen"
+ ):
+ # Only start pruning if we have finished populating the devices
+ # last seen info.
+ return
+
+ # We do a slightly funky SQL delete to ensure we don't try and delete
+ # too much at once (as the table may be very large from before we
+ # started pruning).
+ #
+ # This works by finding the max last_seen that is less than the given
+ # time, but has no more than N rows before it, deleting all rows with
+ # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+ # returns exactly one row).
+ sql = """
+ DELETE FROM user_ips
+ WHERE last_seen <= (
+ SELECT COALESCE(MAX(last_seen), -1)
+ FROM (
+ SELECT last_seen FROM user_ips
+ WHERE last_seen <= ?
+ ORDER BY last_seen ASC
+ LIMIT 5000
+ ) AS u
+ )
+ """
+
+ timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+ def _prune_old_user_ips_txn(txn):
+ txn.execute(sql, (timestamp,))
+
+ await self.db_pool.runInteraction(
+ "_prune_old_user_ips", _prune_old_user_ips_txn
+ )
+
+
+class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = Cache(
@@ -360,8 +416,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
super().__init__(database, db_conn, hs)
- self.user_ips_max_age = hs.config.user_ips_max_age
-
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
@@ -372,9 +426,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
"before", "shutdown", self._update_client_ips_batch
)
- if self.user_ips_max_age:
- self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
-
async def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
@@ -525,49 +576,3 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
}
for (access_token, ip), (user_agent, last_seen) in results.items()
]
-
- @wrap_as_background_process("prune_old_user_ips")
- async def _prune_old_user_ips(self):
- """Removes entries in user IPs older than the configured period.
- """
-
- if self.user_ips_max_age is None:
- # Nothing to do
- return
-
- if not await self.db_pool.updates.has_completed_background_update(
- "devices_last_seen"
- ):
- # Only start pruning if we have finished populating the devices
- # last seen info.
- return
-
- # We do a slightly funky SQL delete to ensure we don't try and delete
- # too much at once (as the table may be very large from before we
- # started pruning).
- #
- # This works by finding the max last_seen that is less than the given
- # time, but has no more than N rows before it, deleting all rows with
- # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
- # returns exactly one row).
- sql = """
- DELETE FROM user_ips
- WHERE last_seen <= (
- SELECT COALESCE(MAX(last_seen), -1)
- FROM (
- SELECT last_seen FROM user_ips
- WHERE last_seen <= ?
- ORDER BY last_seen ASC
- LIMIT 5000
- ) AS u
- )
- """
-
- timestamp = self.clock.time_msec() - self.user_ips_max_age
-
- def _prune_old_user_ips_txn(txn):
- txn.execute(sql, (timestamp,))
-
- await self.db_pool.runInteraction(
- "_prune_old_user_ips", _prune_old_user_ips_txn
- )
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fdf394c612..88fd97e1df 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@ from synapse.logging.opentracing import (
trace,
whitelisted_homeserver,
)
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -33,7 +33,7 @@ from synapse.storage.database import (
make_tuple_comparison_clause,
)
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
-from synapse.util import json_encoder
+from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import Cache, cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -48,6 +48,14 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(
+ self._prune_old_outbound_device_pokes, 60 * 60 * 1000
+ )
+
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -698,6 +706,172 @@ class DeviceWorkerStore(SQLBaseStore):
_mark_remote_user_device_list_as_unsubscribed_txn,
)
+ async def get_dehydrated_device(
+ self, user_id: str
+ ) -> Optional[Tuple[str, JsonDict]]:
+ """Retrieve the information for a dehydrated device.
+
+ Args:
+ user_id: the user whose dehydrated device we are looking for
+ Returns:
+ a tuple whose first item is the device ID, and the second item is
+ the dehydrated device information
+ """
+ # FIXME: make sure device ID still exists in devices table
+ row = await self.db_pool.simple_select_one(
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id},
+ retcols=["device_id", "device_data"],
+ allow_none=True,
+ )
+ return (
+ (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
+ )
+
+ def _store_dehydrated_device_txn(
+ self, txn, user_id: str, device_id: str, device_data: str
+ ) -> Optional[str]:
+ old_device_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id},
+ retcol="device_id",
+ allow_none=True,
+ )
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id},
+ values={"device_id": device_id, "device_data": device_data},
+ )
+ return old_device_id
+
+ async def store_dehydrated_device(
+ self, user_id: str, device_id: str, device_data: JsonDict
+ ) -> Optional[str]:
+ """Store a dehydrated device for a user.
+
+ Args:
+ user_id: the user that we are storing the device for
+ device_id: the ID of the dehydrated device
+ device_data: the dehydrated device information
+ Returns:
+ device id of the user's previous dehydrated device, if any
+ """
+ return await self.db_pool.runInteraction(
+ "store_dehydrated_device_txn",
+ self._store_dehydrated_device_txn,
+ user_id,
+ device_id,
+ json_encoder.encode(device_data),
+ )
+
+ async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
+ """Remove a dehydrated device.
+
+ Args:
+ user_id: the user that the dehydrated device belongs to
+ device_id: the ID of the dehydrated device
+ """
+ count = await self.db_pool.simple_delete(
+ "dehydrated_devices",
+ {"user_id": user_id, "device_id": device_id},
+ desc="remove_dehydrated_device",
+ )
+ return count >= 1
+
+ @wrap_as_background_process("prune_old_outbound_device_pokes")
+ async def _prune_old_outbound_device_pokes(
+ self, prune_age: int = 24 * 60 * 60 * 1000
+ ) -> None:
+ """Delete old entries out of the device_lists_outbound_pokes to ensure
+ that we don't fill up due to dead servers.
+
+ Normally, we try to send device updates as a delta since a previous known point:
+ this is done by setting the prev_id in the m.device_list_update EDU. However,
+ for that to work, we have to have a complete record of each change to
+ each device, which can add up to quite a lot of data.
+
+ An alternative mechanism is that, if the remote server sees that it has missed
+ an entry in the stream_id sequence for a given user, it will request a full
+ list of that user's devices. Hence, we can reduce the amount of data we have to
+ store (and transmit in some future transaction), by clearing almost everything
+ for a given destination out of the database, and having the remote server
+ resync.
+
+ All we need to do is make sure we keep at least one row for each
+ (user, destination) pair, to remind us to send a m.device_list_update EDU for
+ that user when the destination comes back. It doesn't matter which device
+ we keep.
+ """
+ yesterday = self._clock.time_msec() - prune_age
+
+ def _prune_txn(txn):
+ # look for (user, destination) pairs which have an update older than
+ # the cutoff.
+ #
+ # For each pair, we also need to know the most recent stream_id, and
+ # an arbitrary device_id at that stream_id.
+ select_sql = """
+ SELECT
+ dlop1.destination,
+ dlop1.user_id,
+ MAX(dlop1.stream_id) AS stream_id,
+ (SELECT MIN(dlop2.device_id) AS device_id FROM
+ device_lists_outbound_pokes dlop2
+ WHERE dlop2.destination = dlop1.destination AND
+ dlop2.user_id=dlop1.user_id AND
+ dlop2.stream_id=MAX(dlop1.stream_id)
+ )
+ FROM device_lists_outbound_pokes dlop1
+ GROUP BY destination, user_id
+ HAVING min(ts) < ? AND count(*) > 1
+ """
+
+ txn.execute(select_sql, (yesterday,))
+ rows = txn.fetchall()
+
+ if not rows:
+ return
+
+ logger.info(
+ "Pruning old outbound device list updates for %i users/destinations: %s",
+ len(rows),
+ shortstr((row[0], row[1]) for row in rows),
+ )
+
+ # we want to keep the update with the highest stream_id for each user.
+ #
+ # there might be more than one update (with different device_ids) with the
+ # same stream_id, so we also delete all but one rows with the max stream id.
+ delete_sql = """
+ DELETE FROM device_lists_outbound_pokes
+ WHERE destination = ? AND user_id = ? AND (
+ stream_id < ? OR
+ (stream_id = ? AND device_id != ?)
+ )
+ """
+ count = 0
+ for (destination, user_id, stream_id, device_id) in rows:
+ txn.execute(
+ delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+ )
+ count += txn.rowcount
+
+ # Since we've deleted unsent deltas, we need to remove the entry
+ # of last successful sent so that the prev_ids are correctly set.
+ sql = """
+ DELETE FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ?
+ """
+ txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
+ logger.info("Pruned %d device list outbound pokes", count)
+
+ await self.db_pool.runInteraction(
+ "_prune_old_outbound_device_pokes", _prune_txn,
+ )
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -834,10 +1008,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
name="device_id_exists", keylen=2, max_entries=10000
)
- self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
-
async def store_device(
- self, user_id: str, device_id: str, initial_device_display_name: str
+ self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
) -> bool:
"""Ensure the given device is known; add it to the store if not
@@ -955,7 +1127,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def update_remote_device_list_cache_entry(
- self, user_id: str, device_id: str, content: JsonDict, stream_id: int
+ self, user_id: str, device_id: str, content: JsonDict, stream_id: str
) -> None:
"""Updates a single device in the cache of a remote user's devicelist.
@@ -983,7 +1155,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_id: str,
content: JsonDict,
- stream_id: int,
+ stream_id: str,
) -> None:
if content.get("deleted"):
self.db_pool.simple_delete_txn(
@@ -1193,95 +1365,3 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids
],
)
-
- def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
- """Delete old entries out of the device_lists_outbound_pokes to ensure
- that we don't fill up due to dead servers.
-
- Normally, we try to send device updates as a delta since a previous known point:
- this is done by setting the prev_id in the m.device_list_update EDU. However,
- for that to work, we have to have a complete record of each change to
- each device, which can add up to quite a lot of data.
-
- An alternative mechanism is that, if the remote server sees that it has missed
- an entry in the stream_id sequence for a given user, it will request a full
- list of that user's devices. Hence, we can reduce the amount of data we have to
- store (and transmit in some future transaction), by clearing almost everything
- for a given destination out of the database, and having the remote server
- resync.
-
- All we need to do is make sure we keep at least one row for each
- (user, destination) pair, to remind us to send a m.device_list_update EDU for
- that user when the destination comes back. It doesn't matter which device
- we keep.
- """
- yesterday = self._clock.time_msec() - prune_age
-
- def _prune_txn(txn):
- # look for (user, destination) pairs which have an update older than
- # the cutoff.
- #
- # For each pair, we also need to know the most recent stream_id, and
- # an arbitrary device_id at that stream_id.
- select_sql = """
- SELECT
- dlop1.destination,
- dlop1.user_id,
- MAX(dlop1.stream_id) AS stream_id,
- (SELECT MIN(dlop2.device_id) AS device_id FROM
- device_lists_outbound_pokes dlop2
- WHERE dlop2.destination = dlop1.destination AND
- dlop2.user_id=dlop1.user_id AND
- dlop2.stream_id=MAX(dlop1.stream_id)
- )
- FROM device_lists_outbound_pokes dlop1
- GROUP BY destination, user_id
- HAVING min(ts) < ? AND count(*) > 1
- """
-
- txn.execute(select_sql, (yesterday,))
- rows = txn.fetchall()
-
- if not rows:
- return
-
- logger.info(
- "Pruning old outbound device list updates for %i users/destinations: %s",
- len(rows),
- shortstr((row[0], row[1]) for row in rows),
- )
-
- # we want to keep the update with the highest stream_id for each user.
- #
- # there might be more than one update (with different device_ids) with the
- # same stream_id, so we also delete all but one rows with the max stream id.
- delete_sql = """
- DELETE FROM device_lists_outbound_pokes
- WHERE destination = ? AND user_id = ? AND (
- stream_id < ? OR
- (stream_id = ? AND device_id != ?)
- )
- """
- count = 0
- for (destination, user_id, stream_id, device_id) in rows:
- txn.execute(
- delete_sql, (destination, user_id, stream_id, stream_id, device_id)
- )
- count += txn.rowcount
-
- # Since we've deleted unsent deltas, we need to remove the entry
- # of last successful sent so that the prev_ids are correctly set.
- sql = """
- DELETE FROM device_lists_outbound_last_success
- WHERE destination = ? AND user_id = ?
- """
- txn.executemany(sql, ((row[0], row[1]) for row in rows))
-
- logger.info("Pruned %d device list outbound pokes", count)
-
- return run_as_background_process(
- "prune_old_outbound_device_pokes",
- self.db_pool.runInteraction,
- "_prune_old_outbound_device_pokes",
- _prune_txn,
- )
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 22e1ed15d0..4415909414 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -367,6 +367,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
+ async def set_e2e_fallback_keys(
+ self, user_id: str, device_id: str, fallback_keys: JsonDict
+ ) -> None:
+ """Set the user's e2e fallback keys.
+
+ Args:
+ user_id: the user whose keys are being set
+ device_id: the device whose keys are being set
+ fallback_keys: the keys to set. This is a map from key ID (which is
+ of the form "algorithm:id") to key data.
+ """
+ # fallback_keys will usually only have one item in it, so using a for
+ # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
+ # FIXME: make sure that only one key per algorithm is uploaded
+ for key_id, fallback_key in fallback_keys.items():
+ algorithm, key_id = key_id.split(":", 1)
+ await self.db_pool.simple_upsert(
+ "e2e_fallback_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ values={
+ "key_id": key_id,
+ "key_json": json_encoder.encode(fallback_key),
+ "used": False,
+ },
+ desc="set_e2e_fallback_key",
+ )
+
+ await self.invalidate_cache_and_stream(
+ "get_e2e_unused_fallback_key_types", (user_id, device_id)
+ )
+
+ @cached(max_entries=10000)
+ async def get_e2e_unused_fallback_key_types(
+ self, user_id: str, device_id: str
+ ) -> List[str]:
+ """Returns the fallback key types that have an unused key.
+
+ Args:
+ user_id: the user whose keys are being queried
+ device_id: the device whose keys are being queried
+
+ Returns:
+ a list of key types
+ """
+ return await self.db_pool.simple_select_onecol(
+ "e2e_fallback_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
+ retcol="algorithm",
+ desc="get_e2e_unused_fallback_key_types",
+ )
+
async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Optional[dict]:
@@ -701,15 +756,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
+ fallback_sql = (
+ "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " LIMIT 1"
+ )
result = {}
delete = []
+ used_fallbacks = []
for user_id, device_id, algorithm in query_list:
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
- for key_id, key_json in txn:
+ otk_row = txn.fetchone()
+ if otk_row is not None:
+ key_id, key_json = otk_row
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
+ else:
+ # no one-time key available, so see if there's a fallback
+ # key
+ txn.execute(fallback_sql, (user_id, device_id, algorithm))
+ fallback_row = txn.fetchone()
+ if fallback_row is not None:
+ key_id, key_json, used = fallback_row
+ device_result[algorithm + ":" + key_id] = key_json
+ if not used:
+ used_fallbacks.append(
+ (user_id, device_id, algorithm, key_id)
+ )
+
+ # drop any one-time keys that were claimed
sql = (
"DELETE FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
@@ -726,6 +803,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
+ # mark fallback keys as used
+ for user_id, device_id, algorithm, key_id in used_fallbacks:
+ self.db_pool.simple_update_txn(
+ txn,
+ "e2e_fallback_keys_json",
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ "key_id": key_id,
+ },
+ {"used": True},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
+
return result
return await self.db_pool.runInteraction(
@@ -754,6 +848,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="e2e_fallback_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6d3689c09e..a6279a6c13 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -19,7 +19,7 @@ from typing import Dict, Iterable, List, Set, Tuple
from synapse.api.errors import StoreError
from synapse.events import EventBase
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
@@ -32,6 +32,14 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ if hs.config.run_background_tasks:
+ hs.get_clock().looping_call(
+ self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+ )
+
async def get_auth_chain(
self, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
@@ -586,6 +594,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return [row["event_id"] for row in rows]
+ @wrap_as_background_process("delete_old_forward_extrem_cache")
+ async def _delete_old_forward_extrem_cache(self) -> None:
+ def _delete_old_forward_extrem_cache_txn(txn):
+ # Delete entries older than a month, while making sure we don't delete
+ # the only entries for a room.
+ sql = """
+ DELETE FROM stream_ordering_to_exterm
+ WHERE
+ room_id IN (
+ SELECT room_id
+ FROM stream_ordering_to_exterm
+ WHERE stream_ordering > ?
+ ) AND stream_ordering < ?
+ """
+ txn.execute(
+ sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
+ )
+
+ await self.db_pool.runInteraction(
+ "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
+ )
+
class EventFederationStore(EventFederationWorkerStore):
""" Responsible for storing and serving up the various graphs associated
@@ -606,34 +636,6 @@ class EventFederationStore(EventFederationWorkerStore):
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
- hs.get_clock().looping_call(
- self._delete_old_forward_extrem_cache, 60 * 60 * 1000
- )
-
- def _delete_old_forward_extrem_cache(self):
- def _delete_old_forward_extrem_cache_txn(txn):
- # Delete entries older than a month, while making sure we don't delete
- # the only entries for a room.
- sql = """
- DELETE FROM stream_ordering_to_exterm
- WHERE
- room_id IN (
- SELECT room_id
- FROM stream_ordering_to_exterm
- WHERE stream_ordering > ?
- ) AND stream_ordering < ?
- """
- txn.execute(
- sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
- )
-
- return run_as_background_process(
- "delete_old_forward_extrem_cache",
- self.db_pool.runInteraction,
- "_delete_old_forward_extrem_cache",
- _delete_old_forward_extrem_cache_txn,
- )
-
async def clean_room_for_join(self, room_id):
return await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 62f1738732..2e56dfaf31 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,15 +13,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import Dict, List, Optional, Tuple, Union
import attr
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -74,19 +73,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self.stream_ordering_month_ago = None
self.stream_ordering_day_ago = None
- cur = LoggingTransaction(
- db_conn.cursor(),
- name="_find_stream_orderings_for_times_txn",
- database_engine=self.database_engine,
- )
+ cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
self._find_stream_orderings_for_times_txn(cur)
cur.close()
self.find_stream_orderings_looping_call = self._clock.looping_call(
self._find_stream_orderings_for_times, 10 * 60 * 1000
)
+
self._rotate_delay = 3
self._rotate_count = 10000
+ self._doing_notif_rotation = False
+ if hs.config.run_background_tasks:
+ self._rotate_notif_loop = self._clock.looping_call(
+ self._rotate_notifs, 30 * 60 * 1000
+ )
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
@@ -518,15 +519,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"Error removing push actions after event persistence failure"
)
- def _find_stream_orderings_for_times(self):
- return run_as_background_process(
- "event_push_action_stream_orderings",
- self.db_pool.runInteraction,
+ @wrap_as_background_process("event_push_action_stream_orderings")
+ async def _find_stream_orderings_for_times(self) -> None:
+ await self.db_pool.runInteraction(
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn,
)
- def _find_stream_orderings_for_times_txn(self, txn):
+ def _find_stream_orderings_for_times_txn(self, txn: LoggingTransaction) -> None:
logger.info("Searching for stream ordering 1 month ago")
self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
@@ -656,129 +656,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
return result[0] if result else None
-
-class EventPushActionsStore(EventPushActionsWorkerStore):
- EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
-
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
-
- self.db_pool.updates.register_background_index_update(
- self.EPA_HIGHLIGHT_INDEX,
- index_name="event_push_actions_u_highlight",
- table="event_push_actions",
- columns=["user_id", "stream_ordering"],
- )
-
- self.db_pool.updates.register_background_index_update(
- "event_push_actions_highlights_index",
- index_name="event_push_actions_highlights_index",
- table="event_push_actions",
- columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
- where_clause="highlight=1",
- )
-
- self._doing_notif_rotation = False
- self._rotate_notif_loop = self._clock.looping_call(
- self._start_rotate_notifs, 30 * 60 * 1000
- )
-
- async def get_push_actions_for_user(
- self, user_id, before=None, limit=50, only_highlight=False
- ):
- def f(txn):
- before_clause = ""
- if before:
- before_clause = "AND epa.stream_ordering < ?"
- args = [user_id, before, limit]
- else:
- args = [user_id, limit]
-
- if only_highlight:
- if len(before_clause) > 0:
- before_clause += " "
- before_clause += "AND epa.highlight = 1"
-
- # NB. This assumes event_ids are globally unique since
- # it makes the query easier to index
- sql = (
- "SELECT epa.event_id, epa.room_id,"
- " epa.stream_ordering, epa.topological_ordering,"
- " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
- " FROM event_push_actions epa, events e"
- " WHERE epa.event_id = e.event_id"
- " AND epa.user_id = ? %s"
- " AND epa.notif = 1"
- " ORDER BY epa.stream_ordering DESC"
- " LIMIT ?" % (before_clause,)
- )
- txn.execute(sql, args)
- return self.db_pool.cursor_to_dict(txn)
-
- push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
- for pa in push_actions:
- pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
- return push_actions
-
- async def get_latest_push_action_stream_ordering(self):
- def f(txn):
- txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
- return txn.fetchone()
-
- result = await self.db_pool.runInteraction(
- "get_latest_push_action_stream_ordering", f
- )
- return result[0] or 0
-
- def _remove_old_push_actions_before_txn(
- self, txn, room_id, user_id, stream_ordering
- ):
- """
- Purges old push actions for a user and room before a given
- stream_ordering.
-
- We however keep a months worth of highlighted notifications, so that
- users can still get a list of recent highlights.
-
- Args:
- txn: The transcation
- room_id: Room ID to delete from
- user_id: user ID to delete for
- stream_ordering: The lowest stream ordering which will
- not be deleted.
- """
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (room_id, user_id),
- )
-
- # We need to join on the events table to get the received_ts for
- # event_push_actions and sqlite won't let us use a join in a delete so
- # we can't just delete where received_ts < x. Furthermore we can
- # only identify event_push_actions by a tuple of room_id, event_id
- # we we can't use a subquery.
- # Instead, we look up the stream ordering for the last event in that
- # room received before the threshold time and delete event_push_actions
- # in the room with a stream_odering before that.
- txn.execute(
- "DELETE FROM event_push_actions "
- " WHERE user_id = ? AND room_id = ? AND "
- " stream_ordering <= ?"
- " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
- (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
- )
-
- txn.execute(
- """
- DELETE FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
- """,
- (room_id, user_id, stream_ordering),
- )
-
- def _start_rotate_notifs(self):
- return run_as_background_process("rotate_notifs", self._rotate_notifs)
-
+ @wrap_as_background_process("rotate_notifs")
async def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return
@@ -958,6 +836,121 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
+class EventPushActionsStore(EventPushActionsWorkerStore):
+ EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_index_update(
+ self.EPA_HIGHLIGHT_INDEX,
+ index_name="event_push_actions_u_highlight",
+ table="event_push_actions",
+ columns=["user_id", "stream_ordering"],
+ )
+
+ self.db_pool.updates.register_background_index_update(
+ "event_push_actions_highlights_index",
+ index_name="event_push_actions_highlights_index",
+ table="event_push_actions",
+ columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+ where_clause="highlight=1",
+ )
+
+ async def get_push_actions_for_user(
+ self, user_id, before=None, limit=50, only_highlight=False
+ ):
+ def f(txn):
+ before_clause = ""
+ if before:
+ before_clause = "AND epa.stream_ordering < ?"
+ args = [user_id, before, limit]
+ else:
+ args = [user_id, limit]
+
+ if only_highlight:
+ if len(before_clause) > 0:
+ before_clause += " "
+ before_clause += "AND epa.highlight = 1"
+
+ # NB. This assumes event_ids are globally unique since
+ # it makes the query easier to index
+ sql = (
+ "SELECT epa.event_id, epa.room_id,"
+ " epa.stream_ordering, epa.topological_ordering,"
+ " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
+ " FROM event_push_actions epa, events e"
+ " WHERE epa.event_id = e.event_id"
+ " AND epa.user_id = ? %s"
+ " AND epa.notif = 1"
+ " ORDER BY epa.stream_ordering DESC"
+ " LIMIT ?" % (before_clause,)
+ )
+ txn.execute(sql, args)
+ return self.db_pool.cursor_to_dict(txn)
+
+ push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
+ for pa in push_actions:
+ pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
+ return push_actions
+
+ async def get_latest_push_action_stream_ordering(self):
+ def f(txn):
+ txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
+ return txn.fetchone()
+
+ result = await self.db_pool.runInteraction(
+ "get_latest_push_action_stream_ordering", f
+ )
+ return result[0] or 0
+
+ def _remove_old_push_actions_before_txn(
+ self, txn, room_id, user_id, stream_ordering
+ ):
+ """
+ Purges old push actions for a user and room before a given
+ stream_ordering.
+
+ We however keep a months worth of highlighted notifications, so that
+ users can still get a list of recent highlights.
+
+ Args:
+ txn: The transcation
+ room_id: Room ID to delete from
+ user_id: user ID to delete for
+ stream_ordering: The lowest stream ordering which will
+ not be deleted.
+ """
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (room_id, user_id),
+ )
+
+ # We need to join on the events table to get the received_ts for
+ # event_push_actions and sqlite won't let us use a join in a delete so
+ # we can't just delete where received_ts < x. Furthermore we can
+ # only identify event_push_actions by a tuple of room_id, event_id
+ # we we can't use a subquery.
+ # Instead, we look up the stream ordering for the last event in that
+ # room received before the threshold time and delete event_push_actions
+ # in the room with a stream_odering before that.
+ txn.execute(
+ "DELETE FROM event_push_actions "
+ " WHERE user_id = ? AND room_id = ? AND "
+ " stream_ordering <= ?"
+ " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+ (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+ )
+
+ txn.execute(
+ """
+ DELETE FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+ """,
+ (room_id, user_id, stream_ordering),
+ )
+
+
def _action_has_highlight(actions):
for action in actions:
try:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 18def01f50..b19c424ba9 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -52,16 +52,6 @@ event_counter = Counter(
)
-def encode_json(json_object):
- """
- Encode a Python object as JSON and return it in a Unicode string.
- """
- out = frozendict_json_encoder.encode(json_object)
- if isinstance(out, bytes):
- out = out.decode("utf8")
- return out
-
-
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@@ -341,6 +331,10 @@ class PersistEventsStore:
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+ # stream orderings should have been assigned by now
+ assert min_stream_order
+ assert max_stream_order
+
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremeties,
@@ -432,12 +426,12 @@ class PersistEventsStore:
# so that async background tasks get told what happened.
sql = """
INSERT INTO current_state_delta_stream
- (stream_id, room_id, type, state_key, event_id, prev_event_id)
- SELECT ?, room_id, type, state_key, null, event_id
+ (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, ?, room_id, type, state_key, null, event_id
FROM current_state_events
WHERE room_id = ?
"""
- txn.execute(sql, (stream_id, room_id))
+ txn.execute(sql, (stream_id, self._instance_name, room_id))
self.db_pool.simple_delete_txn(
txn, table="current_state_events", keyvalues={"room_id": room_id},
@@ -458,8 +452,8 @@ class PersistEventsStore:
#
sql = """
INSERT INTO current_state_delta_stream
- (stream_id, room_id, type, state_key, event_id, prev_event_id)
- SELECT ?, ?, ?, ?, ?, (
+ (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, ?, ?, ?, ?, ?, (
SELECT event_id FROM current_state_events
WHERE room_id = ? AND type = ? AND state_key = ?
)
@@ -469,6 +463,7 @@ class PersistEventsStore:
(
(
stream_id,
+ self._instance_name,
room_id,
etype,
state_key,
@@ -743,7 +738,9 @@ class PersistEventsStore:
logger.exception("")
raise
- metadata_json = encode_json(event.internal_metadata.get_dict())
+ metadata_json = frozendict_json_encoder.encode(
+ event.internal_metadata.get_dict()
+ )
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
@@ -759,6 +756,7 @@ class PersistEventsStore:
"event_stream_ordering": stream_order,
"event_id": event.event_id,
"state_group": state_group_id,
+ "instance_name": self._instance_name,
},
)
@@ -797,10 +795,10 @@ class PersistEventsStore:
{
"event_id": event.event_id,
"room_id": event.room_id,
- "internal_metadata": encode_json(
+ "internal_metadata": frozendict_json_encoder.encode(
event.internal_metadata.get_dict()
),
- "json": encode_json(event_dict(event)),
+ "json": frozendict_json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
}
for event, _ in events_and_contexts
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index f95679ebc4..4e74fafe43 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -74,6 +74,13 @@ class EventRedactBehaviour(Names):
class EventsWorkerStore(SQLBaseStore):
+ # Whether to use dedicated DB threads for event fetching. This is only used
+ # if there are multiple DB threads available. When used will lock the DB
+ # thread for periods of time (so unit tests want to disable this when they
+ # run DB transactions on the main thread). See EVENT_QUEUE_* for more
+ # options controlling this.
+ USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
+
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -522,7 +529,11 @@ class EventsWorkerStore(SQLBaseStore):
if not event_list:
single_threaded = self.database_engine.single_threaded
- if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
self._event_fetch_ongoing -= 1
return
else:
@@ -712,6 +723,7 @@ class EventsWorkerStore(SQLBaseStore):
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
+ original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
event_map[event_id] = original_ev
@@ -779,6 +791,8 @@ class EventsWorkerStore(SQLBaseStore):
* event_id (str)
+ * stream_ordering (int): stream ordering for this event
+
* json (str): json-encoded event structure
* internal_metadata (str): json-encoded internal metadata dict
@@ -811,13 +825,15 @@ class EventsWorkerStore(SQLBaseStore):
sql = """\
SELECT
e.event_id,
- e.internal_metadata,
- e.json,
- e.format_version,
+ e.stream_ordering,
+ ej.internal_metadata,
+ ej.json,
+ ej.format_version,
r.room_version,
rej.reason
- FROM event_json as e
- LEFT JOIN rooms r USING (room_id)
+ FROM events AS e
+ JOIN event_json AS ej USING (event_id)
+ LEFT JOIN rooms r ON r.room_id = e.room_id
LEFT JOIN rejections as rej USING (event_id)
WHERE """
@@ -831,11 +847,12 @@ class EventsWorkerStore(SQLBaseStore):
event_id = row[0]
event_dict[event_id] = {
"event_id": event_id,
- "internal_metadata": row[1],
- "json": row[2],
- "format_version": row[3],
- "room_version_id": row[4],
- "rejected_reason": row[5],
+ "stream_ordering": row[1],
+ "internal_metadata": row[2],
+ "json": row[3],
+ "format_version": row[4],
+ "room_version_id": row[5],
+ "rejected_reason": row[6],
"redactions": [],
}
@@ -1017,16 +1034,12 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
- def get_current_backfill_token(self):
- """The current minimum token that backfilled events have reached"""
- return -self._backfill_id_gen.get_current_token()
-
def get_current_events_token(self):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
- self, last_id: int, current_id: int, limit: int
+ self, instance_name: str, last_id: int, current_id: int, limit: int
) -> List[Tuple]:
"""Returns new events, for the Events replication stream
@@ -1050,10 +1063,11 @@ class EventsWorkerStore(SQLBaseStore):
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " AND instance_name = ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
- txn.execute(sql, (last_id, current_id, limit))
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
return txn.fetchall()
return await self.db_pool.runInteraction(
@@ -1061,7 +1075,7 @@ class EventsWorkerStore(SQLBaseStore):
)
async def get_ex_outlier_stream_rows(
- self, last_id: int, current_id: int
+ self, instance_name: str, last_id: int, current_id: int
) -> List[Tuple]:
"""Returns de-outliered events, for the Events replication stream
@@ -1080,16 +1094,17 @@ class EventsWorkerStore(SQLBaseStore):
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
+ " INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
+ " AND out.instance_name = ?"
" ORDER BY event_stream_ordering ASC"
)
- txn.execute(sql, (last_id, current_id))
+ txn.execute(sql, (last_id, current_id, instance_name))
return txn.fetchall()
return await self.db_pool.runInteraction(
@@ -1102,6 +1117,9 @@ class EventsWorkerStore(SQLBaseStore):
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
+ NOTE: The IDs given here are from replication, and so should be
+ *positive*.
+
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
@@ -1132,10 +1150,11 @@ class EventsWorkerStore(SQLBaseStore):
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " AND instance_name = ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
- txn.execute(sql, (-last_id, -current_id, limit))
+ txn.execute(sql, (-last_id, -current_id, instance_name, limit))
new_event_updates = [(row[0], row[1:]) for row in txn]
limited = False
@@ -1149,15 +1168,16 @@ class EventsWorkerStore(SQLBaseStore):
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
+ " INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
+ " AND out.instance_name = ?"
" ORDER BY event_stream_ordering DESC"
)
- txn.execute(sql, (-last_id, -upper_bound))
+ txn.execute(sql, (-last_id, -upper_bound, instance_name))
new_event_updates.extend((row[0], row[1:]) for row in txn)
if len(new_event_updates) >= limit:
@@ -1171,7 +1191,7 @@ class EventsWorkerStore(SQLBaseStore):
)
async def get_all_updated_current_state_deltas(
- self, from_token: int, to_token: int, target_row_count: int
+ self, instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple], int, bool]:
"""Fetch updates from current_state_delta_stream
@@ -1197,9 +1217,10 @@ class EventsWorkerStore(SQLBaseStore):
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
+ AND instance_name = ?
ORDER BY stream_id ASC LIMIT ?
"""
- txn.execute(sql, (from_token, to_token, target_row_count))
+ txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
return txn.fetchall()
def get_deltas_for_stream_id_txn(txn, stream_id):
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 92099f95ce..0acf0617ca 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -12,15 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import calendar
+import logging
+import time
+from typing import Dict
from synapse.metrics import GaugeBucketCollector
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
+logger = logging.getLogger(__name__)
+
# Collect metrics on the number of forward extremities that exist.
_extremities_collecter = GaugeBucketCollector(
"synapse_forward_extremities",
@@ -51,15 +57,13 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes
- def read_forward_extremities():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "read_forward_extremities", self._read_forward_extremities
- )
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000)
- hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+ # Used in _generate_user_daily_visits to keep track of progress
+ self._last_user_visit_update = self._get_start_of_day()
+ @wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self):
def fetch(txn):
txn.execute(
@@ -137,3 +141,190 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return count
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
+
+ async def count_daily_users(self) -> int:
+ """
+ Counts the number of users who used this homeserver in the last 24 hours.
+ """
+ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
+ return await self.db_pool.runInteraction(
+ "count_daily_users", self._count_users, yesterday
+ )
+
+ async def count_monthly_users(self) -> int:
+ """
+ Counts the number of users who used this homeserver in the last 30 days.
+ Note this method is intended for phonehome metrics only and is different
+ from the mau figure in synapse.storage.monthly_active_users which,
+ amongst other things, includes a 3 day grace period before a user counts.
+ """
+ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+ return await self.db_pool.runInteraction(
+ "count_monthly_users", self._count_users, thirty_days_ago
+ )
+
+ def _count_users(self, txn, time_from):
+ """
+ Returns number of users seen in the past time_from period
+ """
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT user_id FROM user_ips
+ WHERE last_seen > ?
+ GROUP BY user_id
+ ) u
+ """
+ txn.execute(sql, (time_from,))
+ (count,) = txn.fetchone()
+ return count
+
+ async def count_r30_users(self) -> Dict[str, int]:
+ """
+ Counts the number of 30 day retained users, defined as:-
+ * Users who have created their accounts more than 30 days ago
+ * Where last seen at most 30 days ago
+ * Where account creation and last_seen are > 30 days apart
+
+ Returns:
+ A mapping of counts globally as well as broken out by platform.
+ """
+
+ def _count_r30_users(txn):
+ thirty_days_in_secs = 86400 * 30
+ now = int(self._clock.time())
+ thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+ sql = """
+ SELECT platform, COALESCE(count(*), 0) FROM (
+ SELECT
+ users.name, platform, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen,
+ CASE
+ WHEN user_agent LIKE '%%Android%%' THEN 'android'
+ WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+ WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+ WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+ WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+ ELSE 'unknown'
+ END
+ AS platform
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND users.appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, platform, users.creation_ts
+ ) u GROUP BY platform
+ """
+
+ results = {}
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+ for row in txn:
+ if row[0] == "unknown":
+ pass
+ results[row[0]] = row[1]
+
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT users.name, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, users.creation_ts
+ ) u
+ """
+
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+ (count,) = txn.fetchone()
+ results["all"] = count
+
+ return results
+
+ return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+
+ def _get_start_of_day(self):
+ """
+ Returns millisecond unixtime for start of UTC day.
+ """
+ now = time.gmtime()
+ today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
+ return today_start * 1000
+
+ @wrap_as_background_process("generate_user_daily_visits")
+ async def generate_user_daily_visits(self) -> None:
+ """
+ Generates daily visit data for use in cohort/ retention analysis
+ """
+
+ def _generate_user_daily_visits(txn):
+ logger.info("Calling _generate_user_daily_visits")
+ today_start = self._get_start_of_day()
+ a_day_in_milliseconds = 24 * 60 * 60 * 1000
+ now = self._clock.time_msec()
+
+ sql = """
+ INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+ SELECT u.user_id, u.device_id, ?
+ FROM user_ips AS u
+ LEFT JOIN (
+ SELECT user_id, device_id, timestamp FROM user_daily_visits
+ WHERE timestamp = ?
+ ) udv
+ ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+ INNER JOIN users ON users.name=u.user_id
+ WHERE last_seen > ? AND last_seen <= ?
+ AND udv.timestamp IS NULL AND users.is_guest=0
+ AND users.appservice_id IS NULL
+ GROUP BY u.user_id, u.device_id
+ """
+
+ # This means that the day has rolled over but there could still
+ # be entries from the previous day. There is an edge case
+ # where if the user logs in at 23:59 and overwrites their
+ # last_seen at 00:01 then they will not be counted in the
+ # previous day's stats - it is important that the query is run
+ # often to minimise this case.
+ if today_start > self._last_user_visit_update:
+ yesterday_start = today_start - a_day_in_milliseconds
+ txn.execute(
+ sql,
+ (
+ yesterday_start,
+ yesterday_start,
+ self._last_user_visit_update,
+ today_start,
+ ),
+ )
+ self._last_user_visit_update = today_start
+
+ txn.execute(
+ sql, (today_start, today_start, self._last_user_visit_update, now)
+ )
+ # Update _last_user_visit_update to now. The reason to do this
+ # rather just clamping to the beginning of the day is to limit
+ # the size of the join - meaning that the query can be run more
+ # frequently
+ self._last_user_visit_update = now
+
+ await self.db_pool.runInteraction(
+ "generate_user_daily_visits", _generate_user_daily_visits
+ )
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e93aad33cd..d788dc0fc6 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,6 +15,7 @@
import logging
from typing import Dict, List
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
@@ -32,6 +33,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
+ self._limit_usage_by_mau = hs.config.limit_usage_by_mau
+ self._max_mau_value = hs.config.max_mau_value
+
@cached(num_args=0)
async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users
@@ -124,60 +128,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
desc="user_last_seen_monthly_active",
)
-
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
-
- self._limit_usage_by_mau = hs.config.limit_usage_by_mau
- self._mau_stats_only = hs.config.mau_stats_only
- self._max_mau_value = hs.config.max_mau_value
-
- # Do not add more reserved users than the total allowable number
- # cur = LoggingTransaction(
- self.db_pool.new_transaction(
- db_conn,
- "initialise_mau_threepids",
- [],
- [],
- self._initialise_reserved_users,
- hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
- )
-
- def _initialise_reserved_users(self, txn, threepids):
- """Ensures that reserved threepids are accounted for in the MAU table, should
- be called on start up.
-
- Args:
- txn (cursor):
- threepids (list[dict]): List of threepid dicts to reserve
- """
-
- # XXX what is this function trying to achieve? It upserts into
- # monthly_active_users for each *registered* reserved mau user, but why?
- #
- # - shouldn't there already be an entry for each reserved user (at least
- # if they have been active recently)?
- #
- # - if it's important that the timestamp is kept up to date, why do we only
- # run this at startup?
-
- for tp in threepids:
- user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
-
- if user_id:
- is_support = self.is_support_user_txn(txn, user_id)
- if not is_support:
- # We do this manually here to avoid hitting #6791
- self.db_pool.simple_upsert_txn(
- txn,
- table="monthly_active_users",
- keyvalues={"user_id": user_id},
- values={"timestamp": int(self._clock.time_msec())},
- )
- else:
- logger.warning("mau limit reserved threepid %s not found in db" % tp)
-
+ @wrap_as_background_process("reap_monthly_active_users")
async def reap_monthly_active_users(self):
"""Cleans out monthly active user table to ensure that no stale
entries exist.
@@ -257,6 +208,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"reap_monthly_active_users", _reap_users, reserved_users
)
+
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self._mau_stats_only = hs.config.mau_stats_only
+
+ # Do not add more reserved users than the total allowable number
+ self.db_pool.new_transaction(
+ db_conn,
+ "initialise_mau_threepids",
+ [],
+ [],
+ self._initialise_reserved_users,
+ hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
+ )
+
+ def _initialise_reserved_users(self, txn, threepids):
+ """Ensures that reserved threepids are accounted for in the MAU table, should
+ be called on start up.
+
+ Args:
+ txn (cursor):
+ threepids (list[dict]): List of threepid dicts to reserve
+ """
+
+ # XXX what is this function trying to achieve? It upserts into
+ # monthly_active_users for each *registered* reserved mau user, but why?
+ #
+ # - shouldn't there already be an entry for each reserved user (at least
+ # if they have been active recently)?
+ #
+ # - if it's important that the timestamp is kept up to date, why do we only
+ # run this at startup?
+
+ for tp in threepids:
+ user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
+
+ if user_id:
+ is_support = self.is_support_user_txn(txn, user_id)
+ if not is_support:
+ # We do this manually here to avoid hitting #6791
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="monthly_active_users",
+ keyvalues={"user_id": user_id},
+ values={"timestamp": int(self._clock.time_msec())},
+ )
+ else:
+ logger.warning("mau limit reserved threepid %s not found in db" % tp)
+
async def upsert_monthly_active_user(self, user_id: str) -> None:
"""Updates or inserts the user into the monthly active user table, which
is used to track the current MAU usage of the server
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a83df7759d..236d3cdbe3 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,14 +14,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor
@@ -48,6 +47,18 @@ class RegistrationWorkerStore(SQLBaseStore):
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)
+ self._account_validity = hs.config.account_validity
+ if hs.config.run_background_tasks and self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0, self._set_expiration_date_when_missing,
+ )
+
+ # Create a background job for culling expired 3PID validity tokens
+ if hs.config.run_background_tasks:
+ self.clock.looping_call(
+ self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
+ )
+
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
@@ -778,6 +789,79 @@ class RegistrationWorkerStore(SQLBaseStore):
"delete_threepid_session", delete_threepid_session_txn
)
+ @wrap_as_background_process("cull_expired_threepid_validation_tokens")
+ async def cull_expired_threepid_validation_tokens(self) -> None:
+ """Remove threepid validation tokens with expiry dates that have passed"""
+
+ def cull_expired_threepid_validation_tokens_txn(txn, ts):
+ sql = """
+ DELETE FROM threepid_validation_token WHERE
+ expires < ?
+ """
+ txn.execute(sql, (ts,))
+
+ await self.db_pool.runInteraction(
+ "cull_expired_threepid_validation_tokens",
+ cull_expired_threepid_validation_tokens_txn,
+ self.clock.time_msec(),
+ )
+
+ @wrap_as_background_process("account_validity_set_expiration_dates")
+ async def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
+ """
+
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database, filtering out deactivated users.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+ )
+ txn.execute(sql, [])
+
+ res = self.db_pool.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(
+ txn, user["name"], use_delta=True
+ )
+
+ await self.db_pool.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
+ )
+
+ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ """Sets an expiration date to the account with the given user ID.
+
+ Args:
+ user_id (str): User ID to set an expiration date for.
+ use_delta (bool): If set to False, the expiration date for the user will be
+ now + validity period. If set to True, this expiration date will be a
+ random value in the [now + period - d ; now + period] range, d being a
+ delta equal to 10% of the validity period.
+ """
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+
+ if use_delta:
+ expiration_ts = self.rand.randrange(
+ expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts,
+ )
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ "account_validity",
+ keyvalues={"user_id": user_id},
+ values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -911,28 +995,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self._account_validity = hs.config.account_validity
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
- if self._account_validity.enabled:
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "account_validity_set_expiration_dates",
- self._set_expiration_date_when_missing,
- )
-
- # Create a background job for culling expired 3PID validity tokens
- def start_cull():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "cull_expired_threepid_validation_tokens",
- self.cull_expired_threepid_validation_tokens,
- )
-
- hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
-
async def add_access_token_to_user(
self,
user_id: str,
@@ -964,6 +1028,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user",
)
+ def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+ old_device_id = self.db_pool.simple_select_one_onecol_txn(
+ txn, "access_tokens", {"token": token}, "device_id"
+ )
+
+ self.db_pool.simple_update_txn(
+ txn, "access_tokens", {"token": token}, {"device_id": device_id}
+ )
+
+ self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
+
+ return old_device_id
+
+ async def set_device_for_access_token(self, token: str, device_id: str) -> str:
+ """Sets the device ID associated with an access token.
+
+ Args:
+ token: The access token to modify.
+ device_id: The new device ID.
+ Returns:
+ The old device ID associated with the access token.
+ """
+
+ return await self.db_pool.runInteraction(
+ "set_device_for_access_token",
+ self._set_device_for_access_token_txn,
+ token,
+ device_id,
+ )
+
async def register_user(
self,
user_id: str,
@@ -1121,7 +1215,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id",
)
- async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
+ async def user_set_password_hash(
+ self, user_id: str, password_hash: Optional[str]
+ ) -> None:
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
@@ -1447,22 +1543,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn,
)
- async def cull_expired_threepid_validation_tokens(self) -> None:
- """Remove threepid validation tokens with expiry dates that have passed"""
-
- def cull_expired_threepid_validation_tokens_txn(txn, ts):
- sql = """
- DELETE FROM threepid_validation_token WHERE
- expires < ?
- """
- txn.execute(sql, (ts,))
-
- await self.db_pool.runInteraction(
- "cull_expired_threepid_validation_tokens",
- cull_expired_threepid_validation_tokens_txn,
- self.clock.time_msec(),
- )
-
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
@@ -1492,61 +1572,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
txn.call_after(self.is_guest.invalidate, (user_id,))
- async def _set_expiration_date_when_missing(self):
- """
- Retrieves the list of registered users that don't have an expiration date, and
- adds an expiration date for each of them.
- """
-
- def select_users_with_no_expiration_date_txn(txn):
- """Retrieves the list of registered users with no expiration date from the
- database, filtering out deactivated users.
- """
- sql = (
- "SELECT users.name FROM users"
- " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
- " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
- )
- txn.execute(sql, [])
-
- res = self.db_pool.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
-
- await self.db_pool.runInteraction(
- "get_users_with_no_expiration_date",
- select_users_with_no_expiration_date_txn,
- )
-
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
- """Sets an expiration date to the account with the given user ID.
-
- Args:
- user_id (str): User ID to set an expiration date for.
- use_delta (bool): If set to False, the expiration date for the user will be
- now + validity period. If set to True, this expiration date will be a
- random value in the [now + period - d ; now + period] range, d being a
- delta equal to 10% of the validity period.
- """
- now_ms = self._clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
-
- if use_delta:
- expiration_ts = self.rand.randrange(
- expiration_ts - self._account_validity.startup_job_max_delta,
- expiration_ts,
- )
-
- self.db_pool.simple_upsert_txn(
- txn,
- "account_validity",
- keyvalues={"user_id": user_id},
- values={"expiration_ts_ms": expiration_ts, "email_sent": False},
- )
-
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3c7630857f..c0f2af0785 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore):
"count_public_rooms", _count_public_rooms_txn
)
+ async def get_room_count(self) -> int:
+ """Retrieve the total number of rooms.
+ """
+
+ def f(txn):
+ sql = "SELECT count(*) FROM rooms"
+ txn.execute(sql)
+ row = txn.fetchone()
+ return row[0] or 0
+
+ return await self.db_pool.runInteraction("get_rooms", f)
+
async def get_largest_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
@@ -1292,18 +1304,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
self.hs.get_notifier().on_new_replication_data()
- async def get_room_count(self) -> int:
- """Retrieve the total number of rooms.
- """
-
- def f(txn):
- sql = "SELECT count(*) FROM rooms"
- txn.execute(sql)
- row = txn.fetchone()
- return row[0] or 0
-
- return await self.db_pool.runInteraction("get_rooms", f)
-
async def add_event_report(
self,
room_id: str,
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 86ffe2479e..20fcdaa529 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,12 +21,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import (
- LoggingTransaction,
- SQLBaseStore,
- db_to_json,
- make_in_list_sql_clause,
-)
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
@@ -60,15 +55,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# background update still running?
self._current_state_events_membership_up_to_date = False
- txn = LoggingTransaction(
- db_conn.cursor(),
- name="_check_safe_current_state_events_membership_updated",
- database_engine=self.database_engine,
+ txn = db_conn.cursor(
+ txn_name="_check_safe_current_state_events_membership_updated"
)
self._check_safe_current_state_events_membership_updated_txn(txn)
txn.close()
- if self.hs.config.metrics_flags.known_servers:
+ if (
+ self.hs.config.run_background_tasks
+ and self.hs.config.metrics_flags.known_servers
+ ):
self._known_servers_count = 1
self.hs.get_clock().looping_call(
run_as_background_process,
diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py
index 3edfcfd783..45b846e6a7 100644
--- a/synapse/storage/databases/main/schema/delta/20/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/20/pushers.py
@@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs):
row[8] = bytes(row[8]).decode("utf-8")
row[11] = bytes(row[11]).decode("utf-8")
cur.execute(
- database_engine.convert_param_style(
- """
- INSERT into pushers2 (
- id, user_name, access_token, profile_tag, kind,
- app_id, app_display_name, device_display_name,
- pushkey, ts, lang, data, last_token, last_success,
- failing_since
- ) values (%s)"""
- % (",".join(["?" for _ in range(len(row))]))
- ),
+ """
+ INSERT into pushers2 (
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_token, last_success,
+ failing_since
+ ) values (%s)
+ """
+ % (",".join(["?" for _ in range(len(row))])),
row,
)
count += 1
diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py
index ee675e71ff..21f57825d4 100644
--- a/synapse/storage/databases/main/schema/delta/25/fts.py
+++ b/synapse/storage/databases/main/schema/delta/25/fts.py
@@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_search", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py
index b7972cfa8e..1c6058063f 100644
--- a/synapse/storage/databases/main/schema/delta/27/ts.py
+++ b/synapse/storage/databases/main/schema/delta/27/ts.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_origin_server_ts", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py
index b42c02710a..7f08fabe9f 100644
--- a/synapse/storage/databases/main/schema/delta/30/as_users.py
+++ b/synapse/storage/databases/main/schema/delta/30/as_users.py
@@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks:
cur.execute(
- database_engine.convert_param_style(
- "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
- % (",".join("?" for _ in chunk),)
- ),
+ "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
+ % (",".join("?" for _ in chunk),),
[as_id] + chunk,
)
diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py
index 9bb504aad5..5be81c806a 100644
--- a/synapse/storage/databases/main/schema/delta/31/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/31/pushers.py
@@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs):
row = list(row)
row[12] = token_to_stream_ordering(row[12])
cur.execute(
- database_engine.convert_param_style(
- """
- INSERT into pushers2 (
- id, user_name, access_token, profile_tag, kind,
- app_id, app_display_name, device_display_name,
- pushkey, ts, lang, data, last_stream_ordering, last_success,
- failing_since
- ) values (%s)"""
- % (",".join(["?" for _ in range(len(row))]))
- ),
+ """
+ INSERT into pushers2 (
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_stream_ordering, last_success,
+ failing_since
+ ) values (%s)
+ """
+ % (",".join(["?" for _ in range(len(row))])),
row,
)
count += 1
diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py
index 63b757ade6..b84c844e3a 100644
--- a/synapse/storage/databases/main/schema/delta/31/search_update.py
+++ b/synapse/storage/databases/main/schema/delta/31/search_update.py
@@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_search_order", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py
index a3e81eeac7..e928c66a8f 100644
--- a/synapse/storage/databases/main/schema/delta/33/event_fields.py
+++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_fields_sender_url", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index a26057dfb6..ad875c733a 100644
--- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs):
def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute(
- database_engine.convert_param_style(
- "UPDATE remote_media_cache SET last_access_ts = ?"
- ),
- (int(time.time() * 1000),),
+ "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
)
diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
index 1de8b54961..bb7296852a 100644
--- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
+++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
@@ -1,6 +1,8 @@
import logging
+from io import StringIO
from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
logger = logging.getLogger(__name__)
@@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs):
select_clause,
)
- if isinstance(database_engine, PostgresEngine):
- cur.execute(sql)
- else:
- cur.executescript(sql)
+ execute_statements_from_stream(cur, StringIO(sql))
diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
index 63b5acdcf7..44917f0a2e 100644
--- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
+++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
@@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
INNER JOIN room_memberships AS r USING (event_id)
WHERE type = 'm.room.member' AND state_key LIKE ?
"""
- sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("%:" + config.server_name,))
cur.execute(
diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
new file mode 100644
index 0000000000..7851a0a825
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
@@ -0,0 +1,20 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS dehydrated_devices(
+ user_id TEXT NOT NULL PRIMARY KEY,
+ device_id TEXT NOT NULL,
+ device_data TEXT NOT NULL -- JSON-encoded client-defined data
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
new file mode 100644
index 0000000000..4ed981dbf8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
+ user_id TEXT NOT NULL, -- The user this fallback key is for.
+ device_id TEXT NOT NULL, -- The device this fallback key is for.
+ algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
+ key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+ key_json TEXT NOT NULL, -- The key as a JSON blob.
+ used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
+ CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
new file mode 100644
index 0000000000..841186b826
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A unique and immutable mapping between instance name and an integer ID. This
+-- lets us refer to instances via a small ID in e.g. stream tokens, without
+-- having to encode the full name.
+CREATE TABLE IF NOT EXISTS instance_map (
+ instance_id SERIAL PRIMARY KEY,
+ instance_name TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name);
diff --git a/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql
new file mode 100644
index 0000000000..ad1f481428
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE current_state_delta_stream ADD COLUMN instance_name TEXT;
+ALTER TABLE ex_outlier_stream ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 37249f1e3f..e3b9ff5ca6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -53,7 +53,9 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
+from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
@@ -208,6 +210,55 @@ def _make_generic_sql_bound(
)
+def _filter_results(
+ lower_token: Optional[RoomStreamToken],
+ upper_token: Optional[RoomStreamToken],
+ instance_name: str,
+ topological_ordering: int,
+ stream_ordering: int,
+) -> bool:
+ """Returns True if the event persisted by the given instance at the given
+ topological/stream_ordering falls between the two tokens (taking a None
+ token to mean unbounded).
+
+ Used to filter results from fetching events in the DB against the given
+ tokens. This is necessary to handle the case where the tokens include
+ position maps, which we handle by fetching more than necessary from the DB
+ and then filtering (rather than attempting to construct a complicated SQL
+ query).
+ """
+
+ event_historical_tuple = (
+ topological_ordering,
+ stream_ordering,
+ )
+
+ if lower_token:
+ if lower_token.topological is not None:
+ # If these are historical tokens we compare the `(topological, stream)`
+ # tuples.
+ if event_historical_tuple <= lower_token.as_historical_tuple():
+ return False
+
+ else:
+ # If these are live tokens we compare the stream ordering against the
+ # writers stream position.
+ if stream_ordering <= lower_token.get_stream_pos_for_instance(
+ instance_name
+ ):
+ return False
+
+ if upper_token:
+ if upper_token.topological is not None:
+ if upper_token.as_historical_tuple() < event_historical_tuple:
+ return False
+ else:
+ if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+ return False
+
+ return True
+
+
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
@@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
raise NotImplementedError()
def get_room_max_token(self) -> RoomStreamToken:
- return RoomStreamToken(None, self.get_room_max_stream_ordering())
+ """Get a `RoomStreamToken` that marks the current maximum persisted
+ position of the events stream. Useful to get a token that represents
+ "now".
+
+ The token returned is a "live" token that may have an instance_map
+ component.
+ """
+
+ min_pos = self._stream_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+ # The `min_pos` is the minimum position that we know all instances
+ # have finished persisting to, so we only care about instances whose
+ # positions are ahead of that. (Instance positions can be behind the
+ # min position as there are times we can work out that the minimum
+ # position is ahead of the naive minimum across all current
+ # positions. See MultiWriterIdGenerator for details)
+ positions = {
+ i: p
+ for i, p in self._stream_id_gen.get_positions().items()
+ if p > min_pos
+ }
+
+ return RoomStreamToken(None, min_pos, positions)
async def get_room_events_stream_for_rooms(
self,
@@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if from_key == to_key:
return [], from_key
- from_id = from_key.stream
- to_id = to_key.stream
-
- has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+ has_changed = self._events_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ )
if not has_changed:
return [], from_key
def f(txn):
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering > ? AND stream_ordering <= ?"
- " ORDER BY stream_ordering %s LIMIT ?"
- ) % (order,)
- txn.execute(sql, (room_id, from_id, to_id, limit))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ sql = """
+ SELECT event_id, instance_name, topological_ordering, stream_ordering
+ FROM events
+ WHERE
+ room_id = ?
+ AND not outlier
+ AND stream_ordering > ? AND stream_ordering <= ?
+ ORDER BY stream_ordering %s LIMIT ?
+ """ % (
+ order,
+ )
+ txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ from_key,
+ to_key,
+ instance_name,
+ topological_ordering,
+ stream_ordering,
+ )
+ ][:limit]
return rows
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
[r.event_id for r in rows], get_prev_content=True
)
- self._set_before_and_after(ret, rows, topo_order=from_id is None)
+ self._set_before_and_after(ret, rows, topo_order=False)
if order.lower() == "desc":
ret.reverse()
@@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
- from_id = from_key.stream
- to_id = to_key.stream
-
if from_key == to_key:
return []
- if from_id:
+ if from_key:
has_changed = self._membership_stream_cache.has_entity_changed(
- user_id, int(from_id)
+ user_id, int(from_key.stream)
)
if not has_changed:
return []
def f(txn):
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
- )
- txn.execute(sql, (user_id, from_id, to_id))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ sql = """
+ SELECT m.event_id, instance_name, topological_ordering, stream_ordering
+ FROM events AS e, room_memberships AS m
+ WHERE e.event_id = m.event_id
+ AND m.user_id = ?
+ AND e.stream_ordering > ? AND e.stream_ordering <= ?
+ ORDER BY e.stream_ordering ASC
+ """
+ txn.execute(sql, (user_id, min_from_id, max_to_id,))
+
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ from_key,
+ to_key,
+ instance_name,
+ topological_ordering,
+ stream_ordering,
+ )
+ ]
return rows
@@ -546,7 +651,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_room_event_before_stream_ordering(
self, room_id: str, stream_ordering: int
- ) -> Tuple[int, int, str]:
+ ) -> Optional[Tuple[int, int, str]]:
"""Gets details of the first event in a room at or before a stream ordering
Args:
@@ -589,19 +694,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
)
return "t%d-%d" % (topo, token)
- async def get_stream_id_for_event(self, event_id: str) -> int:
- """The stream ID for an event
- Args:
- event_id: The id of the event to look up a stream token for.
- Raises:
- StoreError if the event wasn't in the database.
- Returns:
- A stream ID.
- """
- return await self.db_pool.runInteraction(
- "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
- )
-
def get_stream_id_for_event_txn(
self, txn: LoggingTransaction, event_id: str, allow_none=False,
) -> int:
@@ -979,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
else:
order = "ASC"
+ # The bounds for the stream tokens are complicated by the fact
+ # that we need to handle the instance_map part of the tokens. We do this
+ # by fetching all events between the min stream token and the maximum
+ # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+ # then filtering the results.
+ if from_token.topological is not None:
+ from_bound = (
+ from_token.as_historical_tuple()
+ ) # type: Tuple[Optional[int], int]
+ elif direction == "b":
+ from_bound = (
+ None,
+ from_token.get_max_stream_pos(),
+ )
+ else:
+ from_bound = (
+ None,
+ from_token.stream,
+ )
+
+ to_bound = None # type: Optional[Tuple[Optional[int], int]]
+ if to_token:
+ if to_token.topological is not None:
+ to_bound = to_token.as_historical_tuple()
+ elif direction == "b":
+ to_bound = (
+ None,
+ to_token.stream,
+ )
+ else:
+ to_bound = (
+ None,
+ to_token.get_max_stream_pos(),
+ )
+
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token.as_tuple(),
- to_token=to_token.as_tuple() if to_token else None,
+ from_token=from_bound,
+ to_token=to_bound,
engine=self.database_engine,
)
@@ -993,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
bounds += " AND " + filter_clause
args.extend(filter_args)
- args.append(int(limit))
+ # We fetch more events as we'll filter the result set
+ args.append(int(limit) * 2)
select_keywords = "SELECT"
join_clause = ""
@@ -1015,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
select_keywords += "DISTINCT"
sql = """
- %(select_keywords)s event_id, topological_ordering, stream_ordering
+ %(select_keywords)s
+ event_id, instance_name,
+ topological_ordering, stream_ordering
FROM events
%(join_clause)s
WHERE outlier = ? AND room_id = ? AND %(bounds)s
@@ -1030,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
txn.execute(sql, args)
- rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+ # Filter the result set.
+ rows = [
+ _EventDictReturn(event_id, topological_ordering, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ lower_token=to_token if direction == "b" else from_token,
+ upper_token=from_token if direction == "b" else to_token,
+ instance_name=instance_name,
+ topological_ordering=topological_ordering,
+ stream_ordering=stream_ordering,
+ )
+ ][:limit]
if rows:
topo = rows[-1].topological_ordering
@@ -1095,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
return (events, token)
+ @cached()
+ async def get_id_for_instance(self, instance_name: str) -> int:
+ """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
+ """
+
+ def _get_id_for_instance_txn(txn):
+ instance_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ retcol="instance_id",
+ allow_none=True,
+ )
+ if instance_id is not None:
+ return instance_id
+
+ # If we don't have an entry upsert one.
+ #
+ # We could do this before the first check, and rely on the cache for
+ # efficiency, but each UPSERT causes the next ID to increment which
+ # can quickly bloat the size of the generated IDs for new instances.
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ values={},
+ )
+
+ return self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ retcol="instance_id",
+ )
+
+ return await self.db_pool.runInteraction(
+ "get_id_for_instance", _get_id_for_instance_txn
+ )
+
+ @cached()
+ async def get_name_from_instance_id(self, instance_id: int) -> str:
+ """Get the instance name from an ID previously returned by
+ `get_id_for_instance`.
+ """
+
+ return await self.db_pool.simple_select_one_onecol(
+ table="instance_map",
+ keyvalues={"instance_id": instance_id},
+ retcol="instance_name",
+ desc="get_name_from_instance_id",
+ )
+
class StreamStore(StreamWorkerStore):
def get_room_max_stream_ordering(self) -> int:
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 97aed1500e..7d46090267 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -19,7 +19,7 @@ from typing import Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -43,15 +43,33 @@ _UpdateTransactionRow = namedtuple(
SENTINEL = object()
-class TransactionStore(SQLBaseStore):
+class TransactionWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
+
+ @wrap_as_background_process("cleanup_transactions")
+ async def _cleanup_transactions(self) -> None:
+ now = self._clock.time_msec()
+ month_ago = now - 30 * 24 * 60 * 60 * 1000
+
+ def _cleanup_transactions_txn(txn):
+ txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
+
+ await self.db_pool.runInteraction(
+ "_cleanup_transactions", _cleanup_transactions_txn
+ )
+
+
+class TransactionStore(TransactionWorkerStore):
"""A collection of queries for handling PDUs.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
-
self._destination_retry_cache = ExpiringCache(
cache_name="get_destination_retry_timings",
clock=self._clock,
@@ -266,22 +284,6 @@ class TransactionStore(SQLBaseStore):
},
)
- def _start_cleanup_transactions(self):
- return run_as_background_process(
- "cleanup_transactions", self._cleanup_transactions
- )
-
- async def _cleanup_transactions(self) -> None:
- now = self._clock.time_msec()
- month_ago = now - 30 * 24 * 60 * 60 * 1000
-
- def _cleanup_transactions_txn(txn):
- txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
-
- await self.db_pool.runInteraction(
- "_cleanup_transactions", _cleanup_transactions_txn
- )
-
async def store_destination_rooms_entries(
self, destinations: Iterable[str], room_id: str, stream_ordering: int,
) -> None:
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 3b9211a6d2..79b7ece330 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore):
)
return [(row["user_agent"], row["ip"]) for row in rows]
-
-class UIAuthStore(UIAuthWorkerStore):
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
Remove sessions which were last used earlier than the expiration time.
@@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore):
iterable=session_ids,
keyvalues={},
)
+
+
+class UIAuthStore(UIAuthWorkerStore):
+ pass
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 72939f3984..4d2d88d1f0 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -248,6 +248,8 @@ class EventsPersistenceStorage:
await make_deferred_yieldable(deferred)
event_stream_id = event.internal_metadata.stream_ordering
+ # stream ordering should have been assigned by now
+ assert event_stream_id
pos = PersistedEventPosition(self._instance_name, event_stream_id)
return pos, self.main_store.get_room_max_token()
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 4957e77f4c..459754feab 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import imp
import logging
import os
@@ -24,9 +23,10 @@ from typing import Optional, TextIO
import attr
from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Connection, Cursor
+from synapse.storage.types import Cursor
from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -67,7 +67,7 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
def prepare_database(
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
databases: Collection[str] = ["main", "state"],
@@ -89,7 +89,7 @@ def prepare_database(
"""
try:
- cur = db_conn.cursor()
+ cur = db_conn.cursor(txn_name="prepare_database")
# sqlite does not automatically start transactions for DDL / SELECT statements,
# so we start one before running anything. This ensures that any upgrades
@@ -258,9 +258,7 @@ def _setup_new_database(cur, database_engine, databases):
executescript(cur, entry.absolute_path)
cur.execute(
- database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
- ),
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
(max_current_ver, False),
)
@@ -486,17 +484,13 @@ def _upgrade_existing_database(
# Mark as done.
cur.execute(
- database_engine.convert_param_style(
- "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
- ),
+ "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)",
(v, relative_path),
)
cur.execute("DELETE FROM schema_version")
cur.execute(
- database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
- ),
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
(v, True),
)
@@ -532,10 +526,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
schemas to be applied
"""
cur.execute(
- database_engine.convert_param_style(
- "SELECT file FROM applied_module_schemas WHERE module_name = ?"
- ),
- (modname,),
+ "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
)
applied_deltas = {d for d, in cur}
for (name, stream) in names_and_streams:
@@ -553,9 +544,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
# Mark as done.
cur.execute(
- database_engine.convert_param_style(
- "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
- ),
+ "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)",
(modname, name),
)
@@ -627,9 +616,7 @@ def _get_or_create_schema_state(txn, database_engine):
if current_version:
txn.execute(
- database_engine.convert_param_style(
- "SELECT file FROM applied_schema_deltas WHERE version >= ?"
- ),
+ "SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,),
)
applied_deltas = [d for d, in txn]
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 2d2b560e74..970bb1b9da 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -61,3 +61,9 @@ class Connection(Protocol):
def rollback(self, *args, **kwargs) -> None:
...
+
+ def __enter__(self) -> "Connection":
+ ...
+
+ def __exit__(self, exc_type, exc_value, traceback) -> bool:
+ ...
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ad017207aa..3d8da48f2d 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -55,7 +55,7 @@ def _load_current_id(db_conn, table, column, step=1):
"""
# debug logging for https://github.com/matrix-org/synapse/issues/7968
logger.info("initialising stream generator for %s(%s)", table, column)
- cur = db_conn.cursor()
+ cur = db_conn.cursor(txn_name="_load_current_id")
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
@@ -270,7 +270,7 @@ class MultiWriterIdGenerator:
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
):
- cur = db_conn.cursor()
+ cur = db_conn.cursor(txn_name="_load_current_ids")
# Load the current positions of all writers for the stream.
if self._writers:
@@ -284,15 +284,12 @@ class MultiWriterIdGenerator:
stream_name = ?
AND instance_name != ALL(?)
"""
- sql = self._db.engine.convert_param_style(sql)
cur.execute(sql, (self._stream_name, self._writers))
sql = """
SELECT instance_name, stream_id FROM stream_positions
WHERE stream_name = ?
"""
- sql = self._db.engine.convert_param_style(sql)
-
cur.execute(sql, (self._stream_name,))
self._current_positions = {
@@ -341,7 +338,6 @@ class MultiWriterIdGenerator:
"instance": instance_column,
"cmp": "<=" if self._positive else ">=",
}
- sql = self._db.engine.convert_param_style(sql)
cur.execute(sql, (min_stream_id * self._return_factor,))
self._persisted_upto_position = min_stream_id
@@ -422,7 +418,7 @@ class MultiWriterIdGenerator:
self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id)
- new_cur = None
+ new_cur = None # type: Optional[int]
if self._unfinished_ids:
# If there are unfinished IDs then the new position will be the
@@ -528,6 +524,16 @@ class MultiWriterIdGenerator:
heapq.heappush(self._known_persisted_positions, new_id)
+ # If we're a writer and we don't have any active writes we update our
+ # current position to the latest position seen. This allows the instance
+ # to report a recent position when asked, rather than a potentially old
+ # one (if this instance hasn't written anything for a while).
+ our_current_position = self._current_positions.get(self._instance_name)
+ if our_current_position and not self._unfinished_ids:
+ self._current_positions[self._instance_name] = max(
+ our_current_position, new_id
+ )
+
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 2dd95e2709..4386b6101e 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -17,6 +17,7 @@ import logging
import threading
from typing import Callable, List, Optional
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import (
BaseDatabaseEngine,
IncorrectDatabaseSetup,
@@ -53,7 +54,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
@abc.abstractmethod
def check_consistency(
- self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ self,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ id_column: str,
+ positive: bool = True,
):
"""Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table.
@@ -82,9 +87,13 @@ class PostgresSequenceGenerator(SequenceGenerator):
return [i for (i,) in txn]
def check_consistency(
- self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ self,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ id_column: str,
+ positive: bool = True,
):
- txn = db_conn.cursor()
+ txn = db_conn.cursor(txn_name="sequence.check_consistency")
# First we get the current max ID from the table.
table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
@@ -117,6 +126,8 @@ class PostgresSequenceGenerator(SequenceGenerator):
if max_stream_id > last_value:
logger.warning(
"Postgres sequence %s is behind table %s: %d < %d",
+ self._sequence_name,
+ table,
last_value,
max_stream_id,
)
diff --git a/synapse/types.py b/synapse/types.py
index bd271f9f16..5bde67cc07 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -22,6 +22,7 @@ from typing import (
TYPE_CHECKING,
Any,
Dict,
+ Iterable,
Mapping,
MutableMapping,
Optional,
@@ -43,7 +44,7 @@ if TYPE_CHECKING:
if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection
else:
- from typing import Container, Iterable, Sized
+ from typing import Container, Sized
T_co = TypeVar("T_co", covariant=True)
@@ -375,7 +376,7 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
return username.decode("ascii")
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, cmp=False)
class RoomStreamToken:
"""Tokens are positions between events. The token "s1" comes after event 1.
@@ -397,6 +398,31 @@ class RoomStreamToken:
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
+
+ There is also a third mode for live tokens where the token starts with "m",
+ which is sometimes used when using sharded event persisters. In this case
+ the events stream is considered to be a set of streams (one for each writer)
+ and the token encodes the vector clock of positions of each writer in their
+ respective streams.
+
+ The format of the token in such case is an initial integer min position,
+ followed by the mapping of instance ID to position separated by '.' and '~':
+
+ m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ...
+
+ The `min_pos` corresponds to the minimum position all writers have persisted
+ up to, and then only writers that are ahead of that position need to be
+ encoded. An example token is:
+
+ m56~2.58~3.59
+
+ Which corresponds to a set of three (or more writers) where instances 2 and
+ 3 (these are instance IDs that can be looked up in the DB to fetch the more
+ commonly used instance names) are at positions 58 and 59 respectively, and
+ all other instances are at position 56.
+
+ Note: The `RoomStreamToken` cannot have both a topological part and an
+ instance map.
"""
topological = attr.ib(
@@ -405,6 +431,25 @@ class RoomStreamToken:
)
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
+ instance_map = attr.ib(
+ type=Dict[str, int],
+ factory=dict,
+ validator=attr.validators.deep_mapping(
+ key_validator=attr.validators.instance_of(str),
+ value_validator=attr.validators.instance_of(int),
+ mapping_validator=attr.validators.instance_of(dict),
+ ),
+ )
+
+ def __attrs_post_init__(self):
+ """Validates that both `topological` and `instance_map` aren't set.
+ """
+
+ if self.instance_map and self.topological:
+ raise ValueError(
+ "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
+ )
+
@classmethod
async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
try:
@@ -413,6 +458,20 @@ class RoomStreamToken:
if string[0] == "t":
parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
+ if string[0] == "m":
+ parts = string[1:].split("~")
+ stream = int(parts[0])
+
+ instance_map = {}
+ for part in parts[1:]:
+ key, value = part.split(".")
+ instance_id = int(key)
+ pos = int(value)
+
+ instance_name = await store.get_name_from_instance_id(instance_id)
+ instance_map[instance_name] = pos
+
+ return cls(topological=None, stream=stream, instance_map=instance_map,)
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@@ -436,14 +495,61 @@ class RoomStreamToken:
max_stream = max(self.stream, other.stream)
- return RoomStreamToken(None, max_stream)
+ instance_map = {
+ instance: max(
+ self.instance_map.get(instance, self.stream),
+ other.instance_map.get(instance, other.stream),
+ )
+ for instance in set(self.instance_map).union(other.instance_map)
+ }
+
+ return RoomStreamToken(None, max_stream, instance_map)
+
+ def as_historical_tuple(self) -> Tuple[int, int]:
+ """Returns a tuple of `(topological, stream)` for historical tokens.
+
+ Raises if not an historical token (i.e. doesn't have a topological part).
+ """
+ if self.topological is None:
+ raise Exception(
+ "Cannot call `RoomStreamToken.as_historical_tuple` on live token"
+ )
- def as_tuple(self) -> Tuple[Optional[int], int]:
return (self.topological, self.stream)
+ def get_stream_pos_for_instance(self, instance_name: str) -> int:
+ """Get the stream position that the given writer was at at this token.
+
+ This only makes sense for "live" tokens that may have a vector clock
+ component, and so asserts that this is a "live" token.
+ """
+ assert self.topological is None
+
+ # If we don't have an entry for the instance we can assume that it was
+ # at `self.stream`.
+ return self.instance_map.get(instance_name, self.stream)
+
+ def get_max_stream_pos(self) -> int:
+ """Get the maximum stream position referenced in this token.
+
+ The corresponding "min" position is, by definition just `self.stream`.
+
+ This is used to handle tokens that have non-empty `instance_map`, and so
+ reference stream positions after the `self.stream` position.
+ """
+ return max(self.instance_map.values(), default=self.stream)
+
async def to_string(self, store: "DataStore") -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
+ elif self.instance_map:
+ entries = []
+ for name, pos in self.instance_map.items():
+ instance_id = await store.get_id_for_instance(name)
+ entries.append("{}.{}".format(instance_id, pos))
+
+ encoded_map = "~".join(entries)
+ return "m{}~{}".format(self.stream, encoded_map)
else:
return "s%d" % (self.stream,)
@@ -535,7 +641,7 @@ class PersistedEventPosition:
stream = attr.ib(type=int)
def persisted_after(self, token: RoomStreamToken) -> bool:
- return token.stream < self.stream
+ return token.get_stream_pos_for_instance(self.instance_name) < self.stream
def to_room_stream_token(self) -> RoomStreamToken:
"""Converts the position to a room stream token such that events
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index df1a721add..32228f42ee 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer
@@ -20,10 +21,15 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
+T = TypeVar("T")
+
-class ResponseCache:
+class ResponseCache(Generic[T]):
"""
This caches a deferred response. Until the deferred completes it will be
returned from the cache. This means that if the client retries the request
@@ -31,8 +37,9 @@ class ResponseCache:
used rather than trying to compute a new response.
"""
- def __init__(self, hs, name, timeout_ms=0):
- self.pending_result_cache = {} # Requests that haven't finished yet.
+ def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
+ # Requests that haven't finished yet.
+ self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.0
@@ -40,13 +47,13 @@ class ResponseCache:
self._name = name
self._metrics = register_cache("response_cache", name, self, resizable=False)
- def size(self):
+ def size(self) -> int:
return len(self.pending_result_cache)
- def __len__(self):
+ def __len__(self) -> int:
return self.size()
- def get(self, key):
+ def get(self, key: T) -> Optional[defer.Deferred]:
"""Look up the given key.
Can return either a new Deferred (which also doesn't follow the synapse
@@ -58,12 +65,11 @@ class ResponseCache:
from an absent cache entry.
Args:
- key (hashable):
+ key: key to get/set in the cache
Returns:
- twisted.internet.defer.Deferred|None|E: None if there is no entry
- for this key; otherwise either a deferred result or the result
- itself.
+ None if there is no entry for this key; otherwise a deferred which
+ resolves to the result.
"""
result = self.pending_result_cache.get(key)
if result is not None:
@@ -73,7 +79,7 @@ class ResponseCache:
self._metrics.inc_misses()
return None
- def set(self, key, deferred):
+ def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred:
"""Set the entry for the given key to the given deferred.
*deferred* should run its callbacks in the sentinel logcontext (ie,
@@ -85,12 +91,11 @@ class ResponseCache:
result. You will probably want to make_deferred_yieldable the result.
Args:
- key (hashable):
- deferred (twisted.internet.defer.Deferred[T):
+ key: key to get/set in the cache
+ deferred: The deferred which resolves to the result.
Returns:
- twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
- result.
+ A new deferred which resolves to the actual result.
"""
result = ObservableDeferred(deferred, consumeErrors=True)
self.pending_result_cache[key] = result
@@ -107,7 +112,9 @@ class ResponseCache:
result.addBoth(remove)
return result.observe()
- def wrap(self, key, callback, *args, **kwargs):
+ def wrap(
+ self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
+ ) -> defer.Deferred:
"""Wrap together a *get* and *set* call, taking care of logcontexts
First looks up the key in the cache, and if it is present makes it
@@ -118,21 +125,20 @@ class ResponseCache:
Example usage:
- @defer.inlineCallbacks
- def handle_request(request):
+ async def handle_request(request):
# etc
return result
- result = yield response_cache.wrap(
+ result = await response_cache.wrap(
key,
handle_request,
request,
)
Args:
- key (hashable): key to get/set in the cache
+ key: key to get/set in the cache
- callback (callable): function to call if the key is not found in
+ callback: function to call if the key is not found in
the cache
*args: positional parameters to pass to the callback, if it is used
@@ -140,7 +146,7 @@ class ResponseCache:
**kwargs: named parameters to pass to the callback, if it is used
Returns:
- twisted.internet.defer.Deferred: yieldable result
+ Deferred which resolves to the result
"""
result = self.get(key)
if not result:
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index bb62db4637..94b59afb38 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -36,7 +36,7 @@ def load_module(provider):
try:
provider_config = provider_class.parse_config(provider.get("config"))
except Exception as e:
- raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
+ raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e))
return provider_class, provider_config
diff --git a/synapse/visibility.py b/synapse/visibility.py
index e3da7744d2..527365498e 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -16,7 +16,7 @@
import logging
import operator
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.events.utils import prune_event
from synapse.storage import Storage
from synapse.storage.state import StateFilter
@@ -77,15 +77,14 @@ async def filter_events_for_client(
)
ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", user_id
+ AccountDataTypes.IGNORED_USER_LIST, user_id
)
- # FIXME: This will explode if people upload something incorrect.
- ignore_list = frozenset(
- ignore_dict_content.get("ignored_users", {}).keys()
- if ignore_dict_content
- else []
- )
+ ignore_list = frozenset()
+ if ignore_dict_content:
+ ignored_users_dict = ignore_dict_content.get("ignored_users", {})
+ if isinstance(ignored_users_dict, dict):
+ ignore_list = frozenset(ignored_users_dict.keys())
erased_senders = await storage.main.are_users_erased((e.sender for e in events))
diff --git a/sytest-blacklist b/sytest-blacklist
index b563448016..de9986357b 100644
--- a/sytest-blacklist
+++ b/sytest-blacklist
@@ -34,9 +34,6 @@ New federated private chats get full presence information (SYN-115)
# this requirement from the spec
Inbound federation of state requires event_id as a mandatory paramater
-# Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands
-Can upload self-signing keys
-
# Blacklisted until MSC2753 is implemented
Local users can peek into world_readable rooms by room ID
We can't peek into rooms with shared history_visibility
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 8ab56ec94c..cb6f29d670 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,7 +19,6 @@ import pymacaroons
from twisted.internet import defer
-import synapse.handlers.auth
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import (
@@ -36,20 +35,15 @@ from tests import unittest
from tests.utils import mock_getRawHeaders, setup_test_homeserver
-class TestHandlers:
- def __init__(self, hs):
- self.auth_handler = synapse.handlers.auth.AuthHandler(hs)
-
-
class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.state_handler = Mock()
self.store = Mock()
- self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
+ self.hs = yield setup_test_homeserver(self.addCleanup)
self.hs.get_datastore = Mock(return_value=self.store)
- self.hs.handlers = TestHandlers(self.hs)
+ self.hs.get_auth_handler().store = self.store
self.auth = Auth(self.hs)
# AuthBlocking reads from the hs' config on initialization. We need to
@@ -283,7 +277,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_device = Mock(return_value=defer.succeed(None))
token = yield defer.ensureDeferred(
- self.hs.handlers.auth_handler.get_access_token_for_user_id(
+ self.hs.get_auth_handler().get_access_token_for_user_id(
USER_ID, "DEVICE", valid_until_ms=None
)
)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index d2d535d23c..c98ae75974 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -50,10 +50,7 @@ class FilteringTestCase(unittest.TestCase):
self.mock_http_client.put_json = DeferredMockCallable()
hs = yield setup_test_homeserver(
- self.addCleanup,
- handlers=None,
- http_client=self.mock_http_client,
- keyring=Mock(),
+ self.addCleanup, http_client=self.mock_http_client, keyring=Mock(),
)
self.filtering = hs.get_filtering()
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 8ff1460c0d..697916a019 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -315,7 +315,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
- hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+ hs = self.setup_test_homeserver(http_client=self.http_client)
return hs
def test_get_keys_from_server(self):
@@ -395,9 +395,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
}
]
- return self.setup_test_homeserver(
- handlers=None, http_client=self.http_client, config=config
- )
+ return self.setup_test_homeserver(http_client=self.http_client, config=config)
def build_perspectives_response(
self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index fc37c4328c..5c2b4de1a6 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -35,7 +35,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
self.user1 = self.register_user("user1", "password")
self.token1 = self.login("user1", "password")
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 97877c2e42..b5055e018c 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -21,24 +21,17 @@ from twisted.internet import defer
import synapse
import synapse.api.errors
from synapse.api.errors import ResourceLimitError
-from synapse.handlers.auth import AuthHandler
from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
-class AuthHandlers:
- def __init__(self, hs):
- self.auth_handler = AuthHandler(hs)
-
-
class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
- self.hs.handlers = AuthHandlers(self.hs)
- self.auth_handler = self.hs.handlers.auth_handler
+ self.hs = yield setup_test_homeserver(self.addCleanup)
+ self.auth_handler = self.hs.get_auth_handler()
self.macaroon_generator = self.hs.get_macaroon_generator()
# MAU tests
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 969d44c787..4512c51311 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -224,3 +225,84 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
)
self.reactor.advance(1000)
+
+
+class DehydrationTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver("server", http_client=None)
+ self.handler = hs.get_device_handler()
+ self.registration = hs.get_registration_handler()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ return hs
+
+ def test_dehydrate_and_rehydrate_device(self):
+ user_id = "@boris:dehydration"
+
+ self.get_success(self.store.register_user(user_id, "foobar"))
+
+ # First check if we can store and fetch a dehydrated device
+ stored_dehydrated_device_id = self.get_success(
+ self.handler.store_dehydrated_device(
+ user_id=user_id,
+ device_data={"device_data": {"foo": "bar"}},
+ initial_device_display_name="dehydrated device",
+ )
+ )
+
+ retrieved_device_id, device_data = self.get_success(
+ self.handler.get_dehydrated_device(user_id=user_id)
+ )
+
+ self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
+ self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
+
+ # Create a new login for the user and dehydrated the device
+ device_id, access_token = self.get_success(
+ self.registration.register_device(
+ user_id=user_id, device_id=None, initial_display_name="new device",
+ )
+ )
+
+ # Trying to claim a nonexistent device should throw an error
+ self.get_failure(
+ self.handler.rehydrate_device(
+ user_id=user_id,
+ access_token=access_token,
+ device_id="not the right device ID",
+ ),
+ synapse.api.errors.NotFoundError,
+ )
+
+ # dehydrating the right devices should succeed and change our device ID
+ # to the dehydrated device's ID
+ res = self.get_success(
+ self.handler.rehydrate_device(
+ user_id=user_id,
+ access_token=access_token,
+ device_id=retrieved_device_id,
+ )
+ )
+
+ self.assertEqual(res, {"success": True})
+
+ # make sure that our device ID has changed
+ user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
+
+ self.assertEqual(user_info["device_id"], retrieved_device_id)
+
+ # make sure the device has the display name that was set from the login
+ res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
+
+ self.assertEqual(res["display_name"], "new device")
+
+ # make sure that the device ID that we were initially assigned no longer exists
+ self.get_failure(
+ self.handler.get_device(user_id, device_id),
+ synapse.api.errors.NotFoundError,
+ )
+
+ # make sure that there's no device available for dehydrating now
+ ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
+
+ self.assertIsNone(ret)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index bc0c5aefdc..2ce6dc9528 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -48,7 +48,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
federation_registry=self.mock_registry,
)
- self.handler = hs.get_handlers().directory_handler
+ self.handler = hs.get_directory_handler()
self.store = hs.get_datastore()
@@ -110,7 +110,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.handler = hs.get_handlers().directory_handler
+ self.handler = hs.get_directory_handler()
# Create user
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -173,7 +173,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.handler = hs.get_handlers().directory_handler
+ self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
# Create user
@@ -289,7 +289,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.handler = hs.get_handlers().directory_handler
+ self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
# Create user
@@ -442,7 +442,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.result)
self.room_list_handler = hs.get_room_list_handler()
- self.directory_handler = hs.get_handlers().directory_handler
+ self.directory_handler = hs.get_directory_handler()
return hs
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 366dcfb670..924f29f051 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -33,13 +33,15 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
super().__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
+ self.store = None # type: synapse.storage.Storage
@defer.inlineCallbacks
def setUp(self):
self.hs = yield utils.setup_test_homeserver(
- self.addCleanup, handlers=None, federation_client=mock.Mock()
+ self.addCleanup, federation_client=mock.Mock()
)
self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
+ self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_query_local_devices_no_devices(self):
@@ -172,6 +174,89 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
+ def test_fallback_key(self):
+ local_user = "@boris:" + self.hs.hostname
+ device_id = "xyz"
+ fallback_key = {"alg1:k1": "key1"}
+ otk = {"alg1:k2": "key2"}
+
+ # we shouldn't have any unused fallback keys yet
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, [])
+
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"org.matrix.msc2732.fallback_keys": fallback_key},
+ )
+ )
+
+ # we should now have an unused alg1 key
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, ["alg1"])
+
+ # claiming an OTK when no OTKs are available should return the fallback
+ # key
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ # we shouldn't have any unused fallback keys again
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, [])
+
+ # claiming an OTK again should return the same fallback key
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ # if the user uploads a one-time key, the next claim should fetch the
+ # one-time key, and then go back to the fallback
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": otk}
+ )
+ )
+
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
+ )
+
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ @defer.inlineCallbacks
def test_replace_master_key(self):
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 7adde9b9de..45f201a399 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -54,7 +54,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield utils.setup_test_homeserver(
- self.addCleanup, handlers=None, replication_layer=mock.Mock()
+ self.addCleanup, replication_layer=mock.Mock()
)
self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
self.local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 96fea58673..9ef80fe502 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -38,7 +38,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(http_client=None)
- self.handler = hs.get_handlers().federation_handler
+ self.handler = hs.get_federation_handler()
self.store = hs.get_datastore()
return hs
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index d5087e58be..b6f436c016 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -286,9 +286,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
h._validate_metadata,
)
- # Tests for configs that the userinfo endpoint
+ # Tests for configs that require the userinfo endpoint
self.assertFalse(h._uses_userinfo)
- h._scopes = [] # do not request the openid scope
+ self.assertEqual(h._user_profile_method, "auto")
+ h._user_profile_method = "userinfo_endpoint"
+ self.assertTrue(h._uses_userinfo)
+
+ # Revert the profile method and do not request the "openid" scope.
+ h._user_profile_method = "auto"
+ h._scopes = []
self.assertTrue(h._uses_userinfo)
self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 306dcfe944..914c82e7a8 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -470,7 +470,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.federation_sender = hs.get_federation_sender()
self.event_builder_factory = hs.get_event_builder_factory()
- self.federation_handler = hs.get_handlers().federation_handler
+ self.federation_handler = hs.get_federation_handler()
self.presence_handler = hs.get_presence_handler()
# self.event_builder_for_2 = EventBuilderFactory(hs)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 8e95e53d9e..a69fa28b41 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -20,7 +20,6 @@ from twisted.internet import defer
import synapse.types
from synapse.api.errors import AuthError, SynapseError
-from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
from tests import unittest
@@ -28,11 +27,6 @@ from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
-class ProfileHandlers:
- def __init__(self, hs):
- self.profile_handler = MasterProfileHandler(hs)
-
-
class ProfileTestCase(unittest.TestCase):
""" Tests profile management. """
@@ -51,7 +45,6 @@ class ProfileTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
self.addCleanup,
http_client=None,
- handlers=None,
resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_server=Mock(),
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index cb7c0ed51a..bdf3d0a8a2 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,7 +18,6 @@ from mock import Mock
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
-from synapse.handlers.register import RegistrationHandler
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
@@ -29,11 +28,6 @@ from tests.utils import mock_getRawHeaders
from .. import unittest
-class RegistrationHandlers:
- def __init__(self, hs):
- self.registration_handler = RegistrationHandler(hs)
-
-
class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """
@@ -154,7 +148,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -193,7 +187,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
self.get_failure(directory_handler.get_association(room_alias), SynapseError)
@@ -205,7 +199,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -237,7 +231,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -266,7 +260,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -304,7 +298,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -347,7 +341,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -384,7 +378,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -413,7 +407,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- event_creation_handler.send_nonmember_event(requester, event, context)
+ event_creation_handler.handle_new_client_event(requester, event, context)
)
# Register a second user, which won't be be in the room (or even have an invite)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 04de0b9dbe..9b573ac24d 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -12,16 +12,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
-from synapse.module_api import ModuleApi
+from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import create_requester
from tests.unittest import HomeserverTestCase
class ModuleApiTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
- self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler())
+ self.module_api = homeserver.get_module_api()
+ self.event_creation_handler = homeserver.get_event_creation_handler()
def test_can_register_user(self):
"""Tests that an external module can register a user"""
@@ -52,3 +63,141 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the displayname was assigned
displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino")
+
+ def test_sending_events_into_room(self):
+ """Tests that a module can send events into a room"""
+ # Mock out create_and_send_nonmember_event to check whether events are being sent
+ self.event_creation_handler.create_and_send_nonmember_event = Mock(
+ spec=[],
+ side_effect=self.event_creation_handler.create_and_send_nonmember_event,
+ )
+
+ # Create a user and room to play with
+ user_id = self.register_user("summer", "monkey")
+ tok = self.login("summer", "monkey")
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ # Create and send a non-state event
+ content = {"body": "I am a puppet", "msgtype": "m.text"}
+ event_dict = {
+ "room_id": room_id,
+ "type": "m.room.message",
+ "content": content,
+ "sender": user_id,
+ }
+ event = self.get_success(
+ self.module_api.create_and_send_event_into_room(event_dict)
+ ) # type: EventBase
+ self.assertEqual(event.sender, user_id)
+ self.assertEqual(event.type, "m.room.message")
+ self.assertEqual(event.room_id, room_id)
+ self.assertFalse(hasattr(event, "state_key"))
+ self.assertDictEqual(event.content, content)
+
+ # Check that the event was sent
+ self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
+ create_requester(user_id),
+ event_dict,
+ ratelimit=False,
+ ignore_shadow_ban=True,
+ )
+
+ # Create and send a state event
+ content = {
+ "events_default": 0,
+ "users": {user_id: 100},
+ "state_default": 50,
+ "users_default": 0,
+ "events": {"test.event.type": 25},
+ }
+ event_dict = {
+ "room_id": room_id,
+ "type": "m.room.power_levels",
+ "content": content,
+ "sender": user_id,
+ "state_key": "",
+ }
+ event = self.get_success(
+ self.module_api.create_and_send_event_into_room(event_dict)
+ ) # type: EventBase
+ self.assertEqual(event.sender, user_id)
+ self.assertEqual(event.type, "m.room.power_levels")
+ self.assertEqual(event.room_id, room_id)
+ self.assertEqual(event.state_key, "")
+ self.assertDictEqual(event.content, content)
+
+ # Check that the event was sent
+ self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
+ create_requester(user_id),
+ {
+ "type": "m.room.power_levels",
+ "content": content,
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": "",
+ },
+ ratelimit=False,
+ ignore_shadow_ban=True,
+ )
+
+ # Check that we can't send membership events
+ content = {
+ "membership": "leave",
+ }
+ event_dict = {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "content": content,
+ "sender": user_id,
+ "state_key": user_id,
+ }
+ self.get_failure(
+ self.module_api.create_and_send_event_into_room(event_dict), Exception
+ )
+
+ def test_public_rooms(self):
+ """Tests that a room can be added and removed from the public rooms list,
+ as well as have its public rooms directory state queried.
+ """
+ # Create a user and room to play with
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ # The room should not currently be in the public rooms directory
+ is_in_public_rooms = self.get_success(
+ self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ room_id
+ )
+ )
+ self.assertFalse(is_in_public_rooms)
+
+ # Let's try adding it to the public rooms directory
+ self.get_success(
+ self.module_api.public_room_list_manager.add_room_to_public_room_list(
+ room_id
+ )
+ )
+
+ # And checking whether it's in there...
+ is_in_public_rooms = self.get_success(
+ self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ room_id
+ )
+ )
+ self.assertTrue(is_in_public_rooms)
+
+ # Let's remove it again
+ self.get_success(
+ self.module_api.public_room_list_manager.remove_room_from_public_room_list(
+ room_id
+ )
+ )
+
+ # Should be gone
+ is_in_public_rooms = self.get_success(
+ self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ room_id
+ )
+ )
+ self.assertFalse(is_in_public_rooms)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index ae60874ec3..81ea985b9f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import Any, Callable, List, Optional, Tuple
import attr
+import hiredis
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
+from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
@@ -27,7 +28,7 @@ from synapse.app.generic_worker import (
GenericWorkerServer,
)
from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
+from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
@@ -197,19 +198,37 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
+ # Fake in memory Redis server that servers can connect to.
+ self._redis_server = FakeRedisPubSubServer()
+
store = self.hs.get_datastore()
self.database_pool = store.db_pool
self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.reactor.lookups["localhost"] = "127.0.0.1"
+
+ # A map from a HS instance to the associated HTTP Site to use for
+ # handling inbound HTTP requests to that instance.
+ self._hs_to_site = {self.hs: self.site}
+
+ if self.hs.config.redis.redis_enabled:
+ # Handle attempts to connect to fake redis server.
+ self.reactor.add_tcp_client_callback(
+ "localhost", 6379, self.connect_any_redis_attempts,
+ )
- self._worker_hs_to_resource = {}
+ self.hs.get_tcp_replication().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests).
+ #
+ # Register the master replication listener:
self.reactor.add_tcp_client_callback(
- "1.2.3.4", 8765, self._handle_http_replication_attempt
+ "1.2.3.4",
+ 8765,
+ lambda: self._handle_http_replication_attempt(self.hs, 8765),
)
def create_test_json_resource(self):
@@ -253,28 +272,63 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
**kwargs
)
+ # If the instance is in the `instance_map` config then workers may try
+ # and send HTTP requests to it, so we register it with
+ # `_handle_http_replication_attempt` like we do with the master HS.
+ instance_name = worker_hs.get_instance_name()
+ instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
+ if instance_loc:
+ # Ensure the host is one that has a fake DNS entry.
+ if instance_loc.host not in self.reactor.lookups:
+ raise Exception(
+ "Host does not have an IP for instance_map[%r].host = %r"
+ % (instance_name, instance_loc.host,)
+ )
+
+ self.reactor.add_tcp_client_callback(
+ self.reactor.lookups[instance_loc.host],
+ instance_loc.port,
+ lambda: self._handle_http_replication_attempt(
+ worker_hs, instance_loc.port
+ ),
+ )
+
store = worker_hs.get_datastore()
store.db_pool._db_pool = self.database_pool._db_pool
- repl_handler = ReplicationCommandHandler(worker_hs)
- client = ClientReplicationStreamProtocol(
- worker_hs, "client", "test", self.clock, repl_handler,
- )
- server = self.server_factory.buildProtocol(None)
+ # Set up TCP replication between master and the new worker if we don't
+ # have Redis support enabled.
+ if not worker_hs.config.redis_enabled:
+ repl_handler = ReplicationCommandHandler(worker_hs)
+ client = ClientReplicationStreamProtocol(
+ worker_hs, "client", "test", self.clock, repl_handler,
+ )
+ server = self.server_factory.buildProtocol(None)
- client_transport = FakeTransport(server, self.reactor)
- client.makeConnection(client_transport)
+ client_transport = FakeTransport(server, self.reactor)
+ client.makeConnection(client_transport)
- server_transport = FakeTransport(client, self.reactor)
- server.makeConnection(server_transport)
+ server_transport = FakeTransport(client, self.reactor)
+ server.makeConnection(server_transport)
# Set up a resource for the worker
- resource = ReplicationRestResource(self.hs)
+ resource = ReplicationRestResource(worker_hs)
for servlet in self.servlets:
servlet(worker_hs, resource)
- self._worker_hs_to_resource[worker_hs] = resource
+ self._hs_to_site[worker_hs] = SynapseSite(
+ logger_name="synapse.access.http.fake",
+ site_tag="{}-{}".format(
+ worker_hs.config.server.server_name, worker_hs.get_instance_name()
+ ),
+ config=worker_hs.config.server.listeners[0],
+ resource=resource,
+ server_version_string="1",
+ )
+
+ if worker_hs.config.redis.redis_enabled:
+ worker_hs.get_tcp_replication().start_replication(worker_hs)
return worker_hs
@@ -285,7 +339,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return config
def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
- render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+ render(request, self._hs_to_site[worker_hs].resource, self.reactor)
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -294,9 +348,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump()
- def _handle_http_replication_attempt(self):
- """Handles a connection attempt to the master replication HTTP
- listener.
+ def _handle_http_replication_attempt(self, hs, repl_port):
+ """Handles a connection attempt to the given HS replication HTTP
+ listener on the given port.
"""
# We should have at least one outbound connection attempt, where the
@@ -305,7 +359,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, "1.2.3.4")
- self.assertEqual(port, 8765)
+ self.assertEqual(port, repl_port)
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
@@ -315,7 +369,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
- channel.site = self.site
+ channel.site = self._hs_to_site[hs]
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -333,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
+ def connect_any_redis_attempts(self):
+ """If redis is enabled we need to deal with workers connecting to a
+ redis server. We don't want to use a real Redis server so we use a
+ fake one.
+ """
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 6379)
+
+ client_protocol = client_factory.buildProtocol(None)
+ server_protocol = self._redis_server.buildProtocol(None)
+
+ client_to_server_transport = FakeTransport(
+ server_protocol, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
+
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, server_protocol
+ )
+ server_protocol.makeConnection(server_to_client_transport)
+
+ return client_to_server_transport, server_to_client_transport
+
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -467,3 +547,105 @@ class _PullToPushProducer:
pass
self.stopProducing()
+
+
+class FakeRedisPubSubServer:
+ """A fake Redis server for pub/sub.
+ """
+
+ def __init__(self):
+ self._subscribers = set()
+
+ def add_subscriber(self, conn):
+ """A connection has called SUBSCRIBE
+ """
+ self._subscribers.add(conn)
+
+ def remove_subscriber(self, conn):
+ """A connection has called UNSUBSCRIBE
+ """
+ self._subscribers.discard(conn)
+
+ def publish(self, conn, channel, msg) -> int:
+ """A connection want to publish a message to subscribers.
+ """
+ for sub in self._subscribers:
+ sub.send(["message", channel, msg])
+
+ return len(self._subscribers)
+
+ def buildProtocol(self, addr):
+ return FakeRedisPubSubProtocol(self)
+
+
+class FakeRedisPubSubProtocol(Protocol):
+ """A connection from a client talking to the fake Redis server.
+ """
+
+ def __init__(self, server: FakeRedisPubSubServer):
+ self._server = server
+ self._reader = hiredis.Reader()
+
+ def dataReceived(self, data):
+ self._reader.feed(data)
+
+ # We might get multiple messages in one packet.
+ while True:
+ msg = self._reader.gets()
+
+ if msg is False:
+ # No more messages.
+ return
+
+ if not isinstance(msg, list):
+ # Inbound commands should always be a list
+ raise Exception("Expected redis list")
+
+ self.handle_command(msg[0], *msg[1:])
+
+ def handle_command(self, command, *args):
+ """Received a Redis command from the client.
+ """
+
+ # We currently only support pub/sub.
+ if command == b"PUBLISH":
+ channel, message = args
+ num_subscribers = self._server.publish(self, channel, message)
+ self.send(num_subscribers)
+ elif command == b"SUBSCRIBE":
+ (channel,) = args
+ self._server.add_subscriber(self)
+ self.send(["subscribe", channel, 1])
+ else:
+ raise Exception("Unknown command")
+
+ def send(self, msg):
+ """Send a message back to the client.
+ """
+ raw = self.encode(msg).encode("utf-8")
+
+ self.transport.write(raw)
+ self.transport.flush()
+
+ def encode(self, obj):
+ """Encode an object to its Redis format.
+
+ Supports: strings/bytes, integers and list/tuples.
+ """
+
+ if isinstance(obj, bytes):
+ # We assume bytes are just unicode strings.
+ obj = obj.decode("utf-8")
+
+ if isinstance(obj, str):
+ return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
+ if isinstance(obj, int):
+ return ":{val}\r\n".format(val=obj)
+ if isinstance(obj, (list, tuple)):
+ items = "".join(self.encode(a) for a in obj)
+ return "*{len}\r\n{items}".format(len=len(obj), items=items)
+
+ raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
+
+ def connectionLost(self, reason):
+ self._server.remove_subscriber(self)
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 1d7edee5ba..9c4a9c3563 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -207,7 +207,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastore()
- federation = self.hs.get_handlers().federation_handler
+ federation = self.hs.get_federation_handler()
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
room_version = self.get_success(store.get_room_version(room))
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
new file mode 100644
index 0000000000..6068d14905
--- /dev/null
+++ b/tests/replication/test_sharded_event_persister.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+logger = logging.getLogger(__name__)
+
+
+class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks event persisting sharding works
+ """
+
+ # Event persister sharding requires postgres (due to needing
+ # `MutliWriterIdGenerator`).
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # Register a user who sends a message that we'll get notified about
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["redis"] = {"enabled": "true"}
+ conf["stream_writers"] = {"events": ["worker1", "worker2"]}
+ conf["instance_map"] = {
+ "worker1": {"host": "testserv", "port": 1001},
+ "worker2": {"host": "testserv", "port": 1002},
+ }
+ return conf
+
+ def test_basic(self):
+ """Simple test to ensure that multiple rooms can be created and joined,
+ and that different rooms get handled by different instances.
+ """
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker1"},
+ )
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker2"},
+ )
+
+ persisted_on_1 = False
+ persisted_on_2 = False
+
+ store = self.hs.get_datastore()
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Keep making new rooms until we see rooms being persisted on both
+ # workers.
+ for _ in range(10):
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # The other user sends some messages
+ rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+ event_id = rseponse["event_id"]
+
+ # The event position includes which instance persisted the event.
+ pos = self.get_success(store.get_position_for_event(event_id))
+
+ persisted_on_1 |= pos.instance_name == "worker1"
+ persisted_on_2 |= pos.instance_name == "worker2"
+
+ if persisted_on_1 and persisted_on_2:
+ break
+
+ self.assertTrue(persisted_on_1)
+ self.assertTrue(persisted_on_2)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index dfe4bf7762..6bb02b9630 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -78,7 +78,7 @@ class RoomTestCase(_ShadowBannedBase):
def test_invite_3pid(self):
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
- identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler = self.hs.get_identity_handler()
identity_handler.lookup_3pid = Mock(
side_effect=AssertionError("This should not get called")
)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
new file mode 100644
index 0000000000..d03e121664
--- /dev/null
+++ b/tests/rest/client/test_third_party_rules.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import threading
+from typing import Dict
+
+from mock import Mock
+
+from synapse.events import EventBase
+from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import Requester, StateMap
+
+from tests import unittest
+
+thread_local = threading.local()
+
+
+class ThirdPartyRulesTestModule:
+ def __init__(self, config: Dict, module_api: ModuleApi):
+ # keep a record of the "current" rules module, so that the test can patch
+ # it if desired.
+ thread_local.rules_module = self
+ self.module_api = module_api
+
+ async def on_create_room(
+ self, requester: Requester, config: dict, is_requester_admin: bool
+ ):
+ return True
+
+ async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ return config
+
+
+def current_rules_module() -> ThirdPartyRulesTestModule:
+ return thread_local.rules_module
+
+
+class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["third_party_event_rules"] = {
+ "module": __name__ + ".ThirdPartyRulesTestModule",
+ "config": {},
+ }
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ # Create a user and room to play with during the tests
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_third_party_rules(self):
+ """Tests that a forbidden event is forbidden from being sent, but an allowed one
+ can be sent.
+ """
+ # patch the rules module with a Mock which will return False for some event
+ # types
+ async def check(ev, state):
+ return ev.type != "foo.bar.forbidden"
+
+ callback = Mock(spec=[], side_effect=check)
+ current_rules_module().check_event_allowed = callback
+
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ callback.assert_called_once()
+
+ # there should be various state events in the state arg: do some basic checks
+ state_arg = callback.call_args[0][1]
+ for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
+ self.assertIn(k, state_arg)
+ ev = state_arg[k]
+ self.assertEqual(ev.type, k[0])
+ self.assertEqual(ev.state_key, k[1])
+
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_modify_event(self):
+ """Tests that the module can successfully tweak an event before it is persisted.
+ """
+ # first patch the event checker so that it will modify the event
+ async def check(ev: EventBase, state):
+ ev.content = {"x": "y"}
+ return True
+
+ current_rules_module().check_event_allowed = check
+
+ # now send the event
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+ {"x": "x"},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ event_id = channel.json_body["event_id"]
+
+ # ... and check that it got modified
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["x"], "y")
+
+ def test_send_event(self):
+ """Tests that the module can send an event into a room via the module api"""
+ content = {
+ "msgtype": "m.text",
+ "body": "Hello!",
+ }
+ event_dict = {
+ "room_id": self.room_id,
+ "type": "m.room.message",
+ "content": content,
+ "sender": self.user_id,
+ }
+ event = self.get_success(
+ current_rules_module().module_api.create_and_send_event_into_room(
+ event_dict
+ )
+ ) # type: EventBase
+
+ self.assertEquals(event.sender, self.user_id)
+ self.assertEquals(event.room_id, self.room_id)
+ self.assertEquals(event.type, "m.room.message")
+ self.assertEquals(event.content, content)
diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py
deleted file mode 100644
index 8c24add530..0000000000
--- a/tests/rest/client/third_party_rules.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-
-from tests import unittest
-
-
-class ThirdPartyRulesTestModule:
- def __init__(self, config):
- pass
-
- def check_event_allowed(self, event, context):
- if event.type == "foo.bar.forbidden":
- return False
- else:
- return True
-
- @staticmethod
- def parse_config(config):
- return config
-
-
-class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
- servlets = [
- admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def make_homeserver(self, reactor, clock):
- config = self.default_config()
- config["third_party_event_rules"] = {
- "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule",
- "config": {},
- }
-
- self.hs = self.setup_test_homeserver(config=config)
- return self.hs
-
- def test_third_party_rules(self):
- """Tests that a forbidden event is forbidden from being sent, but an allowed one
- can be sent.
- """
- user_id = self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
-
- room_id = self.helper.create_room_as(user_id, tok=tok)
-
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id,
- {},
- access_token=tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"200", channel.result)
-
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id,
- {},
- access_token=tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index 633b7dbda0..ea5a7f3739 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/v1/test_directory.py
@@ -21,6 +21,7 @@ from synapse.types import RoomAlias
from synapse.util.stringutils import random_string
from tests import unittest
+from tests.unittest import override_config
class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -67,10 +68,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.ensure_user_joined_room()
self.set_alias_via_directory(400, alias_length=256)
- def test_state_event_in_room(self):
+ @override_config({"default_room_version": 5})
+ def test_state_event_user_in_v5_room(self):
+ """Test that a regular user can add alias events before room v6"""
self.ensure_user_joined_room()
self.set_alias_via_state_event(200)
+ @override_config({"default_room_version": 6})
+ def test_state_event_v6_room(self):
+ """Test that a regular user can *not* add alias events from room v6"""
+ self.ensure_user_joined_room()
+ self.set_alias_via_state_event(403)
+
def test_directory_in_room(self):
self.ensure_user_joined_room()
self.set_alias_via_directory(200)
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index f75520877f..3397ba5579 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -42,7 +42,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
- hs.get_handlers().federation_handler = Mock()
+ hs.get_federation_handler = Mock()
return hs
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 0d809d25d5..9ba5f9d943 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -32,6 +32,7 @@ from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
from tests import unittest
+from tests.test_utils import make_awaitable
PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -47,7 +48,10 @@ class RoomBase(unittest.HomeserverTestCase):
"red", http_client=None, federation_client=Mock(),
)
- self.hs.get_federation_handler = Mock(return_value=Mock())
+ self.hs.get_federation_handler = Mock()
+ self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
+ return_value=make_awaitable(None)
+ )
async def _insert_client_ip(*args, **kwargs):
return None
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 94d2bf2eb1..cd58ee7792 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -44,7 +44,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.event_source = hs.get_event_sources().sources["typing"]
- hs.get_handlers().federation_handler = Mock()
+ hs.get_federation_handler = Mock()
async def get_user_by_access_token(token=None, allow_guest=False):
return {
diff --git a/tests/server.py b/tests/server.py
index b404ad4e2a..4d33b84097 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,8 +1,11 @@
import json
import logging
+from collections import deque
from io import SEEK_END, BytesIO
+from typing import Callable
import attr
+from typing_extensions import Deque
from zope.interface import implementer
from twisted.internet import address, threads, udp
@@ -251,6 +254,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self._tcp_callbacks = {}
self._udp = []
lookups = self.lookups = {}
+ self._thread_callbacks = deque() # type: Deque[Callable[[], None]]()
@implementer(IResolverSimple)
class FakeResolver:
@@ -272,10 +276,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
Make the callback fire in the next reactor iteration.
"""
- d = Deferred()
- d.addCallback(lambda x: callback(*args, **kwargs))
- self.callLater(0, d.callback, True)
- return d
+ cb = lambda: callback(*args, **kwargs)
+ # it's not safe to call callLater() here, so we append the callback to a
+ # separate queue.
+ self._thread_callbacks.append(cb)
def getThreadPool(self):
return self.threadpool
@@ -303,6 +307,30 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return conn
+ def advance(self, amount):
+ # first advance our reactor's time, and run any "callLater" callbacks that
+ # makes ready
+ super().advance(amount)
+
+ # now run any "callFromThread" callbacks
+ while True:
+ try:
+ callback = self._thread_callbacks.popleft()
+ except IndexError:
+ break
+ callback()
+
+ # check for more "callLater" callbacks added by the thread callback
+ # This isn't required in a regular reactor, but it ends up meaning that
+ # our database queries can complete in a single call to `advance` [1] which
+ # simplifies tests.
+ #
+ # [1]: we replace the threadpool backing the db connection pool with a
+ # mock ThreadPool which doesn't really use threads; but we still use
+ # reactor.callFromThread to feed results back from the db functions to the
+ # main thread.
+ super().advance(0)
+
class ThreadPool:
"""
@@ -339,8 +367,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
"""
server = _sth(cleanup_func, *args, **kwargs)
- database = server.config.database.get_single_database()
-
# Make the thread pool synchronous.
clock = server.get_clock()
@@ -372,6 +398,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
+ # We've just changed the Databases to run DB transactions on the same
+ # thread, so we need to disable the dedicated thread behaviour.
+ server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
+
return server
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 46f94914ff..c905a38930 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
# must be done after inserts
database = hs.get_datastores().databases[0]
self.store = ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database, make_conn(database._database_config, database.engine, "test"), hs
)
def tearDown(self):
@@ -132,7 +132,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
db_config = hs.config.get_single_database()
self.store = TestTransactionStore(
- database, make_conn(db_config, self.engine), hs
+ database, make_conn(db_config, self.engine, "test"), hs
)
def _add_service(self, url, as_token, id):
@@ -448,7 +448,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database, make_conn(database._database_config, database.engine, "test"), hs
)
@defer.inlineCallbacks
@@ -467,7 +467,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
with self.assertRaises(ConfigError) as cm:
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database,
+ make_conn(database._database_config, database.engine, "test"),
+ hs,
)
e = cm.exception
@@ -491,7 +493,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
with self.assertRaises(ConfigError) as cm:
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database,
+ make_conn(database._database_config, database.engine, "test"),
+ hs,
)
e = cm.exception
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 392b08832b..cc0612cf65 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -199,10 +199,17 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
- self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ # The first ID gen will notice that it can advance its token to 7 as it
+ # has no in progress writes...
+ self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+ # ... but the second ID gen doesn't know that.
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -211,7 +218,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(
- first_id_gen.get_positions(), {"first": 3, "second": 7}
+ first_id_gen.get_positions(), {"first": 7, "second": 7}
)
self.get_success(_get_next_async())
@@ -279,7 +286,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first", writers=["first", "second"])
+ id_gen = self._create_id_generator("worker", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
@@ -319,14 +326,14 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator("first", writers=["first", "second"])
- self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+ self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
- self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 5)
async def _get_next_async():
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6)
- self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 5)
self.get_success(_get_next_async())
@@ -388,7 +395,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("second", 5)
# Initial config has two writers
- id_gen = self._create_id_generator("first", writers=["first", "second"])
+ id_gen = self._create_id_generator("worker", writers=["first", "second"])
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
self.assertEqual(id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(id_gen.get_current_token_for_writer("second"), 5)
@@ -568,7 +575,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async2())
- self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 27a7fc9ed7..d39e792580 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -75,7 +75,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- self.handler = self.homeserver.get_handlers().federation_handler
+ self.handler = self.homeserver.get_federation_handler()
self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
context
)
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index 7657bddea5..e7aed092c2 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -17,7 +17,7 @@ import resource
import mock
-from synapse.app.homeserver import phone_stats_home
+from synapse.app.phone_stats_home import phone_stats_home
from tests.unittest import HomeserverTestCase
diff --git a/tests/unittest.py b/tests/unittest.py
index e654c0442d..5c87f6097e 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -241,7 +241,7 @@ class HomeserverTestCase(TestCase):
# create a site to wrap the resource.
self.site = SynapseSite(
logger_name="synapse.access.http.fake",
- site_tag="test",
+ site_tag=self.hs.config.server.server_name,
config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
@@ -608,7 +608,9 @@ class HomeserverTestCase(TestCase):
if soft_failed:
event.internal_metadata.soft_failed = True
- self.get_success(event_creator.send_nonmember_event(requester, event, context))
+ self.get_success(
+ event_creator.handle_new_client_event(requester, event, context)
+ )
return event.event_id
diff --git a/tests/utils.py b/tests/utils.py
index 4673872f88..0c09f5457f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -38,6 +38,7 @@ from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -88,6 +89,7 @@ def setupdb():
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
+ db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(db_conn, db_engine, None)
db_conn.close()
@@ -190,7 +192,6 @@ class TestHomeServer(HomeServer):
def setup_test_homeserver(
cleanup_func,
name="test",
- datastore=None,
config=None,
reactor=None,
homeserverToUse=TestHomeServer,
@@ -247,7 +248,7 @@ def setup_test_homeserver(
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
- if datastore is None and isinstance(db_engine, PostgresEngine):
+ if isinstance(db_engine, PostgresEngine):
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
@@ -263,79 +264,66 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
- if datastore is None:
- hs = homeserverToUse(
- name,
- config=config,
- version_string="Synapse/tests",
- tls_server_context_factory=Mock(),
- tls_client_options_factory=Mock(),
- reactor=reactor,
- **kargs
- )
+ hs = homeserverToUse(
+ name,
+ config=config,
+ version_string="Synapse/tests",
+ tls_server_context_factory=Mock(),
+ tls_client_options_factory=Mock(),
+ reactor=reactor,
+ **kargs
+ )
- hs.setup()
- if homeserverToUse.__name__ == "TestHomeServer":
- hs.setup_master()
+ hs.setup()
+ if homeserverToUse.__name__ == "TestHomeServer":
+ hs.setup_background_tasks()
- if isinstance(db_engine, PostgresEngine):
- database = hs.get_datastores().databases[0]
+ if isinstance(db_engine, PostgresEngine):
+ database = hs.get_datastores().databases[0]
- # We need to do cleanup on PostgreSQL
- def cleanup():
- import psycopg2
+ # We need to do cleanup on PostgreSQL
+ def cleanup():
+ import psycopg2
- # Close all the db pools
- database._db_pool.close()
+ # Close all the db pools
+ database._db_pool.close()
- dropped = False
+ dropped = False
- # Drop the test database
- db_conn = db_engine.module.connect(
- database=POSTGRES_BASE_DB,
- user=POSTGRES_USER,
- host=POSTGRES_HOST,
- password=POSTGRES_PASSWORD,
- )
- db_conn.autocommit = True
- cur = db_conn.cursor()
-
- # Try a few times to drop the DB. Some things may hold on to the
- # database for a few more seconds due to flakiness, preventing
- # us from dropping it when the test is over. If we can't drop
- # it, warn and move on.
- for x in range(5):
- try:
- cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
- db_conn.commit()
- dropped = True
- except psycopg2.OperationalError as e:
- warnings.warn(
- "Couldn't drop old db: " + str(e), category=UserWarning
- )
- time.sleep(0.5)
-
- cur.close()
- db_conn.close()
-
- if not dropped:
- warnings.warn("Failed to drop old DB.", category=UserWarning)
-
- if not LEAVE_DB:
- # Register the cleanup hook
- cleanup_func(cleanup)
+ # Drop the test database
+ db_conn = db_engine.module.connect(
+ database=POSTGRES_BASE_DB,
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ )
+ db_conn.autocommit = True
+ cur = db_conn.cursor()
- else:
- hs = homeserverToUse(
- name,
- datastore=datastore,
- config=config,
- version_string="Synapse/tests",
- tls_server_context_factory=Mock(),
- tls_client_options_factory=Mock(),
- reactor=reactor,
- **kargs
- )
+ # Try a few times to drop the DB. Some things may hold on to the
+ # database for a few more seconds due to flakiness, preventing
+ # us from dropping it when the test is over. If we can't drop
+ # it, warn and move on.
+ for x in range(5):
+ try:
+ cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+ db_conn.commit()
+ dropped = True
+ except psycopg2.OperationalError as e:
+ warnings.warn(
+ "Couldn't drop old db: " + str(e), category=UserWarning
+ )
+ time.sleep(0.5)
+
+ cur.close()
+ db_conn.close()
+
+ if not dropped:
+ warnings.warn("Failed to drop old DB.", category=UserWarning)
+
+ if not LEAVE_DB:
+ # Register the cleanup hook
+ cleanup_func(cleanup)
# bcrypt is far too slow to be doing in unit tests
# Need to let the HS build an auth handler and then mess with it
|