diff --git a/.buildkite/docker-compose-env b/.buildkite/docker-compose-env
new file mode 100644
index 0000000000..85b102d07f
--- /dev/null
+++ b/.buildkite/docker-compose-env
@@ -0,0 +1,13 @@
+CI
+BUILDKITE
+BUILDKITE_BUILD_NUMBER
+BUILDKITE_BRANCH
+BUILDKITE_BUILD_NUMBER
+BUILDKITE_JOB_ID
+BUILDKITE_BUILD_URL
+BUILDKITE_PROJECT_SLUG
+BUILDKITE_COMMIT
+BUILDKITE_PULL_REQUEST
+BUILDKITE_TAG
+CODECOV_TOKEN
+TRIAL_FLAGS
diff --git a/.buildkite/docker-compose.py35.pg95.yaml b/.buildkite/docker-compose.py35.pg95.yaml
new file mode 100644
index 0000000000..c6e8280e65
--- /dev/null
+++ b/.buildkite/docker-compose.py35.pg95.yaml
@@ -0,0 +1,23 @@
+version: '3.1'
+
+services:
+
+ postgres:
+ image: postgres:9.5
+ environment:
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
+ command: -c fsync=off
+
+ testenv:
+ image: python:3.5
+ depends_on:
+ - postgres
+ env_file: docker-compose-env
+ environment:
+ SYNAPSE_POSTGRES_HOST: postgres
+ SYNAPSE_POSTGRES_USER: postgres
+ SYNAPSE_POSTGRES_PASSWORD: postgres
+ working_dir: /src
+ volumes:
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/docker-compose.py37.pg11.yaml b/.buildkite/docker-compose.py37.pg11.yaml
new file mode 100644
index 0000000000..411c37f213
--- /dev/null
+++ b/.buildkite/docker-compose.py37.pg11.yaml
@@ -0,0 +1,23 @@
+version: '3.1'
+
+services:
+
+ postgres:
+ image: postgres:11
+ environment:
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
+ command: -c fsync=off
+
+ testenv:
+ image: python:3.7
+ depends_on:
+ - postgres
+ env_file: docker-compose-env
+ environment:
+ SYNAPSE_POSTGRES_HOST: postgres
+ SYNAPSE_POSTGRES_USER: postgres
+ SYNAPSE_POSTGRES_PASSWORD: postgres
+ working_dir: /src
+ volumes:
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/docker-compose.py37.pg95.yaml b/.buildkite/docker-compose.py37.pg95.yaml
new file mode 100644
index 0000000000..54ca794072
--- /dev/null
+++ b/.buildkite/docker-compose.py37.pg95.yaml
@@ -0,0 +1,23 @@
+version: '3.1'
+
+services:
+
+ postgres:
+ image: postgres:9.5
+ environment:
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
+ command: -c fsync=off
+
+ testenv:
+ image: python:3.7
+ depends_on:
+ - postgres
+ env_file: docker-compose-env
+ environment:
+ SYNAPSE_POSTGRES_HOST: postgres
+ SYNAPSE_POSTGRES_USER: postgres
+ SYNAPSE_POSTGRES_PASSWORD: postgres
+ working_dir: /src
+ volumes:
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/docker-compose.py38.pg12.yaml b/.buildkite/docker-compose.py38.pg12.yaml
new file mode 100644
index 0000000000..934a34cf02
--- /dev/null
+++ b/.buildkite/docker-compose.py38.pg12.yaml
@@ -0,0 +1,23 @@
+version: '3.1'
+
+services:
+
+ postgres:
+ image: postgres:12
+ environment:
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
+ command: -c fsync=off
+
+ testenv:
+ image: python:3.8
+ depends_on:
+ - postgres
+ env_file: docker-compose-env
+ environment:
+ SYNAPSE_POSTGRES_HOST: postgres
+ SYNAPSE_POSTGRES_USER: postgres
+ SYNAPSE_POSTGRES_PASSWORD: postgres
+ working_dir: /src
+ volumes:
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/docker-compose.sytest.py37.redis.yaml b/.buildkite/docker-compose.sytest.py37.redis.yaml
new file mode 100644
index 0000000000..b9e80cc557
--- /dev/null
+++ b/.buildkite/docker-compose.sytest.py37.redis.yaml
@@ -0,0 +1,22 @@
+version: '3.1'
+
+services:
+
+ redis:
+ image: redis:5.0
+
+ sytest:
+ image: matrixdotorg/sytest-synapse:py37
+ depends_on:
+ - redis
+ env_file: docker-compose-env
+ environment:
+ POSTGRES: "1"
+ WORKERS: "1"
+ BLACKLIST: "synapse-blacklist-with-workers"
+ REDIS: "redis"
+ working_dir: "/src"
+ entrypoint: ""
+ volumes:
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}/logs:/logs
diff --git a/.buildkite/merge_base_branch.sh b/.buildkite/merge_base_branch.sh
index 361440fd1a..d0a7aef8cb 100755
--- a/.buildkite/merge_base_branch.sh
+++ b/.buildkite/merge_base_branch.sh
@@ -12,7 +12,7 @@ if [[ -z $BUILDKITE_PULL_REQUEST_BASE_BRANCH ]]; then
# It probably hasn't had a PR opened yet. Since all PRs land on develop, we
# can probably assume it's based on it and will be merged into it.
- GITBASE="develop"
+ GITBASE="dinsic"
else
# Get the reference, using the GitHub API
GITBASE=$BUILDKITE_PULL_REQUEST_BASE_BRANCH
diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
new file mode 100644
index 0000000000..5877ff0883
--- /dev/null
+++ b/.buildkite/pipeline.yml
@@ -0,0 +1,496 @@
+env:
+ COVERALLS_REPO_TOKEN: wsJWOby6j0uCYFiCes3r0XauxO27mx8lD
+
+steps:
+ - label: "\U0001F9F9 Check Style"
+ command:
+ - "python -m pip install tox"
+ - "tox -e check_codestyle"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - label: "\U0001F9F9 packaging"
+ command:
+ - "python -m pip install tox"
+ - "tox -e packaging"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - label: "\U0001F9F9 isort"
+ command:
+ - "python -m pip install tox"
+ - "tox -e check_isort"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - label: ":newspaper: Newsfile"
+ command:
+ - "python -m pip install tox"
+ - "scripts-dev/check-newsfragment"
+ branches: "!master !develop !release-*"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ propagate-environment: true
+ mount-buildkite-agent: false
+
+ - label: "\U0001F9F9 check-sample-config"
+ command:
+ - "python -m pip install tox"
+ - "tox -e check-sampleconfig"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - label: ":mypy: mypy"
+ command:
+ - "python -m pip install tox"
+ - "tox -e mypy"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.7"
+ mount-buildkite-agent: false
+
+ - wait
+
+ ################################################################################
+ #
+ # `trial` tests
+ #
+ ################################################################################
+
+ - label: ":python: 3.5 / SQLite / Old Deps"
+ command:
+ - ".buildkite/scripts/test_old_deps.sh"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ plugins:
+ - docker#v3.0.1:
+ image: "ubuntu:xenial" # We use xenial to get an old sqlite and python
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.5 / SQLite"
+ command:
+ - "python -m pip install tox"
+ - "tox -e py35,combine"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.5"
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.6 / SQLite"
+ command:
+ - "python -m pip install tox"
+ - "tox -e py36,combine"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.7 / SQLite"
+ command:
+ - "python -m pip install tox"
+ - "tox -e py37,combine"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.7"
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.5 / :postgres: 9.5"
+ agents:
+ queue: "medium"
+ env:
+ TRIAL_FLAGS: "-j 8"
+ command:
+ - "bash -c 'python -m pip install tox && python -m tox -e py35-postgres,combine'"
+ plugins:
+ - matrix-org/download#v1.1.0:
+ urls:
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose.py35.pg95.yaml
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose-env
+ - docker-compose#v2.1.0:
+ run: testenv
+ config:
+ - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.py35.pg95.yaml
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.7 / :postgres: 11"
+ agents:
+ queue: "medium"
+ env:
+ TRIAL_FLAGS: "-j 8"
+ command:
+ - "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,combine'"
+ plugins:
+ - matrix-org/download#v1.1.0:
+ urls:
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose.py37.pg11.yaml
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose-env
+ - docker-compose#v2.1.0:
+ run: testenv
+ config:
+ - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.py37.pg11.yaml
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.8 / :postgres: 12"
+ agents:
+ queue: "medium"
+ env:
+ TRIAL_FLAGS: "-j 8"
+ command:
+ - "bash -c 'python -m pip install tox && python -m tox -e py38-postgres,combine'"
+ plugins:
+ - matrix-org/download#v1.1.0:
+ urls:
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose.py38.pg12.yaml
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose-env
+ - docker-compose#v2.1.0:
+ run: testenv
+ config:
+ - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.py38.pg12.yaml
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ ################################################################################
+ #
+ # Sytest
+ #
+ ################################################################################
+
+ - label: "SyTest - :python: 3.5 / SQLite / Monolith"
+ agents:
+ queue: "medium"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash /bootstrap.sh synapse"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:dinsic"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: "/bin/sh"
+ init: false
+ shell: ["-x", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ style: "error"
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: "SyTest - :python: 3.5 / :postgres: 9.6 / Monolith"
+ agents:
+ queue: "medium"
+ env:
+ POSTGRES: "1"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash /bootstrap.sh synapse"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:dinsic"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: "/bin/sh"
+ init: false
+ shell: ["-x", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ style: "error"
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: "SyTest - :python: 3.5 / :postgres: 9.6 / Workers"
+ agents:
+ queue: "medium"
+ env:
+ MULTI_POSTGRES: "1" # Test with split out databases
+ POSTGRES: "1"
+ WORKERS: "1"
+ BLACKLIST: "synapse-blacklist-with-workers"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash -c 'cat /src/sytest-blacklist /src/.buildkite/worker-blacklist > /src/synapse-blacklist-with-workers'"
+ - "bash /bootstrap.sh synapse"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:dinsic"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: "/bin/sh"
+ init: false
+ shell: ["-x", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ style: "error"
+ # - matrix-org/coveralls#v1.0:
+ # parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+
+ - label: "SyTest - :python: 3.8 / :postgres: 12 / Monolith"
+ agents:
+ queue: "medium"
+ env:
+ POSTGRES: "1"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash /bootstrap.sh synapse"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:dinsic"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: "/bin/sh"
+ init: false
+ shell: ["-x", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ style: "error"
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: "SyTest - :python: 3.7 / :postgres: 11 / Workers"
+ agents:
+ queue: "medium"
+ env:
+ MULTI_POSTGRES: "1" # Test with split out databases
+ POSTGRES: "1"
+ WORKERS: "1"
+ BLACKLIST: "synapse-blacklist-with-workers"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash -c 'cat /src/sytest-blacklist /src/.buildkite/worker-blacklist > /src/synapse-blacklist-with-workers'"
+ - "bash /bootstrap.sh synapse"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:dinsic"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: "/bin/sh"
+ init: false
+ shell: ["-x", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ style: "error"
+ # - matrix-org/coveralls#v1.0:
+ # parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+# TODO: Enable once Synapse v1.13.0 is merged in
+# - label: "SyTest - :python: 3.7 / :postgres: 11 / Workers / :redis: Redis"
+# agents:
+# queue: "medium"
+# command:
+# - bash -c "cat /src/sytest-blacklist /src/.buildkite/worker-blacklist > /src/synapse-blacklist-with-workers && ./.buildkite/merge_base_branch.sh && /bootstrap.sh synapse --redis-host redis"
+# plugins:
+# - matrix-org/download#v1.1.0:
+# urls:
+# - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose.sytest.py37.redis.yaml
+# - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose-env
+# - docker-compose#v2.1.0:
+# run: sytest
+# config:
+# - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.sytest.py37.redis.yaml
+# - artifacts#v1.2.0:
+# upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+# - matrix-org/annotate:
+# path: "logs/annotate.md"
+# style: "error"
+## - matrix-org/coveralls#v1.0:
+## parallel: "true"
+# retry:
+# automatic:
+# - exit_status: -1
+# limit: 2
+# - exit_status: 2
+# limit: 2
+
+ ################################################################################
+ #
+ # synapse_port_db
+ #
+ ################################################################################
+
+ - label: "synapse_port_db / :python: 3.5 / :postgres: 9.5"
+ agents:
+ queue: "medium"
+ command:
+ - "bash .buildkite/scripts/test_synapse_port_db.sh"
+ plugins:
+ - matrix-org/download#v1.1.0:
+ urls:
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose.py35.pg95.yaml
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose-env
+ - docker-compose#v2.1.0:
+ run: testenv
+ config:
+ - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.py35.pg95.yaml
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+
+ - label: "synapse_port_db / :python: 3.7 / :postgres: 11"
+ agents:
+ queue: "medium"
+ command:
+ - "bash .buildkite/scripts/test_synapse_port_db.sh"
+ plugins:
+ - matrix-org/download#v1.1.0:
+ urls:
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose.py37.pg11.yaml
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose-env
+ - docker-compose#v2.1.0:
+ run: testenv
+ config:
+ - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.py37.pg11.yaml
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+
+# - wait: ~
+# continue_on_failure: true
+#
+# - label: Trigger webhook
+# command: "curl -k https://coveralls.io/webhook?repo_token=$COVERALLS_REPO_TOKEN -d \"payload[build_num]=$BUILDKITE_BUILD_NUMBER&payload[status]=done\""
diff --git a/CHANGES.md b/CHANGES.md
index d4cc179489..c1b8673c04 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,17 @@
+For the next release
+====================
+
+Removal warning
+---------------
+
+Some older clients used a
+[disallowed character](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-register-email-requesttoken)
+(`:`) in the `client_secret` parameter of various endpoints. The incorrect
+behaviour was allowed for backwards compatibility, but is now being removed
+from Synapse as most users have updated their client. Further context can be
+found at [\#6766](https://github.com/matrix-org/synapse/issues/6766).
+
+
Synapse 1.19.0 (2020-08-17)
===========================
diff --git a/MANIFEST.in b/MANIFEST.in
index 120ce5b776..0a9cf4f51c 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,4 +1,5 @@
include synctl
+include sytest-blacklist
include LICENSE
include VERSION
include *.rst
@@ -51,3 +52,11 @@ prune demo/etc
prune docker
prune snap
prune stubs
+
+exclude jenkins*
+recursive-exclude jenkins *.sh
+
+# FIXME: we shouldn't have these templates here
+recursive-include res/templates-dinsic *.css
+recursive-include res/templates-dinsic *.html
+recursive-include res/templates-dinsic *.txt
diff --git a/UPGRADE.rst b/UPGRADE.rst
index 6492fa011f..b2069a0d26 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -75,6 +75,70 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
+Upgrading to v1.21.0
+====================
+
+Forwarding ``/_synapse/client`` through your reverse proxy
+----------------------------------------------------------
+
+The `reverse proxy documentation
+<https://github.com/matrix-org/synapse/blob/develop/docs/reverse_proxy.md>`_ has been updated
+to include reverse proxy directives for ``/_synapse/client/*`` endpoints. As the user password
+reset flow now uses endpoints under this prefix, **you must update your reverse proxy
+configurations for user password reset to work**.
+
+Additionally, note that the `Synapse worker documentation
+<https://github.com/matrix-org/synapse/blob/develop/docs/workers.md>`_ has been updated to
+ state that the ``/_synapse/client/password_reset/email/submit_token`` endpoint can be handled
+by all workers. If you make use of Synapse's worker feature, please update your reverse proxy
+configuration to reflect this change.
+
+New HTML templates
+------------------
+
+A new HTML template,
+`password_reset_confirmation.html <https://github.com/matrix-org/synapse/blob/develop/synapse/res/templates/password_reset_confirmation.html>`_,
+has been added to the ``synapse/res/templates`` directory. If you are using a
+custom template directory, you may want to copy the template over and modify it.
+
+Note that as of v1.20.0, templates do not need to be included in custom template
+directories for Synapse to start. The default templates will be used if a custom
+template cannot be found.
+
+This page will appear to the user after clicking a password reset link that has
+been emailed to them.
+
+To complete password reset, the page must include a way to make a `POST`
+request to
+``/_synapse/client/password_reset/{medium}/submit_token``
+with the query parameters from the original link, presented as a URL-encoded form. See the file
+itself for more details.
+
+Updated Single Sign-on HTML Templates
+-------------------------------------
+
+The ``saml_error.html`` template was removed from Synapse and replaced with the
+``sso_error.html`` template. If your Synapse is configured to use SAML and a
+custom ``sso_redirect_confirm_template_dir`` configuration then any customisations
+of the ``saml_error.html`` template will need to be merged into the ``sso_error.html``
+template. These templates are similar, but the parameters are slightly different:
+
+* The ``msg`` parameter should be renamed to ``error_description``.
+* There is no longer a ``code`` parameter for the response code.
+* A string ``error`` parameter is available that includes a short hint of why a
+ user is seeing the error page.
+
+ThirdPartyEventRules breaking changes
+-------------------------------------
+
+This release introduces a backwards-incompatible change to modules making use of
+`ThirdPartyEventRules` in Synapse.
+
+The `http_client` argument is no longer passed to modules as they are initialised. Instead,
+modules are expected to make use of the `http_client` property on the `ModuleApi` class.
+Modules are now passed a `module_api` argument during initialisation, which is an instance of
+`ModuleApi`.
+
Upgrading to v1.18.0
====================
diff --git a/changelog.d/1.feature b/changelog.d/1.feature
new file mode 100644
index 0000000000..845642e445
--- /dev/null
+++ b/changelog.d/1.feature
@@ -0,0 +1 @@
+Forbid changing the name, avatar or topic of a direct room.
diff --git a/changelog.d/10.bugfix b/changelog.d/10.bugfix
new file mode 100644
index 0000000000..51f89f46dd
--- /dev/null
+++ b/changelog.d/10.bugfix
@@ -0,0 +1 @@
+Don't apply retention policy based filtering on state events.
diff --git a/changelog.d/11.feature b/changelog.d/11.feature
new file mode 100644
index 0000000000..362e4b1efd
--- /dev/null
+++ b/changelog.d/11.feature
@@ -0,0 +1 @@
+Allow server admins to configure a custom global rate-limiting for third party invites.
\ No newline at end of file
diff --git a/changelog.d/12.feature b/changelog.d/12.feature
new file mode 100644
index 0000000000..8e6e7a28af
--- /dev/null
+++ b/changelog.d/12.feature
@@ -0,0 +1 @@
+Add `/user/:user_id/info` CS servlet and to give user deactivated/expired information.
\ No newline at end of file
diff --git a/changelog.d/13.feature b/changelog.d/13.feature
new file mode 100644
index 0000000000..c2d2e93abf
--- /dev/null
+++ b/changelog.d/13.feature
@@ -0,0 +1 @@
+Hide expired users from the user directory, and optionally re-add them on renewal.
\ No newline at end of file
diff --git a/changelog.d/14.feature b/changelog.d/14.feature
new file mode 100644
index 0000000000..020d0bac1e
--- /dev/null
+++ b/changelog.d/14.feature
@@ -0,0 +1 @@
+User displaynames now have capitalised letters after - symbols.
\ No newline at end of file
diff --git a/changelog.d/15.misc b/changelog.d/15.misc
new file mode 100644
index 0000000000..4cc4a5175f
--- /dev/null
+++ b/changelog.d/15.misc
@@ -0,0 +1 @@
+Fix the ordering on `scripts/generate_signing_key.py`'s import statement.
diff --git a/changelog.d/17.misc b/changelog.d/17.misc
new file mode 100644
index 0000000000..58120ab5c7
--- /dev/null
+++ b/changelog.d/17.misc
@@ -0,0 +1 @@
+Blacklist some flaky sytests until they're fixed.
\ No newline at end of file
diff --git a/changelog.d/18.feature b/changelog.d/18.feature
new file mode 100644
index 0000000000..f5aa29a6e8
--- /dev/null
+++ b/changelog.d/18.feature
@@ -0,0 +1 @@
+Add option `limit_profile_requests_to_known_users` to prevent requirement of a user sharing a room with another user to query their profile information.
\ No newline at end of file
diff --git a/changelog.d/19.feature b/changelog.d/19.feature
new file mode 100644
index 0000000000..95a44a4a89
--- /dev/null
+++ b/changelog.d/19.feature
@@ -0,0 +1 @@
+Add `max_avatar_size` and `allowed_avatar_mimetypes` to restrict the size of user avatars and their file type respectively.
\ No newline at end of file
diff --git a/changelog.d/2.bugfix b/changelog.d/2.bugfix
new file mode 100644
index 0000000000..4fe5691468
--- /dev/null
+++ b/changelog.d/2.bugfix
@@ -0,0 +1 @@
+Don't treat 3PID revocation as a new 3PID invite.
diff --git a/changelog.d/20.bugfix b/changelog.d/20.bugfix
new file mode 100644
index 0000000000..8ba53c28f9
--- /dev/null
+++ b/changelog.d/20.bugfix
@@ -0,0 +1 @@
+Validate `client_secret` parameter against the regex provided by the C-S spec.
\ No newline at end of file
diff --git a/changelog.d/21.bugfix b/changelog.d/21.bugfix
new file mode 100644
index 0000000000..630d7812f7
--- /dev/null
+++ b/changelog.d/21.bugfix
@@ -0,0 +1 @@
+Fix resetting user passwords via a phone number.
diff --git a/changelog.d/28.bugfix b/changelog.d/28.bugfix
new file mode 100644
index 0000000000..38d7455971
--- /dev/null
+++ b/changelog.d/28.bugfix
@@ -0,0 +1 @@
+Fix a bug causing account validity renewal emails to be sent even if the feature is turned off in some cases.
diff --git a/changelog.d/29.misc b/changelog.d/29.misc
new file mode 100644
index 0000000000..720e0ddcfb
--- /dev/null
+++ b/changelog.d/29.misc
@@ -0,0 +1 @@
+Improve performance when making `.well-known` requests by sharing the SSL options between requests.
diff --git a/changelog.d/3.bugfix b/changelog.d/3.bugfix
new file mode 100644
index 0000000000..cc4bcefa80
--- /dev/null
+++ b/changelog.d/3.bugfix
@@ -0,0 +1 @@
+Fix encoding on password reset HTML responses in Python 2.
diff --git a/changelog.d/30.misc b/changelog.d/30.misc
new file mode 100644
index 0000000000..ae68554be3
--- /dev/null
+++ b/changelog.d/30.misc
@@ -0,0 +1 @@
+Improve performance when making HTTP requests to sygnal, sydent, etc, by sharing the SSL context object between connections.
diff --git a/changelog.d/32.bugfix b/changelog.d/32.bugfix
new file mode 100644
index 0000000000..b6e7b90710
--- /dev/null
+++ b/changelog.d/32.bugfix
@@ -0,0 +1 @@
+Fixes a bug when using the default display name during registration.
diff --git a/changelog.d/39.feature b/changelog.d/39.feature
new file mode 100644
index 0000000000..426b7ef27e
--- /dev/null
+++ b/changelog.d/39.feature
@@ -0,0 +1 @@
+Merge Synapse v1.12.4 `master` into the `dinsic` branch.
\ No newline at end of file
diff --git a/changelog.d/4.bugfix b/changelog.d/4.bugfix
new file mode 100644
index 0000000000..fe717920a6
--- /dev/null
+++ b/changelog.d/4.bugfix
@@ -0,0 +1 @@
+Fix handling of filtered strings in Python 3.
diff --git a/changelog.d/45.feature b/changelog.d/45.feature
new file mode 100644
index 0000000000..d45ac34ac1
--- /dev/null
+++ b/changelog.d/45.feature
@@ -0,0 +1 @@
+Merge Synapse mainline releases v1.13.0 through v1.14.0 into the `dinsic` branch.
\ No newline at end of file
diff --git a/changelog.d/46.feature b/changelog.d/46.feature
new file mode 100644
index 0000000000..7872d956e3
--- /dev/null
+++ b/changelog.d/46.feature
@@ -0,0 +1 @@
+Add a bulk version of the User Info API. Deprecate the single-use version.
\ No newline at end of file
diff --git a/changelog.d/47.misc b/changelog.d/47.misc
new file mode 100644
index 0000000000..1d6596d788
--- /dev/null
+++ b/changelog.d/47.misc
@@ -0,0 +1 @@
+Improve performance of `mark_expired_users_as_inactive` background job.
\ No newline at end of file
diff --git a/changelog.d/48.feature b/changelog.d/48.feature
new file mode 100644
index 0000000000..b7939f3f51
--- /dev/null
+++ b/changelog.d/48.feature
@@ -0,0 +1 @@
+Prevent `/register` from raising `M_USER_IN_USE` until UI Auth has been completed. Have `/register/available` always return true.
diff --git a/changelog.d/5.bugfix b/changelog.d/5.bugfix
new file mode 100644
index 0000000000..53f57f46ca
--- /dev/null
+++ b/changelog.d/5.bugfix
@@ -0,0 +1 @@
+Fix room retention policy management in worker mode.
diff --git a/changelog.d/50.feature b/changelog.d/50.feature
new file mode 100644
index 0000000000..0801622c8a
--- /dev/null
+++ b/changelog.d/50.feature
@@ -0,0 +1 @@
+Merge Synapse mainline v1.15.1 into the `dinsic` branch.
\ No newline at end of file
diff --git a/changelog.d/5083.feature b/changelog.d/5083.feature
new file mode 100644
index 0000000000..2ffdd37eef
--- /dev/null
+++ b/changelog.d/5083.feature
@@ -0,0 +1 @@
+Adds auth_profile_reqs option to require access_token to GET /profile endpoints on CS API.
diff --git a/changelog.d/5098.misc b/changelog.d/5098.misc
new file mode 100644
index 0000000000..9cd83bf226
--- /dev/null
+++ b/changelog.d/5098.misc
@@ -0,0 +1 @@
+Add workarounds for pep-517 install errors.
diff --git a/changelog.d/51.feature b/changelog.d/51.feature
new file mode 100644
index 0000000000..e5c9990ad6
--- /dev/null
+++ b/changelog.d/51.feature
@@ -0,0 +1 @@
+Add `bind_new_user_emails_to_sydent` option for automatically binding user's emails after registration.
diff --git a/changelog.d/5214.feature b/changelog.d/5214.feature
new file mode 100644
index 0000000000..6c0f15c901
--- /dev/null
+++ b/changelog.d/5214.feature
@@ -0,0 +1 @@
+Allow server admins to define and enforce a password policy (MSC2000).
diff --git a/changelog.d/53.feature b/changelog.d/53.feature
new file mode 100644
index 0000000000..96c628e824
--- /dev/null
+++ b/changelog.d/53.feature
@@ -0,0 +1 @@
+Merge mainline Synapse v1.18.0 into the `dinsic` branch.
\ No newline at end of file
diff --git a/changelog.d/5416.misc b/changelog.d/5416.misc
new file mode 100644
index 0000000000..155e8c7cd3
--- /dev/null
+++ b/changelog.d/5416.misc
@@ -0,0 +1 @@
+Add unique index to the profile_replication_status table.
diff --git a/changelog.d/5420.feature b/changelog.d/5420.feature
new file mode 100644
index 0000000000..745864b903
--- /dev/null
+++ b/changelog.d/5420.feature
@@ -0,0 +1 @@
+Add configuration option to hide new users from the user directory.
diff --git a/changelog.d/56.misc b/changelog.d/56.misc
new file mode 100644
index 0000000000..f66c55af21
--- /dev/null
+++ b/changelog.d/56.misc
@@ -0,0 +1 @@
+Temporarily revert commit a3fbc23.
diff --git a/changelog.d/5610.feature b/changelog.d/5610.feature
new file mode 100644
index 0000000000..b99514f97e
--- /dev/null
+++ b/changelog.d/5610.feature
@@ -0,0 +1 @@
+Implement new custom event rules for power levels.
diff --git a/changelog.d/57.misc b/changelog.d/57.misc
new file mode 100644
index 0000000000..1bbe8611cd
--- /dev/null
+++ b/changelog.d/57.misc
@@ -0,0 +1 @@
+Add user_id back to presence in worker too https://github.com/matrix-org/synapse/commit/0bbbd10513008d30c17eb1d1e7ba1d091fb44ec7 .
diff --git a/changelog.d/5702.bugfix b/changelog.d/5702.bugfix
new file mode 100644
index 0000000000..43b6e39b13
--- /dev/null
+++ b/changelog.d/5702.bugfix
@@ -0,0 +1 @@
+Fix 3PID invite to invite association detection in the Tchap room access rules.
diff --git a/changelog.d/5760.feature b/changelog.d/5760.feature
new file mode 100644
index 0000000000..90302d793e
--- /dev/null
+++ b/changelog.d/5760.feature
@@ -0,0 +1 @@
+Force the access rule to be "restricted" if the join rule is "public".
diff --git a/changelog.d/58.misc b/changelog.d/58.misc
new file mode 100644
index 0000000000..64098a68a4
--- /dev/null
+++ b/changelog.d/58.misc
@@ -0,0 +1 @@
+Don't push if an user account has expired.
diff --git a/changelog.d/59.feature b/changelog.d/59.feature
new file mode 100644
index 0000000000..aa07f762d1
--- /dev/null
+++ b/changelog.d/59.feature
@@ -0,0 +1 @@
+Freeze a room when the last administrator in the room leaves.
\ No newline at end of file
diff --git a/changelog.d/6.bugfix b/changelog.d/6.bugfix
new file mode 100644
index 0000000000..43ab65cc95
--- /dev/null
+++ b/changelog.d/6.bugfix
@@ -0,0 +1 @@
+Don't forbid membership events which membership isn't 'join' or 'invite' in restricted rooms, so that users who got into these rooms before the access rules started to be enforced can leave them.
diff --git a/changelog.d/60.misc b/changelog.d/60.misc
new file mode 100644
index 0000000000..d2625a4f65
--- /dev/null
+++ b/changelog.d/60.misc
@@ -0,0 +1 @@
+Make all rooms noisy by default.
diff --git a/changelog.d/61.misc b/changelog.d/61.misc
new file mode 100644
index 0000000000..0c3ba98628
--- /dev/null
+++ b/changelog.d/61.misc
@@ -0,0 +1 @@
+Change the minimum power levels for invites and other state events in new rooms.
\ No newline at end of file
diff --git a/changelog.d/62.misc b/changelog.d/62.misc
new file mode 100644
index 0000000000..1e26456595
--- /dev/null
+++ b/changelog.d/62.misc
@@ -0,0 +1 @@
+Type hinting and other cleanups for `synapse.third_party_rules.access_rules`.
\ No newline at end of file
diff --git a/changelog.d/63.feature b/changelog.d/63.feature
new file mode 100644
index 0000000000..b45f38fa94
--- /dev/null
+++ b/changelog.d/63.feature
@@ -0,0 +1 @@
+Make AccessRules use the public rooms directory instead of checking a room's join rules on rule change.
diff --git a/changelog.d/64.bugfix b/changelog.d/64.bugfix
new file mode 100644
index 0000000000..60c077af94
--- /dev/null
+++ b/changelog.d/64.bugfix
@@ -0,0 +1 @@
+Ensure a `RoomAccessRules` test doesn't accidentally modify a room's access rule and then test that room assuming its access rule has not changed.
diff --git a/changelog.d/65.bugfix b/changelog.d/65.bugfix
new file mode 100644
index 0000000000..71b498cbc8
--- /dev/null
+++ b/changelog.d/65.bugfix
@@ -0,0 +1 @@
+Fix `nextLink` parameters being checked on validation endpoints even if they weren't provided by the client.
\ No newline at end of file
diff --git a/changelog.d/66.bugfix b/changelog.d/66.bugfix
new file mode 100644
index 0000000000..9547cfeddd
--- /dev/null
+++ b/changelog.d/66.bugfix
@@ -0,0 +1 @@
+Create a mapping between user ID and threepid when binding via the internal Sydent bind API.
\ No newline at end of file
diff --git a/changelog.d/7864.bugfix b/changelog.d/7864.bugfix
new file mode 100644
index 0000000000..8623355fe9
--- /dev/null
+++ b/changelog.d/7864.bugfix
@@ -0,0 +1 @@
+Fix a memory leak by limiting the length of time that messages will be queued for a remote server that has been unreachable.
diff --git a/changelog.d/8013.feature b/changelog.d/8013.feature
new file mode 100644
index 0000000000..b1eaf1e78a
--- /dev/null
+++ b/changelog.d/8013.feature
@@ -0,0 +1 @@
+Iteratively encode JSON to avoid blocking the reactor.
diff --git a/changelog.d/8037.feature b/changelog.d/8037.feature
new file mode 100644
index 0000000000..2e5127477d
--- /dev/null
+++ b/changelog.d/8037.feature
@@ -0,0 +1 @@
+Use the default template file when its equivalent is not found in a custom template directory.
\ No newline at end of file
diff --git a/changelog.d/8071.misc b/changelog.d/8071.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8071.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8072.misc b/changelog.d/8072.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8072.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8074.misc b/changelog.d/8074.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8074.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8075.misc b/changelog.d/8075.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8075.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8076.misc b/changelog.d/8076.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8076.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8081.bugfix b/changelog.d/8081.bugfix
new file mode 100644
index 0000000000..9ebcbf5b84
--- /dev/null
+++ b/changelog.d/8081.bugfix
@@ -0,0 +1 @@
+Fix `Re-starting finished log context PUT-nnnn` warning when event persistence failed.
diff --git a/changelog.d/8085.misc b/changelog.d/8085.misc
new file mode 100644
index 0000000000..c3da1e297c
--- /dev/null
+++ b/changelog.d/8085.misc
@@ -0,0 +1 @@
+Remove some unused database functions.
diff --git a/changelog.d/8087.misc b/changelog.d/8087.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8087.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8090.misc b/changelog.d/8090.misc
new file mode 100644
index 0000000000..725a03ae88
--- /dev/null
+++ b/changelog.d/8090.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.handlers.room`.
diff --git a/changelog.d/8092.feature b/changelog.d/8092.feature
new file mode 100644
index 0000000000..813e6d0903
--- /dev/null
+++ b/changelog.d/8092.feature
@@ -0,0 +1 @@
+Add support for shadow-banning users (ignoring any message send requests).
diff --git a/changelog.d/8093.misc b/changelog.d/8093.misc
new file mode 100644
index 0000000000..80045dde1a
--- /dev/null
+++ b/changelog.d/8093.misc
@@ -0,0 +1 @@
+Return the previous stream token if a non-member event is a duplicate.
diff --git a/changelog.d/8100.misc b/changelog.d/8100.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8100.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8101.bugfix b/changelog.d/8101.bugfix
new file mode 100644
index 0000000000..703bba4234
--- /dev/null
+++ b/changelog.d/8101.bugfix
@@ -0,0 +1 @@
+Synapse now correctly enforces the valid characters in the `client_secret` parameter used in various endpoints.
diff --git a/changelog.d/8106.bugfix b/changelog.d/8106.bugfix
new file mode 100644
index 0000000000..c46c60448f
--- /dev/null
+++ b/changelog.d/8106.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where invalid JSON would be accepted by Synapse.
diff --git a/changelog.d/8107.feature b/changelog.d/8107.feature
new file mode 100644
index 0000000000..2e5127477d
--- /dev/null
+++ b/changelog.d/8107.feature
@@ -0,0 +1 @@
+Use the default template file when its equivalent is not found in a custom template directory.
\ No newline at end of file
diff --git a/changelog.d/8111.doc b/changelog.d/8111.doc
new file mode 100644
index 0000000000..d3f7435452
--- /dev/null
+++ b/changelog.d/8111.doc
@@ -0,0 +1 @@
+Link to matrix-synapse-rest-password-provider in the password provider documentation.
diff --git a/changelog.d/8112.misc b/changelog.d/8112.misc
new file mode 100644
index 0000000000..80045dde1a
--- /dev/null
+++ b/changelog.d/8112.misc
@@ -0,0 +1 @@
+Return the previous stream token if a non-member event is a duplicate.
diff --git a/changelog.d/8113.misc b/changelog.d/8113.misc
new file mode 100644
index 0000000000..00bec4f8ef
--- /dev/null
+++ b/changelog.d/8113.misc
@@ -0,0 +1 @@
+Separate `get_current_token` into two since there are two different use cases for it.
diff --git a/changelog.d/8116.feature b/changelog.d/8116.feature
new file mode 100644
index 0000000000..b1eaf1e78a
--- /dev/null
+++ b/changelog.d/8116.feature
@@ -0,0 +1 @@
+Iteratively encode JSON to avoid blocking the reactor.
diff --git a/changelog.d/8119.misc b/changelog.d/8119.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8119.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8120.doc b/changelog.d/8120.doc
new file mode 100644
index 0000000000..877ef79fd2
--- /dev/null
+++ b/changelog.d/8120.doc
@@ -0,0 +1 @@
+Updated documentation to note that Synapse does not follow `HTTP 308` redirects due to an upstream library not supporting them. Contributed by Ryan Cole.
\ No newline at end of file
diff --git a/changelog.d/8121.misc b/changelog.d/8121.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8121.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8123.misc b/changelog.d/8123.misc
new file mode 100644
index 0000000000..7245122896
--- /dev/null
+++ b/changelog.d/8123.misc
@@ -0,0 +1 @@
+Remove `ChainedIdGenerator`.
diff --git a/changelog.d/8129.bugfix b/changelog.d/8129.bugfix
new file mode 100644
index 0000000000..79eae9db6b
--- /dev/null
+++ b/changelog.d/8129.bugfix
@@ -0,0 +1 @@
+Return a proper error code when the rooms of an invalid group are requested.
diff --git a/changelog.d/8131.bugfix b/changelog.d/8131.bugfix
new file mode 100644
index 0000000000..5110f235d1
--- /dev/null
+++ b/changelog.d/8131.bugfix
@@ -0,0 +1 @@
+Fix a bug which could cause a leaked postgres connection if synapse was set to daemonize.
diff --git a/changelog.d/8133.misc b/changelog.d/8133.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8133.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8275.feature b/changelog.d/8275.feature
new file mode 100644
index 0000000000..17549c3df3
--- /dev/null
+++ b/changelog.d/8275.feature
@@ -0,0 +1 @@
+Add a config option to specify a whitelist of domains that a user can be redirected to after validating their email or phone number.
\ No newline at end of file
diff --git a/changelog.d/8479.feature b/changelog.d/8479.feature
new file mode 100644
index 0000000000..11adeec8a9
--- /dev/null
+++ b/changelog.d/8479.feature
@@ -0,0 +1 @@
+Add the ability to send non-membership events into a room via the `ModuleApi`.
\ No newline at end of file
diff --git a/changelog.d/9.misc b/changelog.d/9.misc
new file mode 100644
index 0000000000..24fd12c978
--- /dev/null
+++ b/changelog.d/9.misc
@@ -0,0 +1 @@
+Add SyTest to the BuildKite CI.
diff --git a/contrib/systemd/README.md b/contrib/systemd/README.md
deleted file mode 100644
index 5d42b3464f..0000000000
--- a/contrib/systemd/README.md
+++ /dev/null
@@ -1,17 +0,0 @@
-# Setup Synapse with Systemd
-This is a setup for managing synapse with a user contributed systemd unit
-file. It provides a `matrix-synapse` systemd unit file that should be tailored
-to accommodate your installation in accordance with the installation
-instructions provided in [installation instructions](../../INSTALL.md).
-
-## Setup
-1. Under the service section, ensure the `User` variable matches which user
-you installed synapse under and wish to run it as.
-2. Under the service section, ensure the `WorkingDirectory` variable matches
-where you have installed synapse.
-3. Under the service section, ensure the `ExecStart` variable matches the
-appropriate locations of your installation.
-4. Copy the `matrix-synapse.service` to `/etc/systemd/system/`
-5. Start Synapse: `sudo systemctl start matrix-synapse`
-6. Verify Synapse is running: `sudo systemctl status matrix-synapse`
-7. *optional* Enable Synapse to start at system boot: `sudo systemctl enable matrix-synapse`
diff --git a/docs/federate.md b/docs/federate.md
index a0786b9cf7..b15cd724d1 100644
--- a/docs/federate.md
+++ b/docs/federate.md
@@ -47,6 +47,18 @@ you invite them to. This can be caused by an incorrectly-configured reverse
proxy: see [reverse_proxy.md](<reverse_proxy.md>) for instructions on how to correctly
configure a reverse proxy.
+### Known issues
+
+**HTTP `308 Permanent Redirect` redirects are not followed**: Due to missing features
+in the HTTP library used by Synapse, 308 redirects are currently not followed by
+federating servers, which can cause `M_UNKNOWN` or `401 Unauthorized` errors. This
+may affect users who are redirecting apex-to-www (e.g. `example.com` -> `www.example.com`),
+and especially users of the Kubernetes *Nginx Ingress* module, which uses 308 redirect
+codes by default. For those Kubernetes users, [this Stackoverflow post](https://stackoverflow.com/a/52617528/5096871)
+might be helpful. For other users, switching to a `301 Moved Permanently` code may be
+an option. 308 redirect codes will be supported properly in a future
+release of Synapse.
+
## Running a demo federation of Synapses
If you want to get up and running quickly with a trio of homeservers in a
diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md
index fef1d47e85..7d98d9f255 100644
--- a/docs/password_auth_providers.md
+++ b/docs/password_auth_providers.md
@@ -14,6 +14,7 @@ password auth provider module implementations:
* [matrix-synapse-ldap3](https://github.com/matrix-org/matrix-synapse-ldap3/)
* [matrix-synapse-shared-secret-auth](https://github.com/devture/matrix-synapse-shared-secret-auth)
+* [matrix-synapse-rest-password-provider](https://github.com/ma1uta/matrix-synapse-rest-password-provider)
## Required methods
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 9235b89fb1..13376c8a42 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -340,6 +340,74 @@ limit_remote_rooms:
#
#allow_per_room_profiles: false
+# Whether to show the users on this homeserver in the user directory. Defaults to
+# 'true'.
+#
+#show_users_in_user_directory: false
+
+# Message retention policy at the server level.
+#
+# Room admins and mods can define a retention period for their rooms using the
+# 'm.room.retention' state event, and server admins can cap this period by setting
+# the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+#
+# If this feature is enabled, Synapse will regularly look for and purge events
+# which are older than the room's maximum retention period. Synapse will also
+# filter events received over federation so that events that should have been
+# purged are ignored and not stored again.
+#
+retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
+
# How long to keep redacted events in unredacted form in the database. After
# this period redacted events get replaced with their redacted form in the DB.
#
@@ -426,6 +494,24 @@ retention:
#
#request_token_inhibit_3pid_errors: true
+# A list of domains that the domain portion of 'next_link' parameters
+# must match.
+#
+# This parameter is optionally provided by clients while requesting
+# validation of an email or phone number, and maps to a link that
+# users will be automatically redirected to after validation
+# succeeds. Clients can make use this parameter to aid the validation
+# process.
+#
+# The whitelist is applied whether the homeserver or an
+# identity server is handling validation.
+#
+# The default value is no whitelist functionality; all domains are
+# allowed. Setting this value to an empty list will instead disallow
+# all domains.
+#
+#next_link_domain_whitelist: ["matrix.org"]
+
## TLS ##
@@ -743,6 +829,8 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+# - one that ratelimits third-party invites requests based on the account
+# that's making the requests.
# - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly.
@@ -772,6 +860,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# per_second: 0.17
# burst_count: 3
#
+#rc_third_party_invite:
+# per_second: 0.2
+# burst_count: 10
+#
#rc_admin_redaction:
# per_second: 1
# burst_count: 50
@@ -846,6 +938,30 @@ media_store_path: "DATADIR/media_store"
#
#max_upload_size: 10M
+# The largest allowed size for a user avatar. If not defined, no
+# restriction will be imposed.
+#
+# Note that this only applies when an avatar is changed globally.
+# Per-room avatar changes are not affected. See allow_per_room_profiles
+# for disabling that functionality.
+#
+# Note that user avatar changes will not work if this is set without
+# using Synapse's local media repo.
+#
+#max_avatar_size: 10M
+
+# Allow mimetypes for a user avatar. If not defined, no restriction will
+# be imposed.
+#
+# Note that this only applies when an avatar is changed globally.
+# Per-room avatar changes are not affected. See allow_per_room_profiles
+# for disabling that functionality.
+#
+# Note that user avatar changes will not work if this is set without
+# using Synapse's local media repo.
+#
+#allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
@@ -1130,9 +1246,32 @@ account_validity:
#
#disable_msisdn_registration: true
+# Derive the user's matrix ID from a type of 3PID used when registering.
+# This overrides any matrix ID the user proposes when calling /register
+# The 3PID type should be present in registrations_require_3pid to avoid
+# users failing to register if they don't specify the right kind of 3pid.
+#
+#register_mxid_from_3pid: email
+
+# Uncomment to set the display name of new users to their email address,
+# rather than using the default heuristic.
+#
+#register_just_use_email_for_display_name: true
+
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
+# Use an Identity Server to establish which 3PIDs are allowed to register?
+# Overrides allowed_local_3pids below.
+#
+#check_is_for_allowed_local_3pids: matrix.org
+#
+# If you are using an IS you can also check whether that IS registers
+# pending invites for the given 3PID (and then allow it to sign up on
+# the platform):
+#
+#allow_invited_3pids: false
+#
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\.org'
@@ -1141,6 +1280,11 @@ account_validity:
# - medium: msisdn
# pattern: '\+44'
+# If true, stop users from trying to change the 3PIDs associated with
+# their accounts.
+#
+#disable_3pid_changes: false
+
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -1172,6 +1316,30 @@ account_validity:
#
#default_identity_server: https://matrix.org
+# If enabled, user IDs, display names and avatar URLs will be replicated
+# to this server whenever they change.
+# This is an experimental API currently implemented by sydent to support
+# cross-homeserver user directories.
+#
+#replicate_user_profiles_to: example.com
+
+# If specified, attempt to replay registrations, profile changes & 3pid
+# bindings on the given target homeserver via the AS API. The HS is authed
+# via a given AS token.
+#
+#shadow_server:
+# hs_url: https://shadow.example.com
+# hs: shadow.example.com
+# as_token: 12u394refgbdhivsia
+
+# If enabled, don't let users set their own display names/avatars
+# other than for the very first time (unless they are a server admin).
+# Useful when provisioning users based on the contents of a 3rd party
+# directory and to avoid ambiguities.
+#
+#disable_set_displayname: false
+#disable_set_avatar_url: false
+
# Handle threepid (email/phone etc) registration and password resets through a set of
# *trusted* identity servers. Note that this allows the configured identity server to
# reset passwords for accounts!
@@ -1298,6 +1466,31 @@ account_threepid_delegates:
#
#auto_join_rooms_for_guests: false
+# Rewrite identity server URLs with a map from one URL to another. Applies to URLs
+# provided by clients (which have https:// prepended) and those specified
+# in `account_threepid_delegates`. URLs should not feature a trailing slash.
+#
+#rewrite_identity_server_urls:
+# "https://somewhere.example.com": "https://somewhereelse.example.com"
+
+# When a user registers an account with an email address, it can be useful to
+# bind that email address to their mxid on an identity server. Typically, this
+# requires the user to validate their email address with the identity server.
+# However if Synapse itself is handling email validation on registration, the
+# user ends up needing to validate their email twice, which leads to poor UX.
+#
+# It is possible to force Sydent, one identity server implementation, to bind
+# threepids using its internal, unauthenticated bind API:
+# https://github.com/matrix-org/sydent/#internal-bind-and-unbind-api
+#
+# Configure the address of a Sydent server here to have Synapse attempt
+# to automatically bind users' emails following registration. The
+# internal bind API must be reachable from Synapse, but should NOT be
+# exposed to any third party, as it allows the creation of bindings
+# without validation.
+#
+#bind_new_user_emails_to_sydent: https://example.com:8091
+
## Metrics ###
@@ -2002,9 +2195,7 @@ email:
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
- # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
- # If you *do* uncomment it, you will need to make sure that all the templates
- # below are in the directory.
+ # Do not uncomment this setting unless you want to customise the templates.
#
# Synapse will look for the following templates in this directory:
#
@@ -2213,6 +2404,11 @@ spam_checker:
#user_directory:
# enabled: true
# search_all_users: false
+#
+# # If this is set, user search will be delegated to this ID server instead
+# # of synapse performing the search itself.
+# # This is an experimental API.
+# defer_to_id_server: https://id.example.com
# User Consent configuration
diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py
index ca4b879526..5c5a115ca9 100644
--- a/docs/sphinx/conf.py
+++ b/docs/sphinx/conf.py
@@ -12,8 +12,8 @@
# All configuration values have a default; values that are commented out
# serve to show the default.
-import sys
import os
+import sys
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
@@ -191,11 +191,11 @@ htmlhelp_basename = "Synapsedoc"
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
- #'papersize': 'letterpaper',
+ # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
- #'pointsize': '10pt',
+ # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
- #'preamble': '',
+ # 'preamble': '',
}
# Grouping the document tree into LaTeX files. List of tuples
diff --git a/res/templates-dinsic/mail-Vector.css b/res/templates-dinsic/mail-Vector.css
new file mode 100644
index 0000000000..6a3e36eda1
--- /dev/null
+++ b/res/templates-dinsic/mail-Vector.css
@@ -0,0 +1,7 @@
+.header {
+ border-bottom: 4px solid #e4f7ed ! important;
+}
+
+.notif_link a, .footer a {
+ color: #76CFA6 ! important;
+}
diff --git a/res/templates-dinsic/mail.css b/res/templates-dinsic/mail.css
new file mode 100644
index 0000000000..5ab3e1b06d
--- /dev/null
+++ b/res/templates-dinsic/mail.css
@@ -0,0 +1,156 @@
+body {
+ margin: 0px;
+}
+
+pre, code {
+ word-break: break-word;
+ white-space: pre-wrap;
+}
+
+#page {
+ font-family: 'Open Sans', Helvetica, Arial, Sans-Serif;
+ font-color: #454545;
+ font-size: 12pt;
+ width: 100%;
+ padding: 20px;
+}
+
+#inner {
+ width: 640px;
+}
+
+.header {
+ width: 100%;
+ height: 87px;
+ color: #454545;
+ border-bottom: 4px solid #e5e5e5;
+}
+
+.logo {
+ text-align: right;
+ margin-left: 20px;
+}
+
+.salutation {
+ padding-top: 10px;
+ font-weight: bold;
+}
+
+.summarytext {
+}
+
+.room {
+ width: 100%;
+ color: #454545;
+ border-bottom: 1px solid #e5e5e5;
+}
+
+.room_header td {
+ padding-top: 38px;
+ padding-bottom: 10px;
+ border-bottom: 1px solid #e5e5e5;
+}
+
+.room_name {
+ vertical-align: middle;
+ font-size: 18px;
+ font-weight: bold;
+}
+
+.room_header h2 {
+ margin-top: 0px;
+ margin-left: 75px;
+ font-size: 20px;
+}
+
+.room_avatar {
+ width: 56px;
+ line-height: 0px;
+ text-align: center;
+ vertical-align: middle;
+}
+
+.room_avatar img {
+ width: 48px;
+ height: 48px;
+ object-fit: cover;
+ border-radius: 24px;
+}
+
+.notif {
+ border-bottom: 1px solid #e5e5e5;
+ margin-top: 16px;
+ padding-bottom: 16px;
+}
+
+.historical_message .sender_avatar {
+ opacity: 0.3;
+}
+
+/* spell out opacity and historical_message class names for Outlook aka Word */
+.historical_message .sender_name {
+ color: #e3e3e3;
+}
+
+.historical_message .message_time {
+ color: #e3e3e3;
+}
+
+.historical_message .message_body {
+ color: #c7c7c7;
+}
+
+.historical_message td,
+.message td {
+ padding-top: 10px;
+}
+
+.sender_avatar {
+ width: 56px;
+ text-align: center;
+ vertical-align: top;
+}
+
+.sender_avatar img {
+ margin-top: -2px;
+ width: 32px;
+ height: 32px;
+ border-radius: 16px;
+}
+
+.sender_name {
+ display: inline;
+ font-size: 13px;
+ color: #a2a2a2;
+}
+
+.message_time {
+ text-align: right;
+ width: 100px;
+ font-size: 11px;
+ color: #a2a2a2;
+}
+
+.message_body {
+}
+
+.notif_link td {
+ padding-top: 10px;
+ padding-bottom: 10px;
+ font-weight: bold;
+}
+
+.notif_link a, .footer a {
+ color: #454545;
+ text-decoration: none;
+}
+
+.debug {
+ font-size: 10px;
+ color: #888;
+}
+
+.footer {
+ margin-top: 20px;
+ text-align: center;
+}
\ No newline at end of file
diff --git a/res/templates-dinsic/notif.html b/res/templates-dinsic/notif.html
new file mode 100644
index 0000000000..bcdfeea9da
--- /dev/null
+++ b/res/templates-dinsic/notif.html
@@ -0,0 +1,45 @@
+{% for message in notif.messages %}
+ <tr class="{{ "historical_message" if message.is_historical else "message" }}">
+ <td class="sender_avatar">
+ {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
+ {% if message.sender_avatar_url %}
+ <img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
+ {% else %}
+ {% if message.sender_hash % 3 == 0 %}
+ <img class="sender_avatar" src="https://vector.im/beta/img/76cfa6.png" />
+ {% elif message.sender_hash % 3 == 1 %}
+ <img class="sender_avatar" src="https://vector.im/beta/img/50e2c2.png" />
+ {% else %}
+ <img class="sender_avatar" src="https://vector.im/beta/img/f4c371.png" />
+ {% endif %}
+ {% endif %}
+ {% endif %}
+ </td>
+ <td class="message_contents">
+ {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
+ <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
+ {% endif %}
+ <div class="message_body">
+ {% if message.msgtype == "m.text" %}
+ {{ message.body_text_html }}
+ {% elif message.msgtype == "m.emote" %}
+ {{ message.body_text_html }}
+ {% elif message.msgtype == "m.notice" %}
+ {{ message.body_text_html }}
+ {% elif message.msgtype == "m.image" %}
+ <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
+ {% elif message.msgtype == "m.file" %}
+ <span class="filename">{{ message.body_text_plain }}</span>
+ {% endif %}
+ </div>
+ </td>
+ <td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
+ </tr>
+{% endfor %}
+<tr class="notif_link">
+ <td></td>
+ <td>
+ <a href="{{ notif.link }}">Voir {{ room.title }}</a>
+ </td>
+ <td></td>
+</tr>
diff --git a/res/templates-dinsic/notif.txt b/res/templates-dinsic/notif.txt
new file mode 100644
index 0000000000..3dff1bb570
--- /dev/null
+++ b/res/templates-dinsic/notif.txt
@@ -0,0 +1,16 @@
+{% for message in notif.messages %}
+{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
+{% if message.msgtype == "m.text" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.emote" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.notice" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.image" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.file" %}
+{{ message.body_text_plain }}
+{% endif %}
+{% endfor %}
+
+Voir {{ room.title }} à {{ notif.link }}
diff --git a/res/templates-dinsic/notif_mail.html b/res/templates-dinsic/notif_mail.html
new file mode 100644
index 0000000000..1e1efa74b2
--- /dev/null
+++ b/res/templates-dinsic/notif_mail.html
@@ -0,0 +1,55 @@
+<!doctype html>
+<html lang="en">
+ <head>
+ <style type="text/css">
+ {% include 'mail.css' without context %}
+ {% include "mail-%s.css" % app_name ignore missing without context %}
+ </style>
+ </head>
+ <body>
+ <table id="page">
+ <tr>
+ <td> </td>
+ <td id="inner">
+ <table class="header">
+ <tr>
+ <td>
+ <div class="salutation">Bonjour {{ user_display_name }},</div>
+ <div class="summarytext">{{ summary_text }}</div>
+ </td>
+ <td class="logo">
+ {% if app_name == "Riot" %}
+ <img src="http://matrix.org/img/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
+ {% elif app_name == "Vector" %}
+ <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
+ {% else %}
+ <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
+ {% endif %}
+ </td>
+ </tr>
+ </table>
+ {% for room in rooms %}
+ {% include 'room.html' with context %}
+ {% endfor %}
+ <div class="footer">
+ <a href="{{ unsubscribe_link }}">Se désinscrire</a>
+ <br/>
+ <br/>
+ <div class="debug">
+ Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
+ an event was received at {{ reason.received_at|format_ts("%c") }}
+ which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
+ {% if reason.last_sent_ts %}
+ and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
+ which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
+ {% else %}
+ and we don't have a last time we sent a mail for this room.
+ {% endif %}
+ </div>
+ </div>
+ </td>
+ <td> </td>
+ </tr>
+ </table>
+ </body>
+</html>
diff --git a/res/templates-dinsic/notif_mail.txt b/res/templates-dinsic/notif_mail.txt
new file mode 100644
index 0000000000..fae877426f
--- /dev/null
+++ b/res/templates-dinsic/notif_mail.txt
@@ -0,0 +1,10 @@
+Bonjour {{ user_display_name }},
+
+{{ summary_text }}
+
+{% for room in rooms %}
+{% include 'room.txt' with context %}
+{% endfor %}
+
+Vous pouvez désactiver ces notifications en cliquant ici {{ unsubscribe_link }}
+
diff --git a/res/templates-dinsic/room.html b/res/templates-dinsic/room.html
new file mode 100644
index 0000000000..0487b1b11c
--- /dev/null
+++ b/res/templates-dinsic/room.html
@@ -0,0 +1,33 @@
+<table class="room">
+ <tr class="room_header">
+ <td class="room_avatar">
+ {% if room.avatar_url %}
+ <img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
+ {% else %}
+ {% if room.hash % 3 == 0 %}
+ <img alt="" src="https://vector.im/beta/img/76cfa6.png" />
+ {% elif room.hash % 3 == 1 %}
+ <img alt="" src="https://vector.im/beta/img/50e2c2.png" />
+ {% else %}
+ <img alt="" src="https://vector.im/beta/img/f4c371.png" />
+ {% endif %}
+ {% endif %}
+ </td>
+ <td class="room_name" colspan="2">
+ {{ room.title }}
+ </td>
+ </tr>
+ {% if room.invite %}
+ <tr>
+ <td></td>
+ <td>
+ <a href="{{ room.link }}">Rejoindre la conversation.</a>
+ </td>
+ <td></td>
+ </tr>
+ {% else %}
+ {% for notif in room.notifs %}
+ {% include 'notif.html' with context %}
+ {% endfor %}
+ {% endif %}
+</table>
diff --git a/res/templates-dinsic/room.txt b/res/templates-dinsic/room.txt
new file mode 100644
index 0000000000..dd36d01d21
--- /dev/null
+++ b/res/templates-dinsic/room.txt
@@ -0,0 +1,9 @@
+{{ room.title }}
+
+{% if room.invite %}
+ Vous avez été invité, rejoignez la conversation en cliquant sur le lien suivant {{ room.link }}
+{% else %}
+ {% for notif in room.notifs %}
+ {% include 'notif.txt' with context %}
+ {% endfor %}
+{% endif %}
diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment
index 448cadb829..d742c522b5 100755
--- a/scripts-dev/check-newsfragment
+++ b/scripts-dev/check-newsfragment
@@ -7,9 +7,9 @@ echo -e "+++ \033[32mChecking newsfragment\033[m"
set -e
-# make sure that origin/develop is up to date
-git remote set-branches --add origin develop
-git fetch -q origin develop
+# make sure that origin/dinsic is up to date
+git remote set-branches --add origin dinsic
+git fetch -q origin dinsic
pr="$BUILDKITE_PULL_REQUEST"
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index a34bdf1830..604e0fd662 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -45,6 +45,7 @@ from synapse.storage.databases.main.events_bg_updates import (
from synapse.storage.databases.main.media_repository import (
MediaRepositoryBackgroundUpdateStore,
)
+from synapse.storage.databases.main.profile import ProfileStore
from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart,
@@ -163,6 +164,7 @@ class Store(
DeviceBackgroundUpdateStore,
EventsBackgroundUpdatesStore,
MediaRepositoryBackgroundUpdateStore,
+ ProfileStore,
RegistrationBackgroundUpdateStore,
RoomBackgroundUpdateStore,
RoomMemberBackgroundUpdateStore,
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index d8190f92ab..432b5885c8 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -190,7 +190,7 @@ class Auth(object):
access_token = self.get_access_token_from_request(request)
- user_id, app_service = await self._get_appservice_user_id(request)
+ user_id, app_service = self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
@@ -213,6 +213,7 @@ class Auth(object):
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
+ shadow_banned = user_info["shadow_banned"]
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
@@ -252,15 +253,21 @@ class Auth(object):
opentracing.set_tag("device_id", device_id)
return synapse.types.create_requester(
- user, token_id, is_guest, device_id, app_service=app_service
+ user,
+ token_id,
+ is_guest,
+ shadow_banned,
+ device_id,
+ app_service=app_service,
)
except KeyError:
raise MissingClientTokenError()
- async def _get_appservice_user_id(self, request):
+ def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
+
if app_service is None:
return None, None
@@ -278,8 +285,12 @@ class Auth(object):
if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
- if not (await self.store.get_user_by_id(user_id)):
- raise AuthError(403, "Application service has not registered this user")
+ # Let ASes manipulate nonexistent users (e.g. to shadow-register them)
+ # if not (yield self.store.get_user_by_id(user_id)):
+ # raise AuthError(
+ # 403,
+ # "Application service has not registered this user"
+ # )
return user_id, app_service
async def get_user_by_access_token(
@@ -297,6 +308,7 @@ class Auth(object):
dict that includes:
`user` (UserID)
`is_guest` (bool)
+ `shadow_banned` (bool)
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises:
@@ -356,6 +368,7 @@ class Auth(object):
ret = {
"user": user,
"is_guest": True,
+ "shadow_banned": False,
"token_id": None,
# all guests get the same device id
"device_id": GUEST_DEVICE_ID,
@@ -365,6 +378,7 @@ class Auth(object):
ret = {
"user": user,
"is_guest": False,
+ "shadow_banned": False,
"token_id": None,
"device_id": None,
}
@@ -488,6 +502,7 @@ class Auth(object):
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
+ "shadow_banned": ret.get("shadow_banned"),
"device_id": ret.get("device_id"),
"valid_until_ms": ret.get("valid_until_ms"),
}
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 6e40630ab6..28a078a7b4 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,10 +22,10 @@ import typing
from http import HTTPStatus
from typing import Dict, List, Optional, Union
-from canonicaljson import json
-
from twisted.web import http
+from synapse.util import json_decoder
+
if typing.TYPE_CHECKING:
from synapse.types import JsonDict
@@ -593,7 +594,7 @@ class HttpResponseException(CodeMessageException):
# try to parse the body as json, to get better errcode/msg, but
# default to M_UNKNOWN with the HTTP status as the error text
try:
- j = json.loads(self.response.decode("utf-8"))
+ j = json_decoder.decode(self.response.decode("utf-8"))
except ValueError:
j = {}
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 7393d6cb74..a8937d2595 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -23,7 +23,7 @@ from jsonschema import FormatChecker
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
-from synapse.storage.presence import UserPresenceState
+from synapse.api.presence import UserPresenceState
from synapse.types import RoomID, UserID
FILTER_SCHEMA = {
diff --git a/synapse/storage/presence.py b/synapse/api/presence.py
index 18a462f0ee..18a462f0ee 100644
--- a/synapse/storage/presence.py
+++ b/synapse/api/presence.py
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 739b013d4c..89e6c0d327 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -157,7 +157,7 @@ class PresenceStatusStubServlet(RestServlet):
async def on_GET(self, request, user_id):
await self.auth.get_user_by_req(request)
- return 200, {"presence": "offline"}
+ return 200, {"presence": "offline", "user_id": user_id}
async def on_PUT(self, request, user_id):
await self.auth.get_user_by_req(request)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index fd137853b1..1477b27326 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,12 +18,17 @@
import argparse
import errno
import os
+import time
+import urllib.parse
from collections import OrderedDict
from hashlib import sha256
+from io import open as io_open
from textwrap import dedent
-from typing import Any, List, MutableMapping, Optional
+from typing import Any, Callable, List, MutableMapping, Optional
import attr
+import jinja2
+import pkg_resources
import yaml
@@ -100,6 +105,11 @@ class Config(object):
def __init__(self, root_config=None):
self.root = root_config
+ # Get the path to the default Synapse template directory
+ self.default_template_dir = pkg_resources.resource_filename(
+ "synapse", "res/templates"
+ )
+
def __getattr__(self, item: str) -> Any:
"""
Try and fetch a configuration option that does not exist on this class.
@@ -181,9 +191,98 @@ class Config(object):
@classmethod
def read_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
- with open(file_path) as file_stream:
+ with io_open(file_path, encoding="utf-8") as file_stream:
return file_stream.read()
+ def read_templates(
+ self, filenames: List[str], custom_template_directory: Optional[str] = None,
+ ) -> List[jinja2.Template]:
+ """Load a list of template files from disk using the given variables.
+
+ This function will attempt to load the given templates from the default Synapse
+ template directory. If `custom_template_directory` is supplied, that directory
+ is tried first.
+
+ Files read are treated as Jinja templates. These templates are not rendered yet.
+
+ Args:
+ filenames: A list of template filenames to read.
+
+ custom_template_directory: A directory to try to look for the templates
+ before using the default Synapse template directory instead.
+
+ Raises:
+ ConfigError: if the file's path is incorrect or otherwise cannot be read.
+
+ Returns:
+ A list of jinja2 templates.
+ """
+ templates = []
+ search_directories = [self.default_template_dir]
+
+ # The loader will first look in the custom template directory (if specified) for the
+ # given filename. If it doesn't find it, it will use the default template dir instead
+ if custom_template_directory:
+ # Check that the given template directory exists
+ if not self.path_exists(custom_template_directory):
+ raise ConfigError(
+ "Configured template directory does not exist: %s"
+ % (custom_template_directory,)
+ )
+
+ # Search the custom template directory as well
+ search_directories.insert(0, custom_template_directory)
+
+ loader = jinja2.FileSystemLoader(search_directories)
+ env = jinja2.Environment(loader=loader, autoescape=True)
+
+ # Update the environment with our custom filters
+ env.filters.update(
+ {
+ "format_ts": _format_ts_filter,
+ "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
+ }
+ )
+
+ for filename in filenames:
+ # Load the template
+ template = env.get_template(filename)
+ templates.append(template)
+
+ return templates
+
+
+def _format_ts_filter(value: int, format: str):
+ return time.strftime(format, time.localtime(value / 1000))
+
+
+def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
+ """Create and return a jinja2 filter that converts MXC urls to HTTP
+
+ Args:
+ public_baseurl: The public, accessible base URL of the homeserver
+ """
+
+ def mxc_to_http_filter(value, width, height, resize_method="crop"):
+ if value[0:6] != "mxc://":
+ return ""
+
+ server_and_media_id = value[6:]
+ fragment = None
+ if "#" in server_and_media_id:
+ server_and_media_id, fragment = server_and_media_id.split("#", 1)
+ fragment = "#" + fragment
+
+ params = {"width": width, "height": height, "method": resize_method}
+ return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+ public_baseurl,
+ server_and_media_id,
+ urllib.parse.urlencode(params),
+ fragment or "",
+ )
+
+ return mxc_to_http_filter
+
class RootConfig(object):
"""
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index a63acbdc63..7a796996c0 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -23,7 +23,6 @@ from enum import Enum
from typing import Optional
import attr
-import pkg_resources
from ._base import Config, ConfigError
@@ -98,21 +97,18 @@ class EmailConfig(Config):
if parsed[1] == "":
raise RuntimeError("Invalid notif_from address")
+ # A user-configurable template directory
template_dir = email_config.get("template_dir")
- # we need an absolute path, because we change directory after starting (and
- # we don't yet know what auxiliary templates like mail.css we will need).
- # (Note that loading as package_resources with jinja.PackageLoader doesn't
- # work for the same reason.)
- if not template_dir:
- template_dir = pkg_resources.resource_filename("synapse", "res/templates")
-
- self.email_template_dir = os.path.abspath(template_dir)
+ if isinstance(template_dir, str):
+ # We need an absolute path, because we change directory after starting (and
+ # we don't yet know what auxiliary templates like mail.css we will need).
+ template_dir = os.path.abspath(template_dir)
+ elif template_dir is not None:
+ # If template_dir is something other than a str or None, warn the user
+ raise ConfigError("Config option email.template_dir must be type str")
self.email_enable_notifs = email_config.get("enable_notifs", False)
- account_validity_config = config.get("account_validity") or {}
- account_validity_renewal_enabled = account_validity_config.get("renew_at")
-
self.threepid_behaviour_email = (
# Have Synapse handle the email sending if account_threepid_delegates.email
# is not defined
@@ -166,19 +162,6 @@ class EmailConfig(Config):
email_config.get("validation_token_lifetime", "1h")
)
- if (
- self.email_enable_notifs
- or account_validity_renewal_enabled
- or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
- ):
- # make sure we can import the required deps
- import bleach
- import jinja2
-
- # prevent unused warnings
- jinja2
- bleach
-
if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
missing = []
if not self.email_notif_from:
@@ -196,49 +179,49 @@ class EmailConfig(Config):
# These email templates have placeholders in them, and thus must be
# parsed using a templating engine during a request
- self.email_password_reset_template_html = email_config.get(
+ password_reset_template_html = email_config.get(
"password_reset_template_html", "password_reset.html"
)
- self.email_password_reset_template_text = email_config.get(
+ password_reset_template_text = email_config.get(
"password_reset_template_text", "password_reset.txt"
)
- self.email_registration_template_html = email_config.get(
+ registration_template_html = email_config.get(
"registration_template_html", "registration.html"
)
- self.email_registration_template_text = email_config.get(
+ registration_template_text = email_config.get(
"registration_template_text", "registration.txt"
)
- self.email_add_threepid_template_html = email_config.get(
+ add_threepid_template_html = email_config.get(
"add_threepid_template_html", "add_threepid.html"
)
- self.email_add_threepid_template_text = email_config.get(
+ add_threepid_template_text = email_config.get(
"add_threepid_template_text", "add_threepid.txt"
)
- self.email_password_reset_template_failure_html = email_config.get(
+ password_reset_template_failure_html = email_config.get(
"password_reset_template_failure_html", "password_reset_failure.html"
)
- self.email_registration_template_failure_html = email_config.get(
+ registration_template_failure_html = email_config.get(
"registration_template_failure_html", "registration_failure.html"
)
- self.email_add_threepid_template_failure_html = email_config.get(
+ add_threepid_template_failure_html = email_config.get(
"add_threepid_template_failure_html", "add_threepid_failure.html"
)
# These templates do not support any placeholder variables, so we
# will read them from disk once during setup
- email_password_reset_template_success_html = email_config.get(
+ password_reset_template_success_html = email_config.get(
"password_reset_template_success_html", "password_reset_success.html"
)
- email_registration_template_success_html = email_config.get(
+ registration_template_success_html = email_config.get(
"registration_template_success_html", "registration_success.html"
)
- email_add_threepid_template_success_html = email_config.get(
+ add_threepid_template_success_html = email_config.get(
"add_threepid_template_success_html", "add_threepid_success.html"
)
- # Check templates exist
- for f in [
+ # Read all templates from disk
+ (
self.email_password_reset_template_html,
self.email_password_reset_template_text,
self.email_registration_template_html,
@@ -248,32 +231,36 @@ class EmailConfig(Config):
self.email_password_reset_template_failure_html,
self.email_registration_template_failure_html,
self.email_add_threepid_template_failure_html,
- email_password_reset_template_success_html,
- email_registration_template_success_html,
- email_add_threepid_template_success_html,
- ]:
- p = os.path.join(self.email_template_dir, f)
- if not os.path.isfile(p):
- raise ConfigError("Unable to find template file %s" % (p,))
-
- # Retrieve content of web templates
- filepath = os.path.join(
- self.email_template_dir, email_password_reset_template_success_html
+ password_reset_template_success_html_template,
+ registration_template_success_html_template,
+ add_threepid_template_success_html_template,
+ ) = self.read_templates(
+ [
+ password_reset_template_html,
+ password_reset_template_text,
+ registration_template_html,
+ registration_template_text,
+ add_threepid_template_html,
+ add_threepid_template_text,
+ password_reset_template_failure_html,
+ registration_template_failure_html,
+ add_threepid_template_failure_html,
+ password_reset_template_success_html,
+ registration_template_success_html,
+ add_threepid_template_success_html,
+ ],
+ template_dir,
)
- self.email_password_reset_template_success_html = self.read_file(
- filepath, "email.password_reset_template_success_html"
- )
- filepath = os.path.join(
- self.email_template_dir, email_registration_template_success_html
- )
- self.email_registration_template_success_html_content = self.read_file(
- filepath, "email.registration_template_success_html"
+
+ # Render templates that do not contain any placeholders
+ self.email_password_reset_template_success_html_content = (
+ password_reset_template_success_html_template.render()
)
- filepath = os.path.join(
- self.email_template_dir, email_add_threepid_template_success_html
+ self.email_registration_template_success_html_content = (
+ registration_template_success_html_template.render()
)
- self.email_add_threepid_template_success_html_content = self.read_file(
- filepath, "email.add_threepid_template_success_html"
+ self.email_add_threepid_template_success_html_content = (
+ add_threepid_template_success_html_template.render()
)
if self.email_enable_notifs:
@@ -290,17 +277,19 @@ class EmailConfig(Config):
% (", ".join(missing),)
)
- self.email_notif_template_html = email_config.get(
+ notif_template_html = email_config.get(
"notif_template_html", "notif_mail.html"
)
- self.email_notif_template_text = email_config.get(
+ notif_template_text = email_config.get(
"notif_template_text", "notif_mail.txt"
)
- for f in self.email_notif_template_text, self.email_notif_template_html:
- p = os.path.join(self.email_template_dir, f)
- if not os.path.isfile(p):
- raise ConfigError("Unable to find email template file %s" % (p,))
+ (
+ self.email_notif_template_html,
+ self.email_notif_template_text,
+ ) = self.read_templates(
+ [notif_template_html, notif_template_text], template_dir,
+ )
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
@@ -309,18 +298,20 @@ class EmailConfig(Config):
"client_base_url", email_config.get("riot_base_url", None)
)
- if account_validity_renewal_enabled:
- self.email_expiry_template_html = email_config.get(
+ if self.account_validity.renew_by_email_enabled:
+ expiry_template_html = email_config.get(
"expiry_template_html", "notice_expiry.html"
)
- self.email_expiry_template_text = email_config.get(
+ expiry_template_text = email_config.get(
"expiry_template_text", "notice_expiry.txt"
)
- for f in self.email_expiry_template_text, self.email_expiry_template_html:
- p = os.path.join(self.email_template_dir, f)
- if not os.path.isfile(p):
- raise ConfigError("Unable to find email template file %s" % (p,))
+ (
+ self.account_validity_template_html,
+ self.account_validity_template_text,
+ ) = self.read_templates(
+ [expiry_template_html, expiry_template_text], template_dir,
+ )
subjects_config = email_config.get("subjects", {})
subjects = {}
@@ -400,9 +391,7 @@ class EmailConfig(Config):
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
- # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
- # If you *do* uncomment it, you will need to make sure that all the templates
- # below are in the directory.
+ # Do not uncomment this setting unless you want to customise the templates.
#
# Synapse will look for the following templates in this directory:
#
diff --git a/synapse/config/password.py b/synapse/config/password.py
index 9c0ea8c30a..6b2dae78b0 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index b2c78ac40c..c3a1d377c5 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -76,6 +76,9 @@ class RatelimitConfig(Config):
)
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+ self.rc_third_party_invite = RateLimitConfig(
+ config.get("rc_third_party_invite", {})
+ )
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
@@ -124,6 +127,8 @@ class RatelimitConfig(Config):
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+ # - one that ratelimits third-party invites requests based on the account
+ # that's making the requests.
# - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly.
@@ -153,6 +158,10 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
#
+ #rc_third_party_invite:
+ # per_second: 0.2
+ # burst_count: 10
+ #
#rc_admin_redaction:
# per_second: 1
# burst_count: 50
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index a185655774..3c2e951a71 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -100,8 +100,19 @@ class RegistrationConfig(Config):
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
+ self.check_is_for_allowed_local_3pids = config.get(
+ "check_is_for_allowed_local_3pids", None
+ )
+ self.allow_invited_3pids = config.get("allow_invited_3pids", False)
+
+ self.disable_3pid_changes = config.get("disable_3pid_changes", False)
+
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
self.registration_shared_secret = config.get("registration_shared_secret")
+ self.register_mxid_from_3pid = config.get("register_mxid_from_3pid")
+ self.register_just_use_email_for_display_name = config.get(
+ "register_just_use_email_for_display_name", False
+ )
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config.get(
@@ -109,7 +120,21 @@ class RegistrationConfig(Config):
)
account_threepid_delegates = config.get("account_threepid_delegates") or {}
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
+ if (
+ self.account_threepid_delegate_email
+ and not self.account_threepid_delegate_email.startswith("http")
+ ):
+ raise ConfigError(
+ "account_threepid_delegates.email must begin with http:// or https://"
+ )
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
+ if (
+ self.account_threepid_delegate_msisdn
+ and not self.account_threepid_delegate_msisdn.startswith("http")
+ ):
+ raise ConfigError(
+ "account_threepid_delegates.msisdn must begin with http:// or https://"
+ )
if self.account_threepid_delegate_msisdn and not self.public_baseurl:
raise ConfigError(
"The configuration option `public_baseurl` is required if "
@@ -178,6 +203,15 @@ class RegistrationConfig(Config):
self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
self.enable_3pid_changes = config.get("enable_3pid_changes", True)
+ self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", [])
+ if not isinstance(self.replicate_user_profiles_to, list):
+ self.replicate_user_profiles_to = [self.replicate_user_profiles_to]
+
+ self.shadow_server = config.get("shadow_server", None)
+ self.rewrite_identity_server_urls = (
+ config.get("rewrite_identity_server_urls") or {}
+ )
+
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
)
@@ -187,6 +221,23 @@ class RegistrationConfig(Config):
session_lifetime = self.parse_duration(session_lifetime)
self.session_lifetime = session_lifetime
+ self.bind_new_user_emails_to_sydent = config.get(
+ "bind_new_user_emails_to_sydent"
+ )
+
+ if self.bind_new_user_emails_to_sydent:
+ if not isinstance(
+ self.bind_new_user_emails_to_sydent, str
+ ) or not self.bind_new_user_emails_to_sydent.startswith("http"):
+ raise ConfigError(
+ "Option bind_new_user_emails_to_sydent has invalid value"
+ )
+
+ # Remove trailing slashes
+ self.bind_new_user_emails_to_sydent = self.bind_new_user_emails_to_sydent.strip(
+ "/"
+ )
+
def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
@@ -291,9 +342,32 @@ class RegistrationConfig(Config):
#
#disable_msisdn_registration: true
+ # Derive the user's matrix ID from a type of 3PID used when registering.
+ # This overrides any matrix ID the user proposes when calling /register
+ # The 3PID type should be present in registrations_require_3pid to avoid
+ # users failing to register if they don't specify the right kind of 3pid.
+ #
+ #register_mxid_from_3pid: email
+
+ # Uncomment to set the display name of new users to their email address,
+ # rather than using the default heuristic.
+ #
+ #register_just_use_email_for_display_name: true
+
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
+ # Use an Identity Server to establish which 3PIDs are allowed to register?
+ # Overrides allowed_local_3pids below.
+ #
+ #check_is_for_allowed_local_3pids: matrix.org
+ #
+ # If you are using an IS you can also check whether that IS registers
+ # pending invites for the given 3PID (and then allow it to sign up on
+ # the platform):
+ #
+ #allow_invited_3pids: false
+ #
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\\.org'
@@ -302,6 +376,11 @@ class RegistrationConfig(Config):
# - medium: msisdn
# pattern: '\\+44'
+ # If true, stop users from trying to change the 3PIDs associated with
+ # their accounts.
+ #
+ #disable_3pid_changes: false
+
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -333,6 +412,30 @@ class RegistrationConfig(Config):
#
#default_identity_server: https://matrix.org
+ # If enabled, user IDs, display names and avatar URLs will be replicated
+ # to this server whenever they change.
+ # This is an experimental API currently implemented by sydent to support
+ # cross-homeserver user directories.
+ #
+ #replicate_user_profiles_to: example.com
+
+ # If specified, attempt to replay registrations, profile changes & 3pid
+ # bindings on the given target homeserver via the AS API. The HS is authed
+ # via a given AS token.
+ #
+ #shadow_server:
+ # hs_url: https://shadow.example.com
+ # hs: shadow.example.com
+ # as_token: 12u394refgbdhivsia
+
+ # If enabled, don't let users set their own display names/avatars
+ # other than for the very first time (unless they are a server admin).
+ # Useful when provisioning users based on the contents of a 3rd party
+ # directory and to avoid ambiguities.
+ #
+ #disable_set_displayname: false
+ #disable_set_avatar_url: false
+
# Handle threepid (email/phone etc) registration and password resets through a set of
# *trusted* identity servers. Note that this allows the configured identity server to
# reset passwords for accounts!
@@ -458,6 +561,31 @@ class RegistrationConfig(Config):
# Defaults to true.
#
#auto_join_rooms_for_guests: false
+
+ # Rewrite identity server URLs with a map from one URL to another. Applies to URLs
+ # provided by clients (which have https:// prepended) and those specified
+ # in `account_threepid_delegates`. URLs should not feature a trailing slash.
+ #
+ #rewrite_identity_server_urls:
+ # "https://somewhere.example.com": "https://somewhereelse.example.com"
+
+ # When a user registers an account with an email address, it can be useful to
+ # bind that email address to their mxid on an identity server. Typically, this
+ # requires the user to validate their email address with the identity server.
+ # However if Synapse itself is handling email validation on registration, the
+ # user ends up needing to validate their email twice, which leads to poor UX.
+ #
+ # It is possible to force Sydent, one identity server implementation, to bind
+ # threepids using its internal, unauthenticated bind API:
+ # https://github.com/matrix-org/sydent/#internal-bind-and-unbind-api
+ #
+ # Configure the address of a Sydent server here to have Synapse attempt
+ # to automatically bind users' emails following registration. The
+ # internal bind API must be reachable from Synapse, but should NOT be
+ # exposed to any third party, as it allows the creation of bindings
+ # without validation.
+ #
+ #bind_new_user_emails_to_sydent: https://example.com:8091
"""
% locals()
)
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 01009f3924..54f565ad5b 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -104,6 +104,12 @@ class ContentRepositoryConfig(Config):
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
+ self.max_avatar_size = config.get("max_avatar_size")
+ if self.max_avatar_size:
+ self.max_avatar_size = self.parse_size(self.max_avatar_size)
+
+ self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes", [])
+
self.media_store_path = self.ensure_directory(
config.get("media_store_path", "media_store")
)
@@ -244,6 +250,30 @@ class ContentRepositoryConfig(Config):
#
#max_upload_size: 10M
+ # The largest allowed size for a user avatar. If not defined, no
+ # restriction will be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #max_avatar_size: 10M
+
+ # Allow mimetypes for a user avatar. If not defined, no restriction will
+ # be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 9277b5f342..036f8c0e90 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -18,8 +18,6 @@ import logging
from typing import Any, List
import attr
-import jinja2
-import pkg_resources
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module
@@ -171,15 +169,9 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "15m")
)
- template_dir = saml2_config.get("template_dir")
- if not template_dir:
- template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
-
- loader = jinja2.FileSystemLoader(template_dir)
- # enable auto-escape here, to having to remember to escape manually in the
- # template
- env = jinja2.Environment(loader=loader, autoescape=True)
- self.saml2_error_html_template = env.get_template("saml_error.html")
+ self.saml2_error_html_template = self.read_templates(
+ ["saml_error.html"], saml2_config.get("template_dir")
+ )
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 9f15ed109e..5a6a55cc4d 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,14 +19,13 @@ import logging
import os.path
import re
from textwrap import indent
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Set
import attr
import yaml
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
-from synapse.python_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@@ -277,6 +276,12 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # True.
+ self.show_users_in_user_directory = config.get(
+ "show_users_in_user_directory", True
+ )
+
retention_config = config.get("retention")
if retention_config is None:
retention_config = {}
@@ -508,8 +513,6 @@ class ServerConfig(Config):
)
)
- _check_resource_config(self.listeners)
-
self.cleanup_extremities_with_dummy_events = config.get(
"cleanup_extremities_with_dummy_events", True
)
@@ -545,6 +548,19 @@ class ServerConfig(Config):
users_new_default_push_rules
) # type: set
+ # Whitelist of domain names that given next_link parameters must have
+ next_link_domain_whitelist = config.get(
+ "next_link_domain_whitelist"
+ ) # type: Optional[List[str]]
+
+ self.next_link_domain_whitelist = None # type: Optional[Set[str]]
+ if next_link_domain_whitelist is not None:
+ if not isinstance(next_link_domain_whitelist, list):
+ raise ConfigError("'next_link_domain_whitelist' must be a list")
+
+ # Turn the list into a set to improve lookup speed.
+ self.next_link_domain_whitelist = set(next_link_domain_whitelist)
+
def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners)
@@ -926,6 +942,74 @@ class ServerConfig(Config):
#
#allow_per_room_profiles: false
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # 'true'.
+ #
+ #show_users_in_user_directory: false
+
+ # Message retention policy at the server level.
+ #
+ # Room admins and mods can define a retention period for their rooms using the
+ # 'm.room.retention' state event, and server admins can cap this period by setting
+ # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+ #
+ # If this feature is enabled, Synapse will regularly look for and purge events
+ # which are older than the room's maximum retention period. Synapse will also
+ # filter events received over federation so that events that should have been
+ # purged are ignored and not stored again.
+ #
+ retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
+
# How long to keep redacted events in unredacted form in the database. After
# this period redacted events get replaced with their redacted form in the DB.
#
@@ -1011,6 +1095,24 @@ class ServerConfig(Config):
# act as if no error happened and return a fake session ID ('sid') to clients.
#
#request_token_inhibit_3pid_errors: true
+
+ # A list of domains that the domain portion of 'next_link' parameters
+ # must match.
+ #
+ # This parameter is optionally provided by clients while requesting
+ # validation of an email or phone number, and maps to a link that
+ # users will be automatically redirected to after validation
+ # succeeds. Clients can make use this parameter to aid the validation
+ # process.
+ #
+ # The whitelist is applied whether the homeserver or an
+ # identity server is handling validation.
+ #
+ # The default value is no whitelist functionality; all domains are
+ # allowed. Setting this value to an empty list will instead disallow
+ # all domains.
+ #
+ #next_link_domain_whitelist: ["matrix.org"]
"""
% locals()
)
@@ -1133,20 +1235,3 @@ def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
if name == "webclient":
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
return
-
-
-def _check_resource_config(listeners: Iterable[ListenerConfig]) -> None:
- resource_names = {
- res_name
- for listener in listeners
- if listener.http_options
- for res in listener.http_options.resources
- for res_name in res.names
- }
-
- for resource in resource_names:
- if resource == "consent":
- try:
- check_requirements("resources.consent")
- except DependencyException as e:
- raise ConfigError(e.message)
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 73b7296399..4427676167 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -12,11 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
from typing import Any, Dict
-import pkg_resources
-
from ._base import Config
@@ -29,22 +26,32 @@ class SSOConfig(Config):
def read_config(self, config, **kwargs):
sso_config = config.get("sso") or {} # type: Dict[str, Any]
- # Pick a template directory in order of:
- # * The sso-specific template_dir
- # * /path/to/synapse/install/res/templates
+ # The sso-specific template_dir
template_dir = sso_config.get("template_dir")
- if not template_dir:
- template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
- self.sso_template_dir = template_dir
- self.sso_account_deactivated_template = self.read_file(
- os.path.join(self.sso_template_dir, "sso_account_deactivated.html"),
- "sso_account_deactivated_template",
+ # Read templates from disk
+ (
+ self.sso_redirect_confirm_template,
+ self.sso_auth_confirm_template,
+ self.sso_error_template,
+ sso_account_deactivated_template,
+ sso_auth_success_template,
+ ) = self.read_templates(
+ [
+ "sso_redirect_confirm.html",
+ "sso_auth_confirm.html",
+ "sso_error.html",
+ "sso_account_deactivated.html",
+ "sso_auth_success.html",
+ ],
+ template_dir,
)
- self.sso_auth_success_template = self.read_file(
- os.path.join(self.sso_template_dir, "sso_auth_success.html"),
- "sso_auth_success_template",
+
+ # These templates have no placeholders, so render them here
+ self.sso_account_deactivated_template = (
+ sso_account_deactivated_template.render()
)
+ self.sso_auth_success_template = sso_auth_success_template.render()
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c8d19c5d6b..43b6c40456 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -26,6 +26,7 @@ class UserDirectoryConfig(Config):
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
+ self.user_directory_defer_to_id_server = None
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_enabled = user_directory_config.get(
@@ -34,6 +35,9 @@ class UserDirectoryConfig(Config):
self.user_directory_search_all_users = user_directory_config.get(
"search_all_users", False
)
+ self.user_directory_defer_to_id_server = user_directory_config.get(
+ "defer_to_id_server", None
+ )
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
@@ -52,4 +56,9 @@ class UserDirectoryConfig(Config):
#user_directory:
# enabled: true
# search_all_users: false
+ #
+ # # If this is set, user search will be delegated to this ID server instead
+ # # of synapse performing the search itself.
+ # # This is an experimental API.
+ # defer_to_id_server: https://id.example.com
"""
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 28ef7cfdb9..81c4b430b2 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -757,9 +757,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name)
- return await yieldable_gather_results(
- get_key, keys_to_fetch.items()
- ).addCallback(lambda _: results)
+ await yieldable_gather_results(get_key, keys_to_fetch.items())
+ return results
async def get_server_verify_key_v2_direct(self, server_name, key_ids):
"""
@@ -769,7 +768,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
key_ids (iterable[str]):
Returns:
- Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+ dict[str, FetchKeyResult]: map from key ID to lookup result
Raises:
KeyLookupError if there was a problem making the lookup
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c0981eee62..8c907ad596 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -47,7 +47,7 @@ def check(
Args:
room_version_obj: the version of the room
event: the event being checked.
- auth_events (dict: event-key -> event): the existing room state.
+ auth_events: the existing room state.
Raises:
AuthError if the checks fail
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 1ffc9525d1..3c11e317fd 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,7 +15,7 @@
# limitations under the License.
import inspect
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional
from synapse.spam_checker_api import SpamCheckerApi
@@ -58,42 +58,82 @@ class SpamChecker(object):
return False
def user_may_invite(
- self, inviter_userid: str, invitee_userid: str, room_id: str
+ self,
+ inviter_userid: str,
+ invitee_userid: str,
+ third_party_invite: Optional[Dict],
+ room_id: str,
+ new_room: bool,
+ published_room: bool,
) -> bool:
"""Checks if a given user may send an invite
If this method returns false, the invite will be rejected.
Args:
- inviter_userid: The user ID of the sender of the invitation
- invitee_userid: The user ID targeted in the invitation
- room_id: The room ID
+ inviter_userid:
+ invitee_userid: The user ID of the invitee. Is None
+ if this is a third party invite and the 3PID is not bound to a
+ user ID.
+ third_party_invite: If a third party invite then is a
+ dict containing the medium and address of the invitee.
+ room_id:
+ new_room: Whether the user is being invited to the room as
+ part of a room creation, if so the invitee would have been
+ included in the call to `user_may_create_room`.
+ published_room: Whether the room the user is being invited
+ to has been published in the local homeserver's public room
+ directory.
Returns:
True if the user may send an invite, otherwise False
"""
for spam_checker in self.spam_checkers:
if (
- spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+ spam_checker.user_may_invite(
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room,
+ )
is False
):
return False
return True
- def user_may_create_room(self, userid: str) -> bool:
+ def user_may_create_room(
+ self,
+ userid: str,
+ invite_list: List[str],
+ third_party_invite_list: List[Dict],
+ cloning: bool,
+ ) -> bool:
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
Args:
userid: The ID of the user attempting to create a room
+ invite_list: List of user IDs that would be invited to
+ the new room.
+ third_party_invite_list: List of third party invites
+ for the new room.
+ cloning: Whether the user is cloning an existing room, e.g.
+ upgrading a room.
Returns:
True if the user may create a room, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_create_room(userid) is False:
+ if (
+ spam_checker.user_may_create_room(
+ userid, invite_list, third_party_invite_list, cloning
+ )
+ is False
+ ):
return False
return True
@@ -134,6 +174,25 @@ class SpamChecker(object):
return True
+ def user_may_join_room(self, userid: str, room_id: str, is_invited: bool):
+ """Checks if a given users is allowed to join a room.
+
+ Not called when a user creates a room.
+
+ Args:
+ userid:
+ room_id:
+ is_invited: Whether the user is invited into the room
+
+ Returns:
+ bool: Whether the user may join the room
+ """
+ for spam_checker in self.spam_checkers:
+ if spam_checker.user_may_join_room(userid, room_id, is_invited) is False:
+ return False
+
+ return True
+
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 2956a64234..7fe1525e0f 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -12,10 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Callable
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import Requester
+from synapse.module_api import ModuleApi
+from synapse.types import Requester, StateMap
class ThirdPartyEventRules(object):
@@ -38,7 +40,7 @@ class ThirdPartyEventRules(object):
if module is not None:
self.third_party_rules = module(
- config=config, http_client=hs.get_simple_http_client()
+ config=config, module_api=ModuleApi(hs, hs.get_auth_handler()),
)
async def check_event_allowed(
@@ -106,6 +108,48 @@ class ThirdPartyEventRules(object):
if self.third_party_rules is None:
return True
+ state_events = await self._get_state_map_for_room(room_id)
+
+ ret = await self.third_party_rules.check_threepid_can_be_invited(
+ medium, address, state_events
+ )
+ return ret
+
+ async def check_visibility_can_be_modified(
+ self, room_id: str, new_visibility: str
+ ) -> bool:
+ """Check if a room is allowed to be published to, or removed from, the public room
+ list.
+
+ Args:
+ room_id: The ID of the room.
+ new_visibility: The new visibility state. Either "public" or "private".
+
+ Returns:
+ True if the room's visibility can be modified, False if not.
+ """
+ if self.third_party_rules is None:
+ return True
+
+ check_func = getattr(
+ self.third_party_rules, "check_visibility_can_be_modified", None
+ )
+ if not check_func or not isinstance(check_func, Callable):
+ return True
+
+ state_events = await self._get_state_map_for_room(room_id)
+
+ return await check_func(room_id, state_events, new_visibility)
+
+ async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
+ """Given a room ID, return the state events of that room.
+
+ Args:
+ room_id: The ID of the room.
+
+ Returns:
+ A dict mapping (event type, state key) to state event.
+ """
state_ids = await self.store.get_filtered_current_state_ids(room_id)
room_state_events = await self.store.get_events(state_ids.values())
@@ -113,7 +157,4 @@ class ThirdPartyEventRules(object):
for key, event_id in state_ids.items():
state_events[key] = room_state_events[event_id]
- ret = await self.third_party_rules.check_threepid_can_be_invited(
- medium, address, state_events
- )
- return ret
+ return state_events
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 11c5d63298..630f571cd4 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -28,7 +28,6 @@ from typing import (
Union,
)
-from canonicaljson import json
from prometheus_client import Counter, Histogram
from twisted.internet import defer
@@ -63,7 +62,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import glob_to_regex, unwrapFirstError
+from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -551,7 +550,7 @@ class FederationServer(FederationBase):
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
- key_id: json.loads(json_str)
+ key_id: json_decoder.decode(json_str)
}
logger.info(
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 2b0ab2dcbf..4d65d4aeea 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -37,8 +37,8 @@ from sortedcontainers import SortedDict
from twisted.internet import defer
+from synapse.api.presence import UserPresenceState
from synapse.metrics import LaterGauge
-from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
from .units import Edu
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 94cc63001e..e53b6ac456 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -22,6 +22,7 @@ from twisted.internet import defer
import synapse
import synapse.metrics
+from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager
@@ -39,7 +40,6 @@ from synapse.metrics import (
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.metrics import Measure, measure_func
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index dd150f89a6..c09ffcaf4c 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -24,12 +24,12 @@ from synapse.api.errors import (
HttpResponseException,
RequestSendFailed,
)
+from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.federation.units import Edu
from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@@ -337,6 +337,28 @@ class PerDestinationQueue(object):
(e.retry_last_ts + e.retry_interval) / 1000.0
),
)
+
+ if e.retry_interval > 60 * 60 * 1000:
+ # we won't retry for another hour!
+ # (this suggests a significant outage)
+ # We drop pending PDUs and EDUs because otherwise they will
+ # rack up indefinitely.
+ # Note that:
+ # - the EDUs that are being dropped here are those that we can
+ # afford to drop (specifically, only typing notifications,
+ # read receipts and presence updates are being dropped here)
+ # - Other EDUs such as to_device messages are queued with a
+ # different mechanism
+ # - this is all volatile state that would be lost if the
+ # federation sender restarted anyway
+
+ # dropping read receipts is a bit sad but should be solved
+ # through another mechanism, because this is all volatile!
+ self._pending_pdus = []
+ self._pending_edus = []
+ self._pending_edus_keyed = {}
+ self._pending_presence = {}
+ self._pending_rrs = {}
except FederationDeniedError as e:
logger.info(e)
except HttpResponseException as e:
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index c7f6cb3d73..9bd534a313 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -15,8 +15,6 @@
import logging
from typing import TYPE_CHECKING, List, Tuple
-from canonicaljson import json
-
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -28,6 +26,7 @@ from synapse.logging.opentracing import (
tags,
whitelisted_homeserver,
)
+from synapse.util import json_decoder
from synapse.util.metrics import measure_func
if TYPE_CHECKING:
@@ -71,7 +70,7 @@ class TransactionManager(object):
for edu in pending_edus:
context = edu.get_context()
if context:
- span_contexts.append(extract_text_map(json.loads(context)))
+ span_contexts.append(extract_text_map(json_decoder.decode(context)))
if keep_destination:
edu.strip_context()
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9ea821dbb2..c189296660 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -16,7 +16,7 @@
import logging
import urllib
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
from synapse.api.constants import Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
@@ -1004,6 +1004,20 @@ class TransportLayerClient(object):
return self.client.get_json(destination=destination, path=path)
+ def get_info_of_users(self, destination: str, user_ids: List[str]):
+ """
+ Args:
+ destination: The remote server
+ user_ids: A list of user IDs to query info about
+
+ Returns:
+ Deferred[List]: A dictionary of User ID to information about that user.
+ """
+ path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/users/info")
+ data = {"user_ids": user_ids}
+
+ return self.client.post_json(destination=destination, path=path, data=data)
+
def _create_path(federation_prefix, path, *args):
"""
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 5e111aa902..b518dace8a 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -31,6 +31,7 @@ from synapse.api.urls import (
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
+ assert_params_in_dict,
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
@@ -845,6 +846,57 @@ class PublicRoomList(BaseFederationServlet):
return 200, data
+class FederationUserInfoServlet(BaseFederationServlet):
+ """
+ Return information about a set of users.
+
+ This API returns expiration and deactivation information about a set of
+ users. Requested users not local to this homeserver will be ignored.
+
+ Example request:
+ POST /users/info
+
+ {
+ "user_ids": [
+ "@alice:example.com",
+ "@bob:example.com"
+ ]
+ }
+
+ Example response
+ {
+ "@alice:example.com": {
+ "expired": false,
+ "deactivated": true
+ }
+ }
+ """
+
+ PATH = "/users/info"
+ PREFIX = FEDERATION_UNSTABLE_PREFIX
+
+ def __init__(self, handler, authenticator, ratelimiter, server_name):
+ super(FederationUserInfoServlet, self).__init__(
+ handler, authenticator, ratelimiter, server_name
+ )
+ self.handler = handler
+
+ async def on_POST(self, origin, content, query):
+ assert_params_in_dict(content, required=["user_ids"])
+
+ user_ids = content.get("user_ids", [])
+
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ data = await self.handler.store.get_info_for_users(user_ids)
+ return 200, data
+
+
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
@@ -1406,6 +1458,7 @@ FEDERATION_SERVLET_CLASSES = (
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
+ FederationUserInfoServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
OPENID_SERVLET_CLASSES = (
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 590135d19c..3cdbf247ea 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -26,11 +26,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
-try:
- from synapse.push.mailer import load_jinja2_templates
-except ImportError:
- load_jinja2_templates = None
-
logger = logging.getLogger(__name__)
@@ -43,13 +38,17 @@ class AccountValidityHandler(object):
self.clock = self.hs.get_clock()
self._account_validity = self.hs.config.account_validity
+ self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory
+ self.profile_handler = self.hs.get_profile_handler()
if (
self._account_validity.enabled
and self._account_validity.renew_by_email_enabled
- and load_jinja2_templates
):
# Don't do email-specific configuration if renewal by email is disabled.
+ self._template_html = self.config.account_validity_template_html
+ self._template_text = self.config.account_validity_template_text
+
try:
app_name = self.hs.config.email_app_name
@@ -65,17 +64,6 @@ class AccountValidityHandler(object):
self._raw_from = email.utils.parseaddr(self._from_string)[1]
- self._template_html, self._template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_expiry_template_html,
- self.config.email_expiry_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
-
# Check the renewal emails to send and send them every 30min.
def send_emails():
# run as a background process to make sure that the database transactions
@@ -86,6 +74,18 @@ class AccountValidityHandler(object):
self.clock.looping_call(send_emails, 30 * 60 * 1000)
+ # Mark users as inactive when they expired. Check once every hour
+ if self._account_validity.enabled:
+
+ def mark_expired_users_as_inactive():
+ # run as a background process to allow async functions to work
+ return run_as_background_process(
+ "_mark_expired_users_as_inactive",
+ self._mark_expired_users_as_inactive,
+ )
+
+ self.clock.looping_call(mark_expired_users_as_inactive, 60 * 60 * 1000)
+
async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
@@ -266,4 +266,24 @@ class AccountValidityHandler(object):
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
+ # Check if renewed users should be reintroduced to the user directory
+ if self._show_users_in_user_directory:
+ # Show the user in the directory again by setting them to active
+ await self.profile_handler.set_active(
+ [UserID.from_string(user_id)], True, True
+ )
+
return expiration_ts
+
+ async def _mark_expired_users_as_inactive(self):
+ """Iterate over active, expired users. Mark them as inactive in order to hide them
+ from the user directory.
+
+ Returns:
+ Deferred
+ """
+ # Get active, expired users
+ active_expired_users = await self.store.get_expired_users()
+
+ # Mark each as non-active
+ await self.profile_handler.set_active(active_expired_users, False, True)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c24e7bafe0..68d6870e40 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -42,7 +42,6 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
-from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.threepids import canonicalise_email
@@ -132,18 +131,17 @@ class AuthHandler(BaseHandler):
# after the SSO completes and before redirecting them back to their client.
# It notifies the user they are about to give access to their matrix account
# to the client.
- self._sso_redirect_confirm_template = load_jinja2_templates(
- hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
- )[0]
+ self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template
+
# The following template is shown during user interactive authentication
# in the fallback auth scenario. It notifies the user that they are
# authenticating for an operation to occur on their account.
- self._sso_auth_confirm_template = load_jinja2_templates(
- hs.config.sso_template_dir, ["sso_auth_confirm.html"],
- )[0]
+ self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
+
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
self._sso_auth_success_template = hs.config.sso_auth_success_template
+
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 25169157c1..0e26a32750 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -35,6 +35,7 @@ class DeactivateAccountHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_handlers().identity_handler
+ self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
# Flag that indicates whether the process to part users from rooms is running
@@ -108,6 +109,9 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.user_set_password_hash(user_id, None)
+ user = UserID.from_string(user_id)
+ await self._profile_handler.set_active([user], False, False)
+
# Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of)
await self.store.add_user_pending_deactivation(user_id)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 79a2df6201..af9936f7e2 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -45,6 +45,7 @@ class DirectoryHandler(BaseHandler):
self.config = hs.config
self.enable_room_list_search = hs.config.enable_room_list_search
self.require_membership = hs.config.require_membership_for_aliases
+ self.third_party_event_rules = hs.get_third_party_event_rules()
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
@@ -448,6 +449,15 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to publish room")
+ # Check if publishing is blocked by a third party module
+ allowed_by_third_party_rules = await (
+ self.third_party_event_rules.check_visibility_can_be_modified(
+ room_id, visibility
+ )
+ )
+ if not allowed_by_third_party_rules:
+ raise SynapseError(403, "Not allowed to publish room")
+
await self.store.set_room_is_public(room_id, making_public)
async def edit_published_appservice_room_list(
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 84169c1022..d8def45e38 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -19,7 +19,7 @@ import logging
from typing import Dict, List, Optional, Tuple
import attr
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from signedjson.key import VerifyKey, decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
@@ -35,7 +35,7 @@ from synapse.types import (
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
-from synapse.util import unwrapFirstError
+from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
@@ -404,7 +404,7 @@ class E2eKeysHandler(object):
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
- key_id: json.loads(json_bytes)
+ key_id: json_decoder.decode(json_bytes)
}
@trace
@@ -1186,7 +1186,7 @@ def _exception_to_failure(e):
def _one_time_keys_match(old_key_json, new_key):
- old_key = json.loads(old_key_json)
+ old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly
if not isinstance(old_key, dict) or not isinstance(new_key, dict):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 593932adb7..29863c029b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -176,7 +176,7 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("handling received PDU: %s", pdu)
+ logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
# We reprocess pdus when we have seen them only as outliers
existing = await self.store.get_event(
@@ -291,6 +291,14 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
+ elif missing_prevs:
+ logger.info(
+ "[%s %s] Not recursively fetching %d missing prev_events: %s",
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -335,12 +343,6 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
- logger.info(
- "Event %s is missing prev_events: calculating state for a "
- "backwards extremity",
- event_id,
- )
-
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: pdu}
@@ -358,7 +360,10 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "Requesting state at missing prev_event %s", event_id,
+ "[%s %s] Requesting state at missing prev_event %s",
+ room_id,
+ event_id,
+ p,
)
with nested_logging_context(p):
@@ -393,9 +398,7 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self.store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.AS_IS,
+ list(state_map.values()), get_prev_content=False,
)
event_map.update(evs)
@@ -1575,8 +1578,15 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
+ is_published = await self.store.is_room_published(event.room_id)
+
if not self.spam_checker.user_may_invite(
- event.sender, event.state_key, event.room_id
+ event.sender,
+ event.state_key,
+ None,
+ room_id=event.room_id,
+ new_room=False,
+ published_room=is_published,
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
@@ -1777,9 +1787,7 @@ class FederationHandler(BaseHandler):
"""Returns the state at the event. i.e. not including said event.
"""
- event = await self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id
- )
+ event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
@@ -1805,9 +1813,7 @@ class FederationHandler(BaseHandler):
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
- event = await self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id
- )
+ event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@@ -2155,9 +2161,9 @@ class FederationHandler(BaseHandler):
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
- current_auth_events = await self.store.get_events(current_state_ids)
+ auth_events_map = await self.store.get_events(current_state_ids)
current_auth_events = {
- (e.type, e.state_key): e for e in current_auth_events.values()
+ (e.type, e.state_key): e for e in auth_events_map.values()
}
try:
@@ -2173,9 +2179,7 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
- event = await self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id
- )
+ event = await self.store.get_event(event_id, check_room_id=room_id)
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 92b7404706..b5676b248b 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,19 +21,20 @@ import logging
import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
-from canonicaljson import json
-
from twisted.internet.error import TimeoutError
from synapse.api.errors import (
+ AuthError,
CodeMessageException,
Codes,
HttpResponseException,
+ ProxiedRequestError,
SynapseError,
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, Requester
+from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -41,31 +42,36 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
-id_server_scheme = "https://"
-
class IdentityHandler(BaseHandler):
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
- self.http_client = SimpleHttpClient(hs)
+ self.hs = hs
+ self.http_client = hs.get_simple_http_client()
# We create a blacklisting instance of SimpleHttpClient for contacting identity
# servers specified by clients
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
self.federation_http_client = hs.get_http_client()
- self.hs = hs
+
+ self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
+ self.trust_any_id_server_just_for_testing_do_not_use = (
+ hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
+ )
+ self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls
+ self._enable_lookup = hs.config.enable_3pid_lookup
async def threepid_from_creds(
- self, id_server: str, creds: Dict[str, str]
+ self, id_server_url: str, creds: Dict[str, str]
) -> Optional[JsonDict]:
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
Args:
- id_server: The identity server to validate 3PIDs against. Must be a
+ id_server_url: The identity server to validate 3PIDs against. Must be a
complete URL including the protocol (http(s)://)
creds: Dictionary containing the following keys:
* client_secret|clientSecret: A unique secret str provided by the client
@@ -90,7 +96,14 @@ class IdentityHandler(BaseHandler):
query_params = {"sid": session_id, "client_secret": client_secret}
- url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
+ url = "%s%s" % (
+ id_server_url,
+ "/_matrix/identity/api/v1/3pid/getValidated3pid",
+ )
try:
data = await self.http_client.get_json(url, query_params)
@@ -99,7 +112,7 @@ class IdentityHandler(BaseHandler):
except HttpResponseException as e:
logger.info(
"%s returned %i for threepid validation for: %s",
- id_server,
+ id_server_url,
e.code,
creds,
)
@@ -113,7 +126,7 @@ class IdentityHandler(BaseHandler):
if "medium" in data:
return data
- logger.info("%s reported non-validated threepid: %s", id_server, creds)
+ logger.info("%s reported non-validated threepid: %s", id_server_url, creds)
return None
async def bind_threepid(
@@ -145,14 +158,19 @@ class IdentityHandler(BaseHandler):
if id_access_token is None:
use_v2 = False
+ # if we have a rewrite rule set for the identity server,
+ # apply it now, but only for sending the request (not
+ # storing in the database).
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
# Decide which API endpoint URLs to use
headers = {}
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
if use_v2:
- bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+ bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,)
headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore
else:
- bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+ bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,)
try:
# Use the blacklisting http client as this call is only to identity servers
@@ -177,7 +195,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except CodeMessageException as e:
- data = json.loads(e.msg) # XXX WAT?
+ data = json_decoder.decode(e.msg) # XXX WAT?
return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
@@ -239,9 +257,6 @@ class IdentityHandler(BaseHandler):
True on success, otherwise False if the identity
server doesn't support unbinding
"""
- url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
- url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
-
content = {
"mxid": mxid,
"threepid": {"medium": threepid["medium"], "address": threepid["address"]},
@@ -250,6 +265,7 @@ class IdentityHandler(BaseHandler):
# we abuse the federation http client to sign the request, but we have to send it
# using the normal http client since we don't want the SRV lookup and want normal
# 'browser-like' HTTPS.
+ url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
method=b"POST",
@@ -259,6 +275,15 @@ class IdentityHandler(BaseHandler):
)
headers = {b"Authorization": auth_headers}
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ #
+ # Note that destination_is has to be the real id_server, not
+ # the server we connect to.
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ url = "%s/_matrix/identity/api/v1/3pid/unbind" % (id_server_url,)
+
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
@@ -372,9 +397,28 @@ class IdentityHandler(BaseHandler):
return session_id
+ def rewrite_id_server_url(self, url: str, add_https=False) -> str:
+ """Given an identity server URL, optionally add a protocol scheme
+ before rewriting it according to the rewrite_identity_server_urls
+ config option
+
+ Adds https:// to the URL if specified, then tries to rewrite the
+ url. Returns either the rewritten URL or the URL with optional
+ protocol scheme additions.
+ """
+ rewritten_url = url
+ if add_https:
+ rewritten_url = "https://" + rewritten_url
+
+ rewritten_url = self.rewrite_identity_server_urls.get(
+ rewritten_url, rewritten_url
+ )
+ logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url)
+ return rewritten_url
+
async def requestEmailToken(
self,
- id_server: str,
+ id_server_url: str,
email: str,
client_secret: str,
send_attempt: int,
@@ -385,7 +429,7 @@ class IdentityHandler(BaseHandler):
validation.
Args:
- id_server: The identity server to proxy to
+ id_server_url: The identity server to proxy to
email: The email to send the message to
client_secret: The unique client_secret sends by the user
send_attempt: Which attempt this is
@@ -399,6 +443,11 @@ class IdentityHandler(BaseHandler):
"client_secret": client_secret,
"send_attempt": send_attempt,
}
+
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
if next_link:
params["next_link"] = next_link
@@ -413,7 +462,8 @@ class IdentityHandler(BaseHandler):
try:
data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
+ "%s/_matrix/identity/api/v1/validate/email/requestToken"
+ % (id_server_url,),
params,
)
return data
@@ -425,7 +475,7 @@ class IdentityHandler(BaseHandler):
async def requestMsisdnToken(
self,
- id_server: str,
+ id_server_url: str,
country: str,
phone_number: str,
client_secret: str,
@@ -436,7 +486,7 @@ class IdentityHandler(BaseHandler):
Request an external server send an SMS message on our behalf for the purposes of
threepid validation.
Args:
- id_server: The identity server to proxy to
+ id_server_url: The identity server to proxy to
country: The country code of the phone number
phone_number: The number to send the message to
client_secret: The unique client_secret sends by the user
@@ -464,9 +514,13 @@ class IdentityHandler(BaseHandler):
"details and update your config file."
)
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
try:
data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
+ "%s/_matrix/identity/api/v1/validate/msisdn/requestToken"
+ % (id_server_url,),
params,
)
except HttpResponseException as e:
@@ -560,6 +614,86 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
+ # TODO: The following two methods are used for proxying IS requests using
+ # the CS API. They should be consolidated with those in RoomMemberHandler
+ # https://github.com/matrix-org/synapse-dinsic/issues/25
+
+ async def proxy_lookup_3pid(
+ self, id_server: str, medium: str, address: str
+ ) -> JsonDict:
+ """Looks up a 3pid in the passed identity server.
+
+ Args:
+ id_server: The server name (including port, if required)
+ of the identity server to use.
+ medium: The type of the third party identifier (e.g. "email").
+ address: The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ try:
+ data = await self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
+ {"medium": medium, "address": address},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %s: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ return data
+
+ async def proxy_bulk_lookup_3pid(
+ self, id_server: str, threepids: List[List[str]]
+ ) -> JsonDict:
+ """Looks up given 3pids in the passed identity server.
+
+ Args:
+ id_server: The server name (including port, if required)
+ of the identity server to use.
+ threepids: The third party identifiers to lookup, as
+ a list of 2-string sized lists ([medium, address]).
+
+ Returns:
+ The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ try:
+ data = await self.http_client.post_json_get_json(
+ "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,),
+ {"threepids": threepids},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %s: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ return data
+
async def lookup_3pid(
self,
id_server: str,
@@ -580,10 +714,13 @@ class IdentityHandler(BaseHandler):
Returns:
the matrix ID of the 3pid, or None if it is not recognized.
"""
+ # Rewrite id_server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
if id_access_token is not None:
try:
results = await self._lookup_3pid_v2(
- id_server, id_access_token, medium, address
+ id_server_url, id_access_token, medium, address
)
return results
@@ -601,16 +738,17 @@ class IdentityHandler(BaseHandler):
logger.warning("Error when looking up hashing details: %s", e)
return None
- return await self._lookup_3pid_v1(id_server, medium, address)
+ return await self._lookup_3pid_v1(id_server, id_server_url, medium, address)
async def _lookup_3pid_v1(
- self, id_server: str, medium: str, address: str
+ self, id_server: str, id_server_url: str, medium: str, address: str
) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
id_server: The server name (including port, if required)
of the identity server to use.
+ id_server_url: The actual, reachable domain of the id server
medium: The type of the third party identifier (e.g. "email").
address: The third party identifier (e.g. "foo@example.com").
@@ -618,8 +756,8 @@ class IdentityHandler(BaseHandler):
the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
+ data = await self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
{"medium": medium, "address": address},
)
@@ -636,13 +774,12 @@ class IdentityHandler(BaseHandler):
return None
async def _lookup_3pid_v2(
- self, id_server: str, id_access_token: str, medium: str, address: str
+ self, id_server_url: str, id_access_token: str, medium: str, address: str
) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
- id_server: The server name (including port, if required)
- of the identity server to use.
+ id_server_url: The protocol scheme and domain of the id server
id_access_token: The access token to authenticate to the identity server with
medium: The type of the third party identifier (e.g. "email").
address: The third party identifier (e.g. "foo@example.com").
@@ -652,8 +789,8 @@ class IdentityHandler(BaseHandler):
"""
# Check what hashing details are supported by this identity server
try:
- hash_details = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
+ hash_details = await self.http_client.get_json(
+ "%s/_matrix/identity/v2/hash_details" % (id_server_url,),
{"access_token": id_access_token},
)
except TimeoutError:
@@ -661,15 +798,14 @@ class IdentityHandler(BaseHandler):
if not isinstance(hash_details, dict):
logger.warning(
- "Got non-dict object when checking hash details of %s%s: %s",
- id_server_scheme,
- id_server,
+ "Got non-dict object when checking hash details of %s: %s",
+ id_server_url,
hash_details,
)
raise SynapseError(
400,
- "Non-dict object from %s%s during v2 hash_details request: %s"
- % (id_server_scheme, id_server, hash_details),
+ "Non-dict object from %s during v2 hash_details request: %s"
+ % (id_server_url, hash_details),
)
# Extract information from hash_details
@@ -683,8 +819,8 @@ class IdentityHandler(BaseHandler):
):
raise SynapseError(
400,
- "Invalid hash details received from identity server %s%s: %s"
- % (id_server_scheme, id_server, hash_details),
+ "Invalid hash details received from identity server %s: %s"
+ % (id_server_url, hash_details),
)
# Check if any of the supported lookup algorithms are present
@@ -706,7 +842,7 @@ class IdentityHandler(BaseHandler):
else:
logger.warning(
"None of the provided lookup algorithms of %s are supported: %s",
- id_server,
+ id_server_url,
supported_lookup_algorithms,
)
raise SynapseError(
@@ -719,8 +855,8 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
- lookup_results = await self.blacklisting_http_client.post_json_get_json(
- "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
+ lookup_results = await self.http_client.post_json_get_json(
+ "%s/_matrix/identity/v2/lookup" % (id_server_url,),
{
"addresses": [lookup_value],
"algorithm": lookup_algorithm,
@@ -805,15 +941,17 @@ class IdentityHandler(BaseHandler):
"sender_avatar_url": inviter_avatar_url,
}
+ # Rewrite the identity server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
data = None
- base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
+ base_url = "%s/_matrix/identity" % (id_server_url,)
if id_access_token:
- key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
+ key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % (
+ id_server_url,
)
# Attempt a v2 lookup
@@ -832,9 +970,8 @@ class IdentityHandler(BaseHandler):
raise e
if data is None:
- key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
+ key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % (
+ id_server_url,
)
url = base_url + "/api/v1/store-invite"
@@ -846,10 +983,7 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning(
- "Error trying to call /store-invite on %s%s: %s",
- id_server_scheme,
- id_server,
- e,
+ "Error trying to call /store-invite on %s: %s", id_server_url, e,
)
if data is None:
@@ -862,10 +996,9 @@ class IdentityHandler(BaseHandler):
)
except HttpResponseException as e:
logger.warning(
- "Error calling /store-invite on %s%s with fallback "
+ "Error calling /store-invite on %s with fallback "
"encoding: %s",
- id_server_scheme,
- id_server,
+ id_server_url,
e,
)
raise e
@@ -886,6 +1019,39 @@ class IdentityHandler(BaseHandler):
display_name = data["display_name"]
return token, public_keys, fallback_public_key, display_name
+ async def bind_email_using_internal_sydent_api(
+ self, id_server_url: str, email: str, user_id: str,
+ ):
+ """Bind an email to a fully qualified user ID using the internal API of an
+ instance of Sydent.
+
+ Args:
+ id_server_url: The URL of the Sydent instance
+ email: The email address to bind
+ user_id: The user ID to bind the email to
+
+ Raises:
+ HTTPResponseException: On a non-2xx HTTP response.
+ """
+ # Extract the domain name from the IS URL as we store IS domains instead of URLs
+ id_server = urllib.parse.urlparse(id_server_url).hostname
+
+ # id_server_url is assumed to have no trailing slashes
+ url = id_server_url + "/_matrix/identity/internal/bind"
+ body = {
+ "address": email,
+ "medium": "email",
+ "mxid": user_id,
+ }
+
+ # Bind the threepid
+ await self.http_client.post_json_get_json(url, body)
+
+ # Remember where we bound the threepid
+ await self.store.add_user_bound_threepid(
+ user_id=user_id, medium="email", address=email, id_server=id_server,
+ )
+
def create_id_access_token_header(id_access_token: str) -> List[str]:
"""Create an Authorization header for passing to SimpleHttpClient as the header value
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 2643438e84..d5b12403f9 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -17,7 +17,7 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet.interfaces import IDelayedCall
@@ -55,6 +55,7 @@ from synapse.types import (
UserID,
create_requester,
)
+from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func
@@ -63,6 +64,7 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING:
+ from synapse.events.third_party_rules import ThirdPartyEventRules
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -396,7 +398,9 @@ class EventCreationHandler(object):
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
- self.third_party_event_rules = hs.get_third_party_event_rules()
+ self.third_party_event_rules = (
+ self.hs.get_third_party_event_rules()
+ ) # type: ThirdPartyEventRules
self._block_events_without_consent_error = (
self.config.block_events_without_consent_error
@@ -667,14 +671,14 @@ class EventCreationHandler(object):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
- prev_state = await self.deduplicate_state_event(event, context)
- if prev_state is not None:
+ prev_event = await self.deduplicate_state_event(event, context)
+ if prev_event is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
event.event_id,
- prev_state.event_id,
+ prev_event.event_id,
)
- return prev_state
+ return await self.store.get_stream_id_for_event(prev_event.event_id)
return await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
@@ -682,27 +686,32 @@ class EventCreationHandler(object):
async def deduplicate_state_event(
self, event: EventBase, context: EventContext
- ) -> None:
+ ) -> Optional[EventBase]:
"""
Checks whether event is in the latest resolved state in context.
- If so, returns the version of the event in context.
- Otherwise, returns None.
+ Args:
+ event: The event to check for duplication.
+ context: The event context.
+
+ Returns:
+ The previous verion of the event is returned, if it is found in the
+ event context. Otherwise, None is returned.
"""
prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
- return
+ return None
prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
- return
+ return None
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
return prev_event
- return
+ return None
async def create_and_send_nonmember_event(
self,
@@ -859,7 +868,7 @@ class EventCreationHandler(object):
# Ensure that we can round trip before trying to persist in db
try:
dump = frozendict_json_encoder.encode(event.content)
- json.loads(dump)
+ json_decoder.decode(dump)
except Exception:
logger.exception("Failed to encode content: %r", event.content)
raise
@@ -891,9 +900,7 @@ class EventCreationHandler(object):
except Exception:
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
- run_in_background(
- self.store.remove_push_actions_from_staging, event.event_id
- )
+ await self.store.remove_push_actions_from_staging(event.event_id)
raise
async def _validate_canonical_alias(
@@ -957,7 +964,7 @@ class EventCreationHandler(object):
allow_none=True,
)
- is_admin_redaction = (
+ is_admin_redaction = bool(
original_event and event.sender != original_event.sender
)
@@ -1077,8 +1084,8 @@ class EventCreationHandler(object):
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
- auth_events = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+ auth_events_map = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index fa5ee5de8f..dd3703cbd2 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@@ -38,8 +37,8 @@ from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.push.mailer import load_jinja2_templates
from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -123,9 +122,7 @@ class OidcHandler:
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
- self._error_template = load_jinja2_templates(
- hs.config.sso_template_dir, ["sso_error.html"]
- )[0]
+ self._error_template = hs.config.sso_error_template
# identifier for the external_ids table
self._auth_provider_id = "oidc"
@@ -370,7 +367,7 @@ class OidcHandler:
# and check for an error field. If not, we respond with a generic
# error message.
try:
- resp = json.loads(resp_body.decode("utf-8"))
+ resp = json_decoder.decode(resp_body.decode("utf-8"))
error = resp["error"]
description = resp.get("error_description", error)
except (ValueError, KeyError):
@@ -387,7 +384,7 @@ class OidcHandler:
# Since it is a not a 5xx code, body should be a valid JSON. It will
# raise if not.
- resp = json.loads(resp_body.decode("utf-8"))
+ resp = json_decoder.decode(resp_body.decode("utf-8"))
if "error" in resp:
error = resp["error"]
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 5387b3724f..24e1940ee5 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -33,13 +33,13 @@ from typing_extensions import ContextManager
import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
+from synapse.api.presence import UserPresenceState
from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
-from synapse.storage.presence import UserPresenceState
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 31a2e5ea18..4f3198896e 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +15,11 @@
# limitations under the License.
import logging
+from typing import List
+
+from signedjson.sign import sign_json
+
+from twisted.internet import defer, reactor
from synapse.api.errors import (
AuthError,
@@ -23,6 +29,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, create_requester, get_domain_from_id
@@ -42,6 +49,8 @@ class BaseProfileHandler(BaseHandler):
subclass MasterProfileHandler
"""
+ PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000
+
def __init__(self, hs):
super(BaseProfileHandler, self).__init__(hs)
@@ -52,6 +61,92 @@ class BaseProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
+
+ self.max_avatar_size = hs.config.max_avatar_size
+ self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes
+ self.replicate_user_profiles_to = hs.config.replicate_user_profiles_to
+
+ if hs.config.worker_app is None:
+ self.clock.looping_call(
+ self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
+ )
+
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ reactor.callWhenRunning(self._assign_profile_replication_batches)
+ reactor.callWhenRunning(self._replicate_profiles)
+ # Add a looping call to replicate_profiles: this handles retries
+ # if the replication is unsuccessful when the user updated their
+ # profile.
+ self.clock.looping_call(
+ self._replicate_profiles, self.PROFILE_REPLICATE_INTERVAL
+ )
+
+ @defer.inlineCallbacks
+ def _assign_profile_replication_batches(self):
+ """If no profile replication has been done yet, allocate replication batch
+ numbers to each profile to start the replication process.
+ """
+ logger.info("Assigning profile batch numbers...")
+ total = 0
+ while True:
+ assigned = yield self.store.assign_profile_batch()
+ total += assigned
+ if assigned == 0:
+ break
+ logger.info("Assigned %d profile batch numbers", total)
+
+ @defer.inlineCallbacks
+ def _replicate_profiles(self):
+ """If any profile data has been updated and not pushed to the replication targets,
+ replicate it.
+ """
+ host_batches = yield self.store.get_replication_hosts()
+ latest_batch = yield self.store.get_latest_profile_replication_batch_number()
+ if latest_batch is None:
+ latest_batch = -1
+ for repl_host in self.hs.config.replicate_user_profiles_to:
+ if repl_host not in host_batches:
+ host_batches[repl_host] = -1
+ try:
+ for i in range(host_batches[repl_host] + 1, latest_batch + 1):
+ yield self._replicate_host_profile_batch(repl_host, i)
+ except Exception:
+ logger.exception(
+ "Exception while replicating to %s: aborting for now", repl_host
+ )
+
+ @defer.inlineCallbacks
+ def _replicate_host_profile_batch(self, host, batchnum):
+ logger.info("Replicating profile batch %d to %s", batchnum, host)
+ batch_rows = yield self.store.get_profile_batch(batchnum)
+ batch = {
+ UserID(r["user_id"], self.hs.hostname).to_string(): (
+ {"display_name": r["displayname"], "avatar_url": r["avatar_url"]}
+ if r["active"]
+ else None
+ )
+ for r in batch_rows
+ }
+
+ url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,)
+ body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
+ signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
+ try:
+ yield defer.ensureDeferred(
+ self.http_client.post_json_get_json(url, signed_body)
+ )
+ yield defer.ensureDeferred(
+ self.store.update_replication_batch_for_host(host, batchnum)
+ )
+ logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host)
+ except Exception:
+ # This will get retried when the looping call next comes around
+ logger.exception(
+ "Failed to replicate profile batch %d to %s", batchnum, host
+ )
+ raise
+
async def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
@@ -148,7 +243,7 @@ class BaseProfileHandler(BaseHandler):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
- if not by_admin and target_user != requester.user:
+ if not by_admin and requester and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
if not by_admin and not self.hs.config.enable_set_displayname:
@@ -168,13 +263,23 @@ class BaseProfileHandler(BaseHandler):
if new_displayname == "":
new_displayname = None
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
# If the admin changes the display name of a user, the requesting user cannot send
# the join event to update the displayname in the rooms.
# This must be done by the target user himself.
if by_admin:
requester = create_requester(target_user)
- await self.store.set_profile_displayname(target_user.localpart, new_displayname)
+ await self.store.set_profile_displayname(
+ target_user.localpart, new_displayname, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
@@ -184,6 +289,50 @@ class BaseProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ @defer.inlineCallbacks
+ def set_active(
+ self, users: List[UserID], active: bool, hide: bool,
+ ):
+ """
+ Sets the 'active' flag on a set of user profiles. If set to false, the
+ accounts are considered deactivated or hidden.
+
+ If 'hide' is true, then we interpret active=False as a request to try to
+ hide the users rather than deactivating them. This means withholding the
+ profiles from replication (and mark it as inactive) rather than clearing
+ the profile from the HS DB.
+
+ Note that unlike set_displayname and set_avatar_url, this does *not*
+ perform authorization checks! This is because the only place it's used
+ currently is in account deactivation where we've already done these
+ checks anyway.
+
+ Args:
+ users: The users to modify
+ active: Whether to set the user to active or inactive
+ hide: Whether to hide the user (withold from replication). If
+ False and active is False, user will have their profile
+ erased
+
+ Returns:
+ Deferred
+ """
+ if len(self.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
+ yield self.store.set_profiles_active(users, active, hide, new_batchnum)
+
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
async def get_avatar_url(self, target_user):
if self.hs.is_mine(target_user):
try:
@@ -233,11 +382,51 @@ class BaseProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
+ # Enforce a max avatar size if one is defined
+ if self.max_avatar_size or self.allowed_avatar_mimetypes:
+ media_id = self._validate_and_parse_media_id_from_avatar_url(new_avatar_url)
+
+ # Check that this media exists locally
+ media_info = await self.store.get_local_media(media_id)
+ if not media_info:
+ raise SynapseError(
+ 400, "Unknown media id supplied", errcode=Codes.NOT_FOUND
+ )
+
+ # Ensure avatar does not exceed max allowed avatar size
+ media_size = media_info["media_length"]
+ if self.max_avatar_size and media_size > self.max_avatar_size:
+ raise SynapseError(
+ 400,
+ "Avatars must be less than %s bytes in size"
+ % (self.max_avatar_size,),
+ errcode=Codes.TOO_LARGE,
+ )
+
+ # Ensure the avatar's file type is allowed
+ if (
+ self.allowed_avatar_mimetypes
+ and media_info["media_type"] not in self.allowed_avatar_mimetypes
+ ):
+ raise SynapseError(
+ 400, "Avatar file type '%s' not allowed" % media_info["media_type"]
+ )
+
# Same like set_displayname
if by_admin:
requester = create_requester(target_user)
- await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
+ await self.store.set_profile_avatar_url(
+ target_user.localpart, new_avatar_url, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
@@ -247,6 +436,23 @@ class BaseProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ def _validate_and_parse_media_id_from_avatar_url(self, mxc):
+ """Validate and parse a provided avatar url and return the local media id
+
+ Args:
+ mxc (str): A mxc URL
+
+ Returns:
+ str: The ID of the media
+ """
+ avatar_pieces = mxc.split("/")
+ if len(avatar_pieces) != 4 or avatar_pieces[0] != "mxc:":
+ raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc)
+ return avatar_pieces[-1]
+
async def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c94209ab3d..e17b402e68 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -47,11 +47,14 @@ class RegistrationHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid
+ self._show_in_user_directory = self.hs.config.show_users_in_user_directory
+
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -67,8 +70,18 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
async def check_username(
- self, localpart, guest_access_token=None, assigned_user_id=None
+ self, localpart, guest_access_token=None, assigned_user_id=None,
):
+ """
+
+ Args:
+ localpart (str|None): The user's localpart
+ guest_access_token (str|None): A guest's access token
+ assigned_user_id (str|None): An existing User ID for this user if pre-calculated
+
+ Returns:
+ Deferred
+ """
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
@@ -111,6 +124,8 @@ class RegistrationHandler(BaseHandler):
raise SynapseError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
+
+ # Retrieve guest user information from provided access token
user_data = await self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
raise AuthError(
@@ -142,6 +157,7 @@ class RegistrationHandler(BaseHandler):
address=None,
bind_emails=[],
by_admin=False,
+ shadow_banned=False,
):
"""Registers a new client on the server.
@@ -159,6 +175,7 @@ class RegistrationHandler(BaseHandler):
bind_emails (List[str]): list of emails to bind to this account.
by_admin (bool): True if this registration is being made via the
admin api, otherwise False.
+ shadow_banned (bool): Shadow-ban the created user.
Returns:
str: user_id
Raises:
@@ -194,8 +211,14 @@ class RegistrationHandler(BaseHandler):
admin=admin,
user_type=user_type,
address=address,
+ shadow_banned=shadow_banned,
)
+ if default_display_name:
+ await self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
+ )
+
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(localpart)
await self.user_directory_handler.handle_local_profile_change(
@@ -224,6 +247,11 @@ class RegistrationHandler(BaseHandler):
make_guest=make_guest,
create_profile_with_displayname=default_display_name,
address=address,
+ shadow_banned=shadow_banned,
+ )
+
+ await self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
)
# Successfully registered
@@ -259,7 +287,15 @@ class RegistrationHandler(BaseHandler):
}
# Bind email to new account
- await self._register_email_threepid(user_id, threepid_dict, None)
+ await self.register_email_threepid(user_id, threepid_dict, None)
+
+ # Prevent the new user from showing up in the user directory if the server
+ # mandates it.
+ if not self._show_in_user_directory:
+ await self.store.add_account_data_for_user(
+ user_id, "im.vector.hide_profile", {"hide_profile": True}
+ )
+ await self.profile_handler.set_active([user], False, True)
return user_id
@@ -453,7 +489,10 @@ class RegistrationHandler(BaseHandler):
"""
await self._auto_join_rooms(user_id)
- async def appservice_register(self, user_localpart, as_token):
+ async def appservice_register(
+ self, user_localpart, as_token, password_hash, display_name
+ ):
+ # FIXME: this should be factored out and merged with normal register()
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -470,12 +509,25 @@ class RegistrationHandler(BaseHandler):
self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)
+ display_name = display_name or user.localpart
+
await self.register_with_store(
user_id=user_id,
- password_hash="",
+ password_hash=password_hash,
appservice_id=service_id,
- create_profile_with_displayname=user.localpart,
+ create_profile_with_displayname=display_name,
+ )
+
+ await self.profile_handler.set_displayname(
+ user, None, display_name, by_admin=True
)
+
+ if self.hs.config.user_directory_search_all_users:
+ profile = await self.store.get_profileinfo(user_localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
return user_id
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
@@ -502,6 +554,49 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
+ async def shadow_register(self, localpart, display_name, auth_result, params):
+ """Invokes the current registration on another server, using
+ shared secret registration, passing in any auth_results from
+ other registration UI auth flows (e.g. validated 3pids)
+ Useful for setting up shadow/backup accounts on a parallel deployment.
+ """
+
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token),
+ {
+ # XXX: auth_result is an unspecified extension for shadow registration
+ "auth_result": auth_result,
+ # XXX: another unspecified extension for shadow registration to ensure
+ # that the displayname is correctly set by the masters erver
+ "display_name": display_name,
+ "username": localpart,
+ "password": params.get("password"),
+ "bind_msisdn": params.get("bind_msisdn"),
+ "device_id": params.get("device_id"),
+ "initial_device_display_name": params.get(
+ "initial_device_display_name"
+ ),
+ "inhibit_login": False,
+ "access_token": as_token,
+ },
+ )
+
+ async def _generate_user_id(self):
+ if self._next_generated_user_id is None:
+ with await self._generate_user_id_linearizer.queue(()):
+ if self._next_generated_user_id is None:
+ self._next_generated_user_id = (
+ await self.store.find_next_generated_user_id_localpart()
+ )
+
+ id = self._next_generated_user_id
+ self._next_generated_user_id += 1
+ return str(id)
+
def check_registration_ratelimit(self, address):
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
@@ -529,6 +624,7 @@ class RegistrationHandler(BaseHandler):
admin=False,
user_type=None,
address=None,
+ shadow_banned=False,
):
"""Register user in the datastore.
@@ -546,6 +642,7 @@ class RegistrationHandler(BaseHandler):
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the registration.
+ shadow_banned (bool): Whether to shadow-ban the user
Returns:
Awaitable
@@ -561,6 +658,7 @@ class RegistrationHandler(BaseHandler):
admin=admin,
user_type=user_type,
address=address,
+ shadow_banned=shadow_banned,
)
else:
return self.store.register_user(
@@ -572,6 +670,7 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
+ shadow_banned=shadow_banned,
)
async def register_device(
@@ -643,6 +742,7 @@ class RegistrationHandler(BaseHandler):
if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
threepid = auth_result[LoginType.EMAIL_IDENTITY]
+
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(
@@ -650,7 +750,32 @@ class RegistrationHandler(BaseHandler):
):
await self.store.upsert_monthly_active_user(user_id)
- await self._register_email_threepid(user_id, threepid, access_token)
+ await self.register_email_threepid(user_id, threepid, access_token)
+
+ if self.hs.config.bind_new_user_emails_to_sydent:
+ # Attempt to call Sydent's internal bind API on the given identity server
+ # to bind this threepid
+ id_server_url = self.hs.config.bind_new_user_emails_to_sydent
+
+ logger.debug(
+ "Attempting the bind email of %s to identity server: %s using "
+ "internal Sydent bind API.",
+ user_id,
+ self.hs.config.bind_new_user_emails_to_sydent,
+ )
+
+ try:
+ await self.identity_handler.bind_email_using_internal_sydent_api(
+ id_server_url, threepid["address"], user_id
+ )
+ except Exception as e:
+ logger.warning(
+ "Failed to bind email of '%s' to Sydent instance '%s' ",
+ "using Sydent internal bind API: %s",
+ user_id,
+ id_server_url,
+ e,
+ )
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
@@ -671,7 +796,7 @@ class RegistrationHandler(BaseHandler):
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
- async def _register_email_threepid(self, user_id, threepid, token):
+ async def register_email_threepid(self, user_id, threepid, token):
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a8545255b1..b1dd3af7b1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -22,7 +22,7 @@ import logging
import math
import string
from collections import OrderedDict
-from typing import Awaitable, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from synapse.api.constants import (
EventTypes,
@@ -32,11 +32,14 @@ from synapse.api.constants import (
RoomEncryptionAlgorithms,
)
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
+from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import (
+ JsonDict,
Requester,
RoomAlias,
RoomID,
@@ -53,6 +56,9 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
@@ -61,7 +67,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
class RoomCreationHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker()
@@ -92,7 +98,7 @@ class RoomCreationHandler(BaseHandler):
"guest_can_join": False,
"power_level_content_override": {},
},
- }
+ } # type: Dict[str, Dict[str, Any]]
# Modify presets to selectively enable encryption by default per homeserver config
for preset_name, preset_config in self._presets_dict.items():
@@ -215,6 +221,9 @@ class RoomCreationHandler(BaseHandler):
old_room_state = await tombstone_context.get_current_state_ids()
+ # We know the tombstone event isn't an outlier so it has current state.
+ assert old_room_state is not None
+
# update any aliases
await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state
@@ -332,7 +341,19 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_create_room(user_id):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to create rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
+
+ if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ user_id, invite_list=[], third_party_invite_list=[], cloning=True
+ ):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -528,17 +549,21 @@ class RoomCreationHandler(BaseHandler):
logger.error("Unable to send updated alias events in new room: %s", e)
async def create_room(
- self, requester, config, ratelimit=True, creator_join_profile=None
+ self,
+ requester: Requester,
+ config: JsonDict,
+ ratelimit: bool = True,
+ creator_join_profile: Optional[JsonDict] = None,
) -> Tuple[dict, int]:
""" Creates a new room.
Args:
- requester (synapse.types.Requester):
+ requester:
The user who requested the room creation.
- config (dict) : A dict of configuration options.
- ratelimit (bool): set to False to disable the rate limiter
+ config : A dict of configuration options.
+ ratelimit: set to False to disable the rate limiter
- creator_join_profile (dict|None):
+ creator_join_profile:
Set to override the displayname and avatar for the creating
user in this room. If unset, displayname and avatar will be
derived from the user's profile. If set, should contain the
@@ -578,8 +603,14 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
+ invite_list = config.get("invite", [])
+ invite_3pid_list = config.get("invite_3pid", [])
+
if not is_requester_admin and not self.spam_checker.user_may_create_room(
- user_id
+ user_id,
+ invite_list=invite_list,
+ third_party_invite_list=invite_3pid_list,
+ cloning=False,
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -601,6 +632,7 @@ class RoomCreationHandler(BaseHandler):
Codes.UNSUPPORTED_ROOM_VERSION,
)
+ room_alias = None
if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
@@ -611,10 +643,7 @@ class RoomCreationHandler(BaseHandler):
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
- else:
- room_alias = None
- invite_list = config.get("invite", [])
for i in invite_list:
try:
uid = UserID.from_string(i)
@@ -636,8 +665,6 @@ class RoomCreationHandler(BaseHandler):
% (user_id,),
)
- invite_3pid_list = config.get("invite_3pid", [])
-
visibility = config.get("visibility", None)
is_public = visibility == "public"
@@ -645,6 +672,15 @@ class RoomCreationHandler(BaseHandler):
creator_id=user_id, is_public=is_public, room_version=room_version,
)
+ # Check whether this visibility value is blocked by a third party module
+ allowed_by_third_party_rules = await (
+ self.third_party_event_rules.check_visibility_can_be_modified(
+ room_id, visibility
+ )
+ )
+ if not allowed_by_third_party_rules:
+ raise SynapseError(403, "Room visibility value not allowed.")
+
directory_handler = self.hs.get_handlers().directory_handler
if room_alias:
await directory_handler.create_association(
@@ -739,6 +775,7 @@ class RoomCreationHandler(BaseHandler):
"invite",
ratelimit=False,
content=content,
+ new_room=True,
)
for invite_3pid in invite_3pid_list:
@@ -754,6 +791,7 @@ class RoomCreationHandler(BaseHandler):
id_server,
requester,
txn_id=None,
+ new_room=True,
id_access_token=id_access_token,
)
@@ -771,23 +809,30 @@ class RoomCreationHandler(BaseHandler):
async def _send_events_for_new_room(
self,
- creator, # A Requester object.
- room_id,
- preset_config,
- invite_list,
- initial_state,
- creation_content,
- room_alias=None,
- power_level_content_override=None, # Doesn't apply when initial state has power level state event content
- creator_join_profile=None,
+ creator: Requester,
+ room_id: str,
+ preset_config: str,
+ invite_list: List[str],
+ initial_state: StateMap,
+ creation_content: JsonDict,
+ room_alias: Optional[RoomAlias] = None,
+ power_level_content_override: Optional[JsonDict] = None,
+ creator_join_profile: Optional[JsonDict] = None,
) -> int:
"""Sends the initial events into a new room.
+ `power_level_content_override` doesn't apply when initial state has
+ power level state event content.
+
Returns:
The stream_id of the last event persisted.
"""
- def create(etype, content, **kwargs):
+ creator_id = creator.user.to_string()
+
+ event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
+
+ def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
e = {"type": etype, "content": content}
e.update(event_keys)
@@ -795,7 +840,7 @@ class RoomCreationHandler(BaseHandler):
return e
- async def send(etype, content, **kwargs) -> int:
+ async def send(etype: str, content: JsonDict, **kwargs) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
(
@@ -808,10 +853,6 @@ class RoomCreationHandler(BaseHandler):
config = self._presets_dict[preset_config]
- creator_id = creator.user.to_string()
-
- event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
-
creation_content.update({"creator": creator_id})
await send(etype=EventTypes.Create, content=creation_content)
@@ -823,6 +864,7 @@ class RoomCreationHandler(BaseHandler):
"join",
ratelimit=False,
content=creator_join_profile,
+ new_room=True,
)
# We treat the power levels override specially as this needs to be one
@@ -852,7 +894,7 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50,
"invite": 50,
- }
+ } # type: JsonDict
if config["original_invitees_have_ops"]:
for invitee in invite_list:
@@ -906,7 +948,7 @@ class RoomCreationHandler(BaseHandler):
return last_sent_stream_id
async def _generate_room_id(
- self, creator_id: str, is_public: str, room_version: RoomVersion,
+ self, creator_id: str, is_public: bool, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -930,23 +972,30 @@ class RoomCreationHandler(BaseHandler):
class RoomContextHandler(object):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
- async def get_event_context(self, user, room_id, event_id, limit, event_filter):
+ async def get_event_context(
+ self,
+ user: UserID,
+ room_id: str,
+ event_id: str,
+ limit: int,
+ event_filter: Optional[Filter],
+ ) -> Optional[JsonDict]:
"""Retrieves events, pagination tokens and state around a given event
in a room.
Args:
- user (UserID)
- room_id (str)
- event_id (str)
- limit (int): The maximum number of events to return in total
+ user
+ room_id
+ event_id
+ limit: The maximum number of events to return in total
(excluding state).
- event_filter (Filter|None): the filter to apply to the events returned
+ event_filter: the filter to apply to the events returned
(excluding the target event_id)
Returns:
@@ -1033,12 +1082,18 @@ class RoomContextHandler(object):
class RoomEventSource(object):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
async def get_new_events(
- self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
- ):
+ self,
+ user: UserID,
+ from_key: str,
+ limit: int,
+ room_ids: List[str],
+ is_guest: bool,
+ explicit_room_id: Optional[str] = None,
+ ) -> Tuple[List[EventBase], str]:
# We just ignore the key for now.
to_key = self.get_current_key()
@@ -1096,7 +1151,7 @@ class RoomShutdownHandler(object):
)
DEFAULT_ROOM_NAME = "Content Violation Notification"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.room_member_handler = hs.get_room_member_handler()
self._room_creation_handler = hs.get_room_creation_handler()
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0cd59bce3b..1486584f25 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -65,6 +65,7 @@ class RoomMemberHandler(object):
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.identity_handler = hs.get_handlers().identity_handler
self.member_linearizer = Linearizer(name="member")
@@ -283,8 +284,10 @@ class RoomMemberHandler(object):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
) -> Tuple[str, int]:
+ """Update a user's membership in a room"""
key = (room_id,)
with (await self.member_linearizer.queue(key)):
@@ -298,6 +301,7 @@ class RoomMemberHandler(object):
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
+ new_room=new_room,
require_consent=require_consent,
)
@@ -314,6 +318,7 @@ class RoomMemberHandler(object):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
) -> Tuple[str, int]:
content_specified = bool(content)
@@ -378,8 +383,15 @@ class RoomMemberHandler(object):
)
block_invite = True
+ is_published = await self.store.is_room_published(room_id)
+
if not self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(), room_id
+ requester.user.to_string(),
+ target.to_string(),
+ third_party_invite=None,
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
):
logger.info("Blocking invite due to spam checker")
block_invite = True
@@ -457,6 +469,25 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to join rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
+
+ inviter = await self._get_inviter(target.to_string(), room_id)
+ if not is_requester_admin:
+ # We assume that if the spam checker allowed the user to create
+ # a room then they're allowed to join it.
+ if not new_room and not self.spam_checker.user_may_join_room(
+ target.to_string(), room_id, is_invited=inviter is not None
+ ):
+ raise SynapseError(403, "Not allowed to join this room")
+
if is_host_in_room:
time_now_s = self.clock.time()
(
@@ -718,7 +749,7 @@ class RoomMemberHandler(object):
guest_access = await self.store.get_event(guest_access_id)
- return (
+ return bool(
guest_access
and guest_access.content
and "guest_access" in guest_access.content
@@ -773,6 +804,7 @@ class RoomMemberHandler(object):
id_server: str,
requester: Requester,
txn_id: Optional[str],
+ new_room: bool = False,
id_access_token: Optional[str] = None,
) -> int:
if self.config.block_non_admin_invites:
@@ -796,6 +828,16 @@ class RoomMemberHandler(object):
Codes.FORBIDDEN,
)
+ can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
+ medium, address, room_id
+ )
+ if not can_invite:
+ raise SynapseError(
+ 403,
+ "This third-party identifier can not be invited in this room",
+ Codes.FORBIDDEN,
+ )
+
if not self._enable_lookup:
raise SynapseError(
403, "Looking up third-party identifiers is denied from this server"
@@ -805,6 +847,19 @@ class RoomMemberHandler(object):
id_server, medium, address, id_access_token
)
+ is_published = await self.store.is_room_published(room_id)
+
+ if not self.spam_checker.user_may_invite(
+ requester.user.to_string(),
+ invitee,
+ third_party_invite={"medium": medium, "address": address},
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ raise SynapseError(403, "Invites have been disabled on this server")
+
if invitee:
_, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 4d245b618b..5d34989f21 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index a011e9fe29..9146dc1a3b 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -16,13 +16,12 @@
import logging
from typing import Any
-from canonicaljson import json
-
from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.util import json_decoder
logger = logging.getLogger(__name__)
@@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
- resp_body = json.loads(data.decode("utf-8"))
+ resp_body = json_decoder.decode(data.decode("utf-8"))
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8aeb70cdec..dad01a8e56 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -19,7 +19,7 @@ import urllib
from io import BytesIO
import treq
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from netaddr import IPAddress
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -47,6 +47,7 @@ from synapse.http import (
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
logger = logging.getLogger(__name__)
@@ -391,7 +392,7 @@ class SimpleHttpClient(object):
body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- return json.loads(body.decode("utf-8"))
+ return json_decoder.decode(body.decode("utf-8"))
else:
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -433,7 +434,7 @@ class SimpleHttpClient(object):
body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- return json.loads(body.decode("utf-8"))
+ return json_decoder.decode(body.decode("utf-8"))
else:
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -463,7 +464,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
body = await self.get_raw(uri, args, headers=headers)
- return json.loads(body.decode("utf-8"))
+ return json_decoder.decode(body.decode("utf-8"))
async def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
@@ -506,7 +507,7 @@ class SimpleHttpClient(object):
body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- return json.loads(body.decode("utf-8"))
+ return json_decoder.decode(body.decode("utf-8"))
else:
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 89a3b041ce..f794315deb 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import logging
import random
import time
@@ -26,7 +25,7 @@ from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from synapse.logging.context import make_deferred_yieldable
-from synapse.util import Clock
+from synapse.util import Clock, json_decoder
from synapse.util.caches.ttlcache import TTLCache
from synapse.util.metrics import Measure
@@ -181,7 +180,7 @@ class WellKnownResolver(object):
if response.code != 200:
raise Exception("Non-200 response %s" % (response.code,))
- parsed_body = json.loads(body.decode("utf-8"))
+ parsed_body = json_decoder.decode(body.decode("utf-8"))
logger.info("Response from .well-known: %s", parsed_body)
result = parsed_body["m.server"].encode("ascii")
diff --git a/synapse/http/server.py b/synapse/http/server.py
index ffe6cfa09e..8d791bd2ca 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -22,12 +22,13 @@ import types
import urllib
from http import HTTPStatus
from io import BytesIO
-from typing import Any, Callable, Dict, Tuple, Union
+from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
import jinja2
-from canonicaljson import encode_canonical_json, encode_pretty_printed_json
+from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
+from zope.interface import implementer
-from twisted.internet import defer
+from twisted.internet import defer, interfaces
from twisted.python import failure
from twisted.web import resource
from twisted.web.server import NOT_DONE_YET, Request
@@ -499,6 +500,90 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
pass
+@implementer(interfaces.IPushProducer)
+class _ByteProducer:
+ """
+ Iteratively write bytes to the request.
+ """
+
+ # The minimum number of bytes for each chunk. Note that the last chunk will
+ # usually be smaller than this.
+ min_chunk_size = 1024
+
+ def __init__(
+ self, request: Request, iterator: Iterator[bytes],
+ ):
+ self._request = request
+ self._iterator = iterator
+ self._paused = False
+
+ # Register the producer and start producing data.
+ self._request.registerProducer(self, True)
+ self.resumeProducing()
+
+ def _send_data(self, data: List[bytes]) -> None:
+ """
+ Send a list of bytes as a chunk of a response.
+ """
+ if not data:
+ return
+ self._request.write(b"".join(data))
+
+ def pauseProducing(self) -> None:
+ self._paused = True
+
+ def resumeProducing(self) -> None:
+ # We've stopped producing in the meantime (note that this might be
+ # re-entrant after calling write).
+ if not self._request:
+ return
+
+ self._paused = False
+
+ # Write until there's backpressure telling us to stop.
+ while not self._paused:
+ # Get the next chunk and write it to the request.
+ #
+ # The output of the JSON encoder is buffered and coalesced until
+ # min_chunk_size is reached. This is because JSON encoders produce
+ # very small output per iteration and the Request object converts
+ # each call to write() to a separate chunk. Without this there would
+ # be an explosion in bytes written (e.g. b"{" becoming "1\r\n{\r\n").
+ #
+ # Note that buffer stores a list of bytes (instead of appending to
+ # bytes) to hopefully avoid many allocations.
+ buffer = []
+ buffered_bytes = 0
+ while buffered_bytes < self.min_chunk_size:
+ try:
+ data = next(self._iterator)
+ buffer.append(data)
+ buffered_bytes += len(data)
+ except StopIteration:
+ # The entire JSON object has been serialized, write any
+ # remaining data, finalize the producer and the request, and
+ # clean-up any references.
+ self._send_data(buffer)
+ self._request.unregisterProducer()
+ self._request.finish()
+ self.stopProducing()
+ return
+
+ self._send_data(buffer)
+
+ def stopProducing(self) -> None:
+ # Clear a circular reference.
+ self._request = None
+
+
+def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
+ """
+ Encode an object into JSON. Returns an iterator of bytes.
+ """
+ for chunk in json_encoder.iterencode(json_object):
+ yield chunk.encode("utf-8")
+
+
def respond_with_json(
request: Request,
code: int,
@@ -533,15 +618,22 @@ def respond_with_json(
return None
if pretty_print:
- json_bytes = encode_pretty_printed_json(json_object) + b"\n"
+ encoder = iterencode_pretty_printed_json
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
- # canonicaljson already encodes to bytes
- json_bytes = encode_canonical_json(json_object)
+ encoder = iterencode_canonical_json
else:
- json_bytes = json_encoder.encode(json_object).encode("utf-8")
+ encoder = _encode_json_bytes
+
+ request.setResponseCode(code)
+ request.setHeader(b"Content-Type", b"application/json")
+ request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
- return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors)
+ if send_cors:
+ set_cors_headers(request)
+
+ _ByteProducer(request, encoder(json_object))
+ return NOT_DONE_YET
def respond_with_json_bytes(
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index a34e5ead88..53acba56cb 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -17,9 +17,8 @@
import logging
-from canonicaljson import json
-
from synapse.api.errors import Codes, SynapseError
+from synapse.util import json_decoder
logger = logging.getLogger(__name__)
@@ -215,7 +214,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
return None
try:
- content = json.loads(content_bytes.decode("utf-8"))
+ content = json_decoder.decode(content_bytes.decode("utf-8"))
except Exception as e:
logger.warning("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 21dbd9f415..abe532d350 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -177,6 +177,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.config import ConfigError
+from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.http.site import SynapseRequest
@@ -499,7 +500,9 @@ def start_active_span_from_edu(
if opentracing is None:
return _noop_context_manager()
- carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
+ carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
+ "opentracing", {}
+ )
context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
_references = [
opentracing.child_of(span_context_from_string(x))
@@ -699,7 +702,7 @@ def span_context_from_string(carrier):
Returns:
The active span context decoded from a string.
"""
- carrier = json.loads(carrier)
+ carrier = json_decoder.decode(carrier)
return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index f766d16db6..4cd7932e5b 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -175,7 +175,7 @@ def run_as_background_process(desc: str, func, *args, **kwargs):
It returns a Deferred which completes when the function completes, but it doesn't
follow the synapse logcontext rules, which makes it appropriate for passing to
clock.looping_call and friends (or for firing-and-forgetting in the middle of a
- normal synapse inlineCallbacks function).
+ normal synapse async function).
Args:
desc: a description for this background process type
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index c2fb757d9a..5600a4107a 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -14,12 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from twisted.internet import defer
+from synapse.events import EventBase
+from synapse.http.client import SimpleHttpClient
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID, create_requester
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
"""
This package defines the 'stable' API which can be used by extension modules which
@@ -31,6 +37,50 @@ __all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"
logger = logging.getLogger(__name__)
+class PublicRoomListManager:
+ """Contains methods for adding to, removing from and querying whether a room
+ is in the public room list.
+
+ Args:
+ hs: The Homeserver object
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ self._store = hs.get_datastore()
+
+ async def room_is_in_public_room_list(self, room_id: str) -> bool:
+ """Checks whether a room is in the public room list.
+
+ Args:
+ room_id: The ID of the room.
+
+ Returns:
+ Whether the room is in the public room list. Returns False if the room does
+ not exist.
+ """
+ room = await self._store.get_room(room_id)
+ if not room:
+ return False
+
+ return room.get("is_public", False)
+
+ async def add_room_to_public_room_list(self, room_id: str) -> None:
+ """Publishes a room to the public room list.
+
+ Args:
+ room_id: The ID of the room.
+ """
+ await self._store.set_room_is_public(room_id, True)
+
+ async def remove_room_from_public_room_list(self, room_id: str) -> None:
+ """Removes a room from the public room list.
+
+ Args:
+ room_id: The ID of the room.
+ """
+ await self._store.set_room_is_public(room_id, False)
+
+
class ModuleApi(object):
"""A proxy object that gets passed to various plugin modules so they
can register new users etc if necessary.
@@ -43,6 +93,9 @@ class ModuleApi(object):
self._auth = hs.get_auth()
self._auth_handler = auth_handler
+ self.http_client = hs.get_simple_http_client() # type: SimpleHttpClient
+ self.public_room_list_manager = PublicRoomListManager(hs)
+
def get_user_by_req(self, req, allow_guest=False):
"""Check the access_token provided for a request
@@ -167,8 +220,10 @@ class ModuleApi(object):
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
- return self._store.record_user_external_id(
- auth_provider_id, remote_user_id, registered_user_id
+ return defer.ensureDeferred(
+ self._store.record_user_external_id(
+ auth_provider_id, remote_user_id, registered_user_id
+ )
)
def generate_short_term_login_token(
@@ -223,7 +278,9 @@ class ModuleApi(object):
Returns:
Deferred[object]: result of func
"""
- return self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
+ return defer.ensureDeferred(
+ self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
+ )
def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
@@ -262,3 +319,30 @@ class ModuleApi(object):
await self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url,
)
+
+ async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase:
+ """Create and send an event into a room. Membership events are currently not supported.
+
+ Args:
+ event_dict: A dictionary representing the event to send.
+ Required keys are `type`, `room_id`, `sender` and `content`.
+
+ Returns:
+ The event that was sent. If state event deduplication happened, then
+ the previous, duplicate event instead.
+
+ Raises:
+ SynapseError if the event was not allowed.
+ """
+ # Create a requester object
+ requester = create_requester(event_dict["sender"])
+
+ # Create and send the event
+ (
+ event,
+ _,
+ ) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event(
+ requester, event_dict, ratelimit=False
+ )
+
+ return event
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 8047873ff1..172af1a5a4 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -482,7 +482,11 @@ BASE_APPEND_UNDERRIDE_RULES = [
"_id": "_message",
}
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight", "value": False},
+ ],
},
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
@@ -496,7 +500,11 @@ BASE_APPEND_UNDERRIDE_RULES = [
"_id": "_encrypted",
}
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight", "value": False},
+ ],
},
]
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index af117fddf9..c38e037281 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -16,8 +16,7 @@
import email.mime.multipart
import email.utils
import logging
-import time
-import urllib
+import urllib.parse
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Iterable, List, TypeVar
@@ -640,72 +639,3 @@ def string_ordinal_total(s):
for c in s:
tot += ord(c)
return tot
-
-
-def format_ts_filter(value, format):
- return time.strftime(format, time.localtime(value / 1000))
-
-
-def load_jinja2_templates(
- template_dir,
- template_filenames,
- apply_format_ts_filter=False,
- apply_mxc_to_http_filter=False,
- public_baseurl=None,
-):
- """Loads and returns one or more jinja2 templates and applies optional filters
-
- Args:
- template_dir (str): The directory where templates are stored
- template_filenames (list[str]): A list of template filenames
- apply_format_ts_filter (bool): Whether to apply a template filter that formats
- timestamps
- apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts
- mxc urls to http urls
- public_baseurl (str|None): The public baseurl of the server. Required for
- apply_mxc_to_http_filter to be enabled
-
- Returns:
- A list of jinja2 templates corresponding to the given list of filenames,
- with order preserved
- """
- logger.info(
- "loading email templates %s from '%s'", template_filenames, template_dir
- )
- loader = jinja2.FileSystemLoader(template_dir)
- env = jinja2.Environment(loader=loader)
-
- if apply_format_ts_filter:
- env.filters["format_ts"] = format_ts_filter
-
- if apply_mxc_to_http_filter and public_baseurl:
- env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl)
-
- templates = []
- for template_filename in template_filenames:
- template = env.get_template(template_filename)
- templates.append(template)
-
- return templates
-
-
-def _create_mxc_to_http_filter(public_baseurl):
- def mxc_to_http_filter(value, width, height, resize_method="crop"):
- if value[0:6] != "mxc://":
- return ""
-
- serverAndMediaId = value[6:]
- fragment = None
- if "#" in serverAndMediaId:
- (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1)
- fragment = "#" + fragment
-
- params = {"width": width, "height": height, "method": resize_method}
- return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
- public_baseurl,
- serverAndMediaId,
- urllib.parse.urlencode(params),
- fragment or "",
- )
-
- return mxc_to_http_filter
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 8ad0bf5936..f626797133 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -15,22 +15,13 @@
import logging
+from synapse.push.emailpusher import EmailPusher
+from synapse.push.mailer import Mailer
+
from .httppusher import HttpPusher
logger = logging.getLogger(__name__)
-# We try importing this if we can (it will fail if we don't
-# have the optional email dependencies installed). We don't
-# yet have the config to know if we need the email pusher,
-# but importing this after daemonizing seems to fail
-# (even though a simple test of importing from a daemonized
-# process works fine)
-try:
- from synapse.push.emailpusher import EmailPusher
- from synapse.push.mailer import Mailer, load_jinja2_templates
-except Exception:
- pass
-
class PusherFactory(object):
def __init__(self, hs):
@@ -43,16 +34,8 @@ class PusherFactory(object):
if hs.config.email_enable_notifs:
self.mailers = {} # app_name -> Mailer
- self.notif_template_html, self.notif_template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_notif_template_html,
- self.config.email_notif_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
+ self._notif_template_html = hs.config.email_notif_template_html
+ self._notif_template_text = hs.config.email_notif_template_text
self.pusher_types["email"] = self._create_email_pusher
@@ -73,8 +56,8 @@ class PusherFactory(object):
mailer = Mailer(
hs=self.hs,
app_name=app_name,
- template_html=self.notif_template_html,
- template_text=self.notif_template_text,
+ template_html=self._notif_template_html,
+ template_text=self._notif_template_text,
)
self.mailers[app_name] = mailer
return EmailPusher(self.hs, pusherdict, mailer)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 3c3262a88c..8ac29ff725 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -64,6 +64,8 @@ class PusherPool:
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
+ self._account_validity = hs.config.account_validity
+
# map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index e5f22fb858..dd77a44b8d 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -43,7 +43,7 @@ REQUIREMENTS = [
"jsonschema>=2.5.1",
"frozendict>=1",
"unpaddedbase64>=1.1.0",
- "canonicaljson>=1.2.0",
+ "canonicaljson>=1.3.0",
# we use the type definitions added in signedjson 1.1.
"signedjson>=1.1.0",
"pynacl>=1.2.1",
@@ -78,8 +78,6 @@ CONDITIONAL_REQUIREMENTS = {
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
# we use execute_batch, which arrived in psycopg 2.7.
"postgres": ["psycopg2>=2.7"],
- # ConsentResource uses select_autoescape, which arrived in jinja 2.9
- "resources.consent": ["Jinja2>=2.9"],
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
"acme": [
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index ce9420aa69..a02b27474d 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin,
user_type,
address,
+ shadow_banned,
):
"""
Args:
@@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration.
+ shadow_banned (bool): Whether to shadow-ban the user
"""
return {
"password_hash": password_hash,
@@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"admin": admin,
"user_type": user_type,
"address": address,
+ "shadow_banned": shadow_banned,
}
async def _handle_request(self, request, user_id):
@@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin=content["admin"],
user_type=content["user_type"],
address=content["address"],
+ shadow_banned=content["shadow_banned"],
)
return 200, {}
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 9d1d173b2f..d43eaf3a29 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -33,3 +33,11 @@ class SlavedIdTracker(object):
int
"""
return self._current
+
+ def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+
+ For streams with single writers this is equivalent to
+ `get_current_token`.
+ """
+ return self.get_current_token()
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 590187df46..90d90833f9 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@@ -21,16 +22,13 @@ from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
- def get_push_rules_stream_token(self):
- return (
- self._push_rules_stream_id_gen.get_current_token(),
- self._stream_id_gen.get_current_token(),
- )
-
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
+ # We assert this for the benefit of mypy
+ assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
+
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(token)
for row in rows:
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index d853e4447e..8cd47770c1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -21,9 +21,7 @@ import abc
import logging
from typing import Tuple, Type
-from canonicaljson import json
-
-from synapse.util import json_encoder as _json_encoder
+from synapse.util import json_decoder, json_encoder
logger = logging.getLogger(__name__)
@@ -125,7 +123,7 @@ class RdataCommand(Command):
stream_name,
instance_name,
None if token == "batch" else int(token),
- json.loads(row_json),
+ json_decoder.decode(row_json),
)
def to_line(self):
@@ -134,7 +132,7 @@ class RdataCommand(Command):
self.stream_name,
self.instance_name,
str(self.token) if self.token is not None else "batch",
- _json_encoder.encode(self.row),
+ json_encoder.encode(self.row),
)
)
@@ -359,7 +357,7 @@ class UserIpCommand(Command):
def from_line(cls, line):
user_id, jsn = line.split(" ", 1)
- access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
+ access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
@@ -367,7 +365,7 @@ class UserIpCommand(Command):
return (
self.user_id
+ " "
- + _json_encoder.encode(
+ + json_encoder.encode(
(
self.access_token,
self.ip,
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 7a42de3f7d..8c3caf30c9 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -352,7 +352,7 @@ class PushRulesStream(Stream):
)
def _current_token(self, instance_name: str) -> int:
- push_rules_token, _ = self.store.get_push_rules_stream_token()
+ push_rules_token = self.store.get_max_push_rules_stream_id()
return push_rules_token
@@ -405,7 +405,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_cache_stream_token,
+ store.get_cache_stream_token_for_writer,
store.get_all_updated_caches,
)
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 46e458e95b..2e81eeff65 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -118,6 +118,7 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
+ password_policy.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 970fdd5834..ceaa28c212 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -49,9 +49,7 @@ class PresenceStatusRestServlet(RestServlet):
raise AuthError(403, "You are not allowed to see their presence.")
state = await self.presence_handler.get_state(target_user=user)
- state = format_user_presence_state(
- state, self.clock.time_msec(), include_user_id=False
- )
+ state = format_user_presence_state(state, self.clock.time_msec())
return 200, state
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e7fe50ed72..165313b572 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,6 +14,7 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
+from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -28,6 +29,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -63,11 +65,27 @@ class ProfileDisplaynameRestServlet(RestServlet):
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_displayname(shadow_user.to_string(), content)
+
return 200, {}
def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_displayname(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
@@ -76,6 +94,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
super(ProfileAvatarURLRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -114,11 +133,27 @@ class ProfileAvatarURLRestServlet(RestServlet):
user, requester, new_avatar_url, is_admin
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_avatar_url(shadow_user.to_string(), content)
+
return 200, {}
def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_avatar_url(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 00831879f3..e781a3bcf4 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from synapse.api.errors import (
NotFoundError,
StoreError,
@@ -160,10 +159,10 @@ class PushRuleRestServlet(RestServlet):
return 200, {}
def notify_user(self, user_id):
- stream_id, _ = self.store.get_push_rules_stream_token()
+ stream_id = self.store.get_max_push_rules_stream_id()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
- def set_rule_attr(self, user_id, spec, val):
+ async def set_rule_attr(self, user_id, spec, val):
if spec["attr"] == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@@ -173,7 +172,9 @@ class PushRuleRestServlet(RestServlet):
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
+ return await self.store.set_push_rule_enabled(
+ user_id, namespaced_rule_id, val
+ )
elif spec["attr"] == "actions":
actions = val.get("actions")
_check_actions(actions)
@@ -188,7 +189,7 @@ class PushRuleRestServlet(RestServlet):
if namespaced_rule_id not in rule_ids:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
- return self.store.set_push_rule_actions(
+ return await self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 2ab30ce897..bc914d920e 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -21,8 +21,6 @@ import re
from typing import List, Optional
from urllib import parse as urlparse
-from canonicaljson import json
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
@@ -46,6 +44,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
+from synapse.util import json_decoder
MYPY = False
if MYPY:
@@ -519,7 +518,9 @@ class RoomMessageListRestServlet(RestServlet):
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
- event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
+ event_filter = Filter(
+ json_decoder.decode(filter_json)
+ ) # type: Optional[Filter]
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
@@ -631,7 +632,9 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
- event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
+ event_filter = Filter(
+ json_decoder.decode(filter_json)
+ ) # type: Optional[Filter]
else:
event_filter = None
@@ -724,7 +727,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["id_server"],
requester,
txn_id,
- content.get("id_access_token"),
+ new_room=False,
+ id_access_token=content.get("id_access_token"),
)
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index fead85074b..6b945e1849 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
from http import HTTPStatus
+from typing import TYPE_CHECKING
+from urllib.parse import urlparse
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import (
@@ -32,7 +40,8 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
-from synapse.push.mailer import Mailer, load_jinja2_templates
+from synapse.push.mailer import Mailer
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
@@ -53,21 +62,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- template_html, template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_password_reset_template_html,
- self.config.email_password_reset_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=template_html,
- template_text=template_text,
+ template_html=self.config.email_password_reset_template_html,
+ template_text=self.config.email_password_reset_template_text,
)
async def on_POST(self, request):
@@ -103,10 +102,14 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", email):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
+ if next_link:
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
@@ -169,9 +172,8 @@ class PasswordResetSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_password_reset_template_failure_html],
+ self._failure_email_template = (
+ self.config.email_password_reset_template_failure_html
)
async def on_GET(self, request, medium):
@@ -214,14 +216,14 @@ class PasswordResetSubmitTokenServlet(RestServlet):
return None
# Otherwise show the success template
- html = self.config.email_password_reset_template_success_html
+ html = self.config.email_password_reset_template_success_html_content
status_code = 200
except ThreepidValidationError as e:
status_code = e.code
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
- html = self.failure_email_template.render(**template_vars)
+ html = self._failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)
@@ -237,6 +239,7 @@ class PasswordRestServlet(RestServlet):
self.datastore = self.hs.get_datastore()
self.password_policy_handler = hs.get_password_policy_handler()
self._set_password_handler = hs.get_set_password_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -262,26 +265,33 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
- try:
- params, session_id = await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
- )
- except InteractiveAuthIncompleteError as e:
- # The user needs to provide more steps to complete auth, but
- # they're not required to provide the password again.
- #
- # If a password is available now, hash the provided password and
- # store it for later.
- if new_password:
- password_hash = await self.auth_handler.hash(new_password)
- await self.auth_handler.set_session_data(
- e.session_id, "password_hash", password_hash
+ # blindly trust ASes without UI-authing them
+ if requester.app_service:
+ params = body
+ else:
+ try:
+ (
+ params,
+ session_id,
+ ) = await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
)
- raise
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
user_id = requester.user.to_string()
else:
requester = None
@@ -346,11 +356,29 @@ class PasswordRestServlet(RestServlet):
user_id, password_hash, logout_devices, requester
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_password(params, shadow_user.to_string())
+
return 200, {}
def on_OPTIONS(self, _):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_password(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
@@ -411,19 +439,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
self.store = self.hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- template_html, template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_add_threepid_template_html,
- self.config.email_add_threepid_template_text,
- ],
- public_baseurl=self.config.public_baseurl,
- )
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=template_html,
- template_text=template_text,
+ template_html=self.config.email_add_threepid_template_html,
+ template_text=self.config.email_add_threepid_template_text,
)
async def on_POST(self, request):
@@ -454,13 +474,17 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", email)):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
+ if next_link:
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
existing_user_id = await self.store.get_user_id_by_threepid("email", email)
if existing_user_id is not None:
@@ -522,13 +546,17 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
+ if next_link:
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
@@ -578,9 +606,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_add_threepid_template_failure_html],
+ self._failure_email_template = (
+ self.config.email_add_threepid_template_failure_html
)
async def on_GET(self, request):
@@ -613,15 +640,10 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set
if next_link:
- if next_link.startswith("file:///"):
- logger.warning(
- "Not redirecting to next_link as it is a local file: address"
- )
- else:
- request.setResponseCode(302)
- request.setHeader("Location", next_link)
- finish_request(request)
- return None
+ request.setResponseCode(302)
+ request.setHeader("Location", next_link)
+ finish_request(request)
+ return None
# Otherwise show the success template
html = self.config.email_add_threepid_template_success_html_content
@@ -631,7 +653,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
- html = self.failure_email_template.render(**template_vars)
+ html = self._failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)
@@ -687,7 +709,8 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- self.datastore = self.hs.get_datastore()
+ self.datastore = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
@@ -706,6 +729,29 @@ class ThreepidRestServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
+ # skip validation if this is a shadow 3PID from an AS
+ if requester.app_service:
+ # XXX: ASes pass in a validated threepid directly to bypass the IS.
+ # This makes the API entirely change shape when we have an AS token;
+ # it really should be an entirely separate API - perhaps
+ # /account/3pid/replicate or something.
+ threepid = body.get("threepid")
+
+ await self.auth_handler.add_threepid(
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
+ )
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
+ return 200, {}
+
threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds")
if threepid_creds is None:
raise SynapseError(
@@ -727,12 +773,36 @@ class ThreepidRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")
@@ -743,6 +813,7 @@ class ThreepidAddRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -778,12 +849,34 @@ class ThreepidAddRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")
@@ -853,6 +946,7 @@ class ThreepidDeleteRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
if not self.hs.config.enable_3pid_changes:
@@ -877,6 +971,12 @@ class ThreepidDeleteRestServlet(RestServlet):
logger.exception("Failed to remove threepid")
raise SynapseError(500, "Failed to remove threepid")
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid_delete(body, shadow_user.to_string())
+
if ret:
id_server_unbind_result = "success"
else:
@@ -884,6 +984,114 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
+ @defer.inlineCallbacks
+ def shadow_3pid_delete(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
+
+class ThreepidLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ async def on_GET(self, request):
+ """Proxy a /_matrix/identity/api/v1/lookup request to an identity
+ server
+ """
+ await self.auth.get_user_by_req(request)
+
+ # Verify query parameters
+ query_params = request.args
+ assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"])
+
+ # Retrieve needed information from query parameters
+ medium = parse_string(request, "medium")
+ address = parse_string(request, "address")
+ id_server = parse_string(request, "id_server")
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = await self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
+
+ return 200, ret
+
+
+class ThreepidBulkLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidBulkLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ async def on_POST(self, request):
+ """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
+ server
+ """
+ await self.auth.get_user_by_req(request)
+
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["threepids", "id_server"])
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = await self.identity_handler.proxy_bulk_lookup_3pid(
+ body["id_server"], body["threepids"]
+ )
+
+ return 200, ret
+
+
+def assert_valid_next_link(hs: "HomeServer", next_link: str):
+ """
+ Raises a SynapseError if a given next_link value is invalid
+
+ next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config
+ option is either empty or contains a domain that matches the one in the given next_link
+
+ Args:
+ hs: The homeserver object
+ next_link: The next_link value given by the client
+
+ Raises:
+ SynapseError: If the next_link is invalid
+ """
+ valid = True
+
+ # Parse the contents of the URL
+ next_link_parsed = urlparse(next_link)
+
+ # Scheme must not point to the local drive
+ if next_link_parsed.scheme == "file":
+ valid = False
+
+ # If the domain whitelist is set, the domain must be in it
+ if (
+ valid
+ and hs.config.next_link_domain_whitelist is not None
+ and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist
+ ):
+ valid = False
+
+ if not valid:
+ raise SynapseError(
+ 400,
+ "'next_link' domain not included in whitelist, or not http(s)",
+ errcode=Codes.INVALID_PARAM,
+ )
+
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
@@ -912,4 +1120,6 @@ def register_servlets(hs, http_server):
ThreepidBindRestServlet(hs).register(http_server)
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ ThreepidLookupRestServlet(hs).register(http_server)
+ ThreepidBulkLookupRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index c1d4cd0caf..d31ec7c29d 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -17,6 +17,7 @@ import logging
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
from ._base import client_patterns
@@ -39,6 +40,7 @@ class AccountDataServlet(RestServlet):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
+ self._profile_handler = hs.get_profile_handler()
async def on_PUT(self, request, user_id, account_data_type):
if self._is_worker:
@@ -50,6 +52,11 @@ class AccountDataServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if account_data_type == "im.vector.hide_profile":
+ user = UserID.from_string(user_id)
+ hide_profile = body.get("hide_profile")
+ await self._profile_handler.set_active([user], not hide_profile, True)
+
max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index d84a6d7e11..13ecf7005d 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,6 +16,7 @@
import logging
+from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID
@@ -325,6 +326,9 @@ class GroupRoomServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
+ if not GroupID.is_valid(group_id):
+ raise SynapseError(400, "%s was not legal group ID" % (group_id,))
+
result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index f808175698..e0d83a962d 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 - 2016 OpenMarket Ltd
-# Copyright 2017 Vector Creations Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,7 @@
import hmac
import logging
+import re
from typing import List, Union
import synapse
@@ -44,7 +46,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
-from synapse.push.mailer import load_jinja2_templates
+from synapse.push.mailer import Mailer
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -81,23 +83,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
self.config = hs.config
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- from synapse.push.mailer import Mailer, load_jinja2_templates
-
- template_html, template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_registration_template_html,
- self.config.email_registration_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=template_html,
- template_text=template_text,
+ template_html=self.config.email_registration_template_html,
+ template_text=self.config.email_registration_template_text,
)
async def on_POST(self, request):
@@ -128,10 +118,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", body["email"])):
raise SynapseError(
403,
- "Your email domain is not authorized to register on this server",
+ "Your email is not authorized to register on this server",
Codes.THREEPID_DENIED,
)
@@ -200,7 +190,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ assert_valid_client_secret(body["client_secret"])
+
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Phone numbers are not authorized to register on this server",
@@ -262,15 +254,8 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_registration_template_failure_html],
- )
-
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_registration_template_failure_html],
+ self._failure_email_template = (
+ self.config.email_registration_template_failure_html
)
async def on_GET(self, request, medium):
@@ -318,7 +303,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
- html = self.failure_email_template.render(**template_vars)
+ html = self._failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)
@@ -356,15 +341,9 @@ class UsernameAvailabilityRestServlet(RestServlet):
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
- ip = self.hs.get_ip_from_request(request)
- with self.ratelimiter.ratelimit(ip) as wait_deferred:
- await wait_deferred
-
- username = parse_string(request, "username", required=True)
-
- await self.registration_handler.check_username(username)
-
- return 200, {"available": True}
+ # We are not interested in logging in via a username in this deployment.
+ # Simply allow anything here as it won't be used later.
+ return 200, {"available": True}
class RegisterRestServlet(RestServlet):
@@ -414,18 +393,27 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
- # Pull out the provided username and do basic sanity checks early since
- # the auth layer will store these in sessions.
+ # We don't care about usernames for this deployment. In fact, the act
+ # of checking whether they exist already can leak metadata about
+ # which users are already registered.
+ #
+ # Usernames are already derived via the provided email.
+ # So, if they're not necessary, just ignore them.
+ #
+ # (we do still allow appservices to set them below)
desired_username = None
- if "username" in body:
- if not isinstance(body["username"], str) or len(body["username"]) > 512:
- raise SynapseError(400, "Invalid username")
- desired_username = body["username"]
+
+ desired_display_name = body.get("display_name")
appservice = None
if self.auth.has_access_token(request):
appservice = self.auth.get_appservice_by_req(request)
+ # We need to retrieve the password early in order to pass it to
+ # application service registration
+ # This is specific to shadow server registration of users via an AS
+ password = body.pop("password", None)
+
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
@@ -434,7 +422,7 @@ class RegisterRestServlet(RestServlet):
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
- desired_username = body.get("user", desired_username)
+ desired_username = body.get("user", body.get("username"))
# XXX we should check that desired_username is valid. Currently
# we give appservices carte blanche for any insanity in mxids,
@@ -445,7 +433,7 @@ class RegisterRestServlet(RestServlet):
if isinstance(desired_username, str):
result = await self._do_appservice_registration(
- desired_username, access_token, body
+ desired_username, password, desired_display_name, access_token, body
)
return 200, result # we throw for non 200 responses
@@ -453,16 +441,6 @@ class RegisterRestServlet(RestServlet):
if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled")
- # For regular registration, convert the provided username to lowercase
- # before attempting to register it. This should mean that people who try
- # to register with upper-case in their usernames don't get a nasty surprise.
- #
- # Note that we treat usernames case-insensitively in login, so they are
- # free to carry on imagining that their username is CrAzYh4cKeR if that
- # keeps them happy.
- if desired_username is not None:
- desired_username = desired_username.lower()
-
# Check if this account is upgrading from a guest account.
guest_access_token = body.get("guest_access_token", None)
@@ -471,7 +449,6 @@ class RegisterRestServlet(RestServlet):
# Note that we remove the password from the body since the auth layer
# will store the body in the session and we don't want a plaintext
# password store there.
- password = body.pop("password", None)
if password is not None:
if not isinstance(password, str) or len(password) > 512:
raise SynapseError(400, "Invalid password")
@@ -501,14 +478,6 @@ class RegisterRestServlet(RestServlet):
session_id, "password_hash", None
)
- # Ensure that the username is valid.
- if desired_username is not None:
- await self.registration_handler.check_username(
- desired_username,
- guest_access_token=guest_access_token,
- assigned_user_id=registered_user_id,
- )
-
# Check if the user-interactive authentication flows are complete, if
# not this will raise a user-interactive auth error.
try:
@@ -547,7 +516,7 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
- if not check_3pid_allowed(self.hs, medium, address):
+ if not (await check_3pid_allowed(self.hs, medium, address)):
raise SynapseError(
403,
"Third party identifiers (email/phone numbers)"
@@ -555,6 +524,80 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ existingUid = await self.store.get_user_id_by_threepid(
+ medium, address
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE
+ )
+
+ if self.hs.config.register_mxid_from_3pid:
+ # override the desired_username based on the 3PID if any.
+ # reset it first to avoid folks picking their own username.
+ desired_username = None
+
+ # we should have an auth_result at this point if we're going to progress
+ # to register the user (i.e. we haven't picked up a registered_user_id
+ # from our session store), in which case get ready and gen the
+ # desired_username
+ if auth_result:
+ if (
+ self.hs.config.register_mxid_from_3pid == "email"
+ and LoginType.EMAIL_IDENTITY in auth_result
+ ):
+ address = auth_result[LoginType.EMAIL_IDENTITY]["address"]
+ desired_username = synapse.types.strip_invalid_mxid_characters(
+ address.replace("@", "-").lower()
+ )
+
+ # find a unique mxid for the account, suffixing numbers
+ # if needed
+ while True:
+ try:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+ # if we got this far we passed the check.
+ break
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ m = re.match(r"^(.*?)(\d+)$", desired_username)
+ if m:
+ desired_username = m.group(1) + str(
+ int(m.group(2)) + 1
+ )
+ else:
+ desired_username += "1"
+ else:
+ # something else went wrong.
+ break
+
+ if self.hs.config.register_just_use_email_for_display_name:
+ desired_display_name = address
+ else:
+ # Custom mapping between email address and display name
+ desired_display_name = _map_email_to_displayname(address)
+ elif (
+ self.hs.config.register_mxid_from_3pid == "msisdn"
+ and LoginType.MSISDN in auth_result
+ ):
+ desired_username = auth_result[LoginType.MSISDN]["address"]
+ else:
+ raise SynapseError(
+ 400, "Cannot derive mxid from 3pid; no recognised 3pid"
+ )
+
+ if desired_username is not None:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session", registered_user_id
@@ -569,7 +612,12 @@ class RegisterRestServlet(RestServlet):
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
- desired_username = params.get("username", None)
+ if not self.hs.config.register_mxid_from_3pid:
+ desired_username = params.get("username", None)
+ else:
+ # we keep the original desired_username derived from the 3pid above
+ pass
+
guest_access_token = params.get("guest_access_token", None)
if desired_username is not None:
@@ -614,6 +662,7 @@ class RegisterRestServlet(RestServlet):
localpart=desired_username,
password_hash=password_hash,
guest_access_token=guest_access_token,
+ default_display_name=desired_display_name,
threepid=threepid,
address=client_addr,
)
@@ -625,6 +674,14 @@ class RegisterRestServlet(RestServlet):
):
await self.store.upsert_monthly_active_user(registered_user_id)
+ if self.hs.config.shadow_server:
+ await self.registration_handler.shadow_register(
+ localpart=desired_username,
+ display_name=desired_display_name,
+ auth_result=auth_result,
+ params=params,
+ )
+
# Remember that the user account has been registered (and the user
# ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
@@ -649,11 +706,38 @@ class RegisterRestServlet(RestServlet):
def on_OPTIONS(self, _):
return 200, {}
- async def _do_appservice_registration(self, username, as_token, body):
+ async def _do_appservice_registration(
+ self, username, password, display_name, as_token, body
+ ):
+ # FIXME: appservice_register() is horribly duplicated with register()
+ # and they should probably just be combined together with a config flag.
+
+ if password:
+ # Hash the password
+ #
+ # In mainline hashing of the password was moved further on in the registration
+ # flow, but we need it here for the AS use-case of shadow servers
+ password = await self.auth_handler.hash(password)
+
user_id = await self.registration_handler.appservice_register(
- username, as_token
+ username, as_token, password, display_name
)
- return await self._create_registration_details(user_id, body)
+ result = await self._create_registration_details(user_id, body)
+
+ auth_result = body.get("auth_result")
+ if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
+ threepid = auth_result[LoginType.EMAIL_IDENTITY]
+ await self.registration_handler.register_email_threepid(
+ user_id, threepid, result["access_token"]
+ )
+
+ if auth_result and LoginType.MSISDN in auth_result:
+ threepid = auth_result[LoginType.MSISDN]
+ await self.registration_handler.register_msisdn_threepid(
+ user_id, threepid, result["access_token"]
+ )
+
+ return result
async def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
@@ -704,6 +788,60 @@ class RegisterRestServlet(RestServlet):
)
+def cap(name):
+ """Capitalise parts of a name containing different words, including those
+ separated by hyphens.
+ For example, 'John-Doe'
+
+ Args:
+ name (str): The name to parse
+ """
+ if not name:
+ return name
+
+ # Split the name by whitespace then hyphens, capitalizing each part then
+ # joining it back together.
+ capatilized_name = " ".join(
+ "-".join(part.capitalize() for part in space_part.split("-"))
+ for space_part in name.split()
+ )
+ return capatilized_name
+
+
+def _map_email_to_displayname(address):
+ """Custom mapping from an email address to a user displayname
+
+ Args:
+ address (str): The email address to process
+ Returns:
+ str: The new displayname
+ """
+ # Split the part before and after the @ in the email.
+ # Replace all . with spaces in the first part
+ parts = address.replace(".", " ").split("@")
+
+ # Figure out which org this email address belongs to
+ org_parts = parts[1].split(" ")
+
+ # If this is a ...matrix.org email, mark them as an Admin
+ if org_parts[-2] == "matrix" and org_parts[-1] == "org":
+ org = "Tchap Admin"
+
+ # Is this is a ...gouv.fr address, set the org to whatever is before
+ # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their
+ # org as "gouv"
+ elif org_parts[-2] == "gouv" and org_parts[-1] == "fr":
+ org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2]
+
+ # Otherwise, mark their org as the email's second-level domain name
+ else:
+ org = org_parts[-2]
+
+ desired_display_name = cap(parts[0]) + " [" + cap(org) + "]"
+
+ return desired_display_name
+
+
def _calculate_registration_flows(
# technically `config` has to provide *all* of these interfaces, not just one
config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a5c24fbd63..96488b131a 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -16,8 +16,6 @@
import itertools
import logging
-from canonicaljson import json
-
from synapse.api.constants import PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
@@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken
+from synapse.util import json_decoder
from ._base import client_patterns, set_timeline_upper_limit
@@ -125,7 +124,7 @@ class SyncRestServlet(RestServlet):
filter_collection = DEFAULT_FILTER_COLLECTION
elif filter_id.startswith("{"):
try:
- filter_object = json.loads(filter_id)
+ filter_object = json_decoder.decode(filter_id)
set_timeline_upper_limit(
filter_object, self.hs.config.filter_timeline_limit
)
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index bef91a2d3e..6e8300d6a5 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -14,9 +14,17 @@
# limitations under the License.
import logging
+from typing import Dict
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from signedjson.sign import sign_json
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.types import UserID
from ._base import client_patterns
@@ -35,6 +43,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
"""Searches for users in directory
@@ -61,6 +70,16 @@ class UserDirectorySearchRestServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if self.hs.config.user_directory_defer_to_id_server:
+ signed_body = sign_json(
+ body, self.hs.hostname, self.hs.config.signing_key[0]
+ )
+ url = "%s/_matrix/identity/api/v1/user_directory/search" % (
+ self.hs.config.user_directory_defer_to_id_server,
+ )
+ resp = await self.http_client.post_json_get_json(url, signed_body)
+ return 200, resp
+
limit = body.get("limit", 10)
limit = min(limit, 50)
@@ -76,5 +95,125 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results
+class SingleUserInfoServlet(RestServlet):
+ """
+ Deprecated and replaced by `/users/info`
+
+ GET /user/{user_id}/info HTTP/1.1
+ """
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
+
+ def __init__(self, hs):
+ super(SingleUserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+ registry = hs.get_federation_registry()
+
+ if not registry.query_handlers.get("user_info"):
+ registry.register_query_handler("user_info", self._on_federation_query)
+
+ async def on_GET(self, request, user_id):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ # Attempt to make a federation request to the server that owns this user
+ args = {"user_id": user_id}
+ res = await self.transport_layer.make_query(
+ user.domain, "user_info", args, retry_on_dns_fail=True
+ )
+ return 200, res
+
+ user_id_to_info = await self.store.get_info_for_users([user_id])
+ return 200, user_id_to_info[user_id]
+
+ async def _on_federation_query(self, args):
+ """Called when a request for user information appears over federation
+
+ Args:
+ args (dict): Dictionary of query arguments provided by the request
+
+ Returns:
+ Deferred[dict]: Deactivation and expiration information for a given user
+ """
+ user_id = args.get("user_id")
+ if not user_id:
+ raise SynapseError(400, "user_id not provided")
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ user_ids_to_info_dict = await self.store.get_info_for_users([user_id])
+ return user_ids_to_info_dict[user_id]
+
+
+class UserInfoServlet(RestServlet):
+ """Bulk version of `/user/{user_id}/info` endpoint
+
+ GET /users/info HTTP/1.1
+
+ Returns a dictionary of user_id to info dictionary. Supports remote users
+ """
+
+ PATTERNS = client_patterns("/users/info$", unstable=True, releases=())
+
+ def __init__(self, hs):
+ super(UserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+
+ async def on_POST(self, request):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ # Extract the user_ids from the request
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, required=["user_ids"])
+
+ user_ids = body["user_ids"]
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ # Separate local and remote users
+ local_user_ids = set()
+ remote_server_to_user_ids = {} # type: Dict[str, set]
+ for user_id in user_ids:
+ user = UserID.from_string(user_id)
+
+ if self.hs.is_mine(user):
+ local_user_ids.add(user_id)
+ else:
+ remote_server_to_user_ids.setdefault(user.domain, set())
+ remote_server_to_user_ids[user.domain].add(user_id)
+
+ # Retrieve info of all local users
+ user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids)
+
+ # Request info of each remote user from their remote homeserver
+ for server_name, user_id_set in remote_server_to_user_ids.items():
+ # Make a request to the given server about their own users
+ res = await self.transport_layer.get_info_of_users(
+ server_name, list(user_id_set)
+ )
+
+ for user_id, info in res:
+ user_id_to_info_dict[user_id] = info
+
+ return 200, user_id_to_info_dict
+
+
def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)
+ SingleUserInfoServlet(hs).register(http_server)
+ UserInfoServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0d668df0b6..b1999d051b 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -57,9 +57,12 @@ class VersionsRestServlet(RestServlet):
# MSC2326.
"org.matrix.label_based_filtering": True,
# Implements support for cross signing as described in MSC1756
- "org.matrix.e2e_cross_signing": True,
+ # "org.matrix.e2e_cross_signing": True,
# Implements additional endpoints as described in MSC2432
"org.matrix.msc2432": True,
+ # Tchap does not currently assume this rule for r0.5.0
+ # XXX: Remove this when it does
+ "m.lazy_load_members": True,
},
},
)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 9b3f85b306..5db7f81c2d 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -15,19 +15,19 @@
import logging
from typing import Dict, Set
-from canonicaljson import encode_canonical_json, json
from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
+from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
+from synapse.util import json_decoder
logger = logging.getLogger(__name__)
class RemoteKey(DirectServeJsonResource):
- """HTTP resource for retreiving the TLS certificate and NACL signature
+ """HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
that the NACL signature for the remote server is valid. Returns a dict of
@@ -209,13 +209,15 @@ class RemoteKey(DirectServeJsonResource):
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
+ # If there is a cache miss, request the missing keys, then recurse (and
+ # ensure the result is sent).
if cache_misses and query_remote_on_cache_miss:
await self.fetcher.get_keys(cache_misses)
await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
signed_keys = []
for key_json in json_results:
- key_json = json.loads(key_json.decode("utf-8"))
+ key_json = json_decoder.decode(key_json.decode("utf-8"))
for signing_key in self.config.key_server_signing_keys:
key_json = sign_json(key_json, self.config.server_name, signing_key)
@@ -223,4 +225,4 @@ class RemoteKey(DirectServeJsonResource):
results = {"server_keys": signed_keys}
- respond_with_json_bytes(request, 200, encode_canonical_json(results))
+ respond_with_json(request, 200, results, canonical_json=True)
diff --git a/synapse/rulecheck/__init__.py b/synapse/rulecheck/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/rulecheck/__init__.py
diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py
new file mode 100644
index 0000000000..6f2a1931c5
--- /dev/null
+++ b/synapse/rulecheck/domain_rule_checker.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.config._base import ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DomainRuleChecker(object):
+ """
+ A re-implementation of the SpamChecker that prevents users in one domain from
+ inviting users in other domains to rooms, based on a configuration.
+
+ Takes a config in the format:
+
+ spam_checker:
+ module: "rulecheck.DomainRuleChecker"
+ config:
+ domain_mapping:
+ "inviter_domain": [ "invitee_domain_permitted", "other_domain_permitted" ]
+ "other_inviter_domain": [ "invitee_domain_permitted" ]
+ default: False
+
+ # Only let local users join rooms if they were explicitly invited.
+ can_only_join_rooms_with_invite: false
+
+ # Only let local users create rooms if they are inviting only one
+ # other user, and that user matches the rules above.
+ can_only_create_one_to_one_rooms: false
+
+ # Only let local users invite during room creation, regardless of the
+ # domain mapping rules above.
+ can_only_invite_during_room_creation: false
+
+ # Prevent local users from inviting users from certain domains to
+ # rooms published in the room directory.
+ domains_prevented_from_being_invited_to_published_rooms: []
+
+ # Allow third party invites
+ can_invite_by_third_party_id: true
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config):
+ self.domain_mapping = config["domain_mapping"] or {}
+ self.default = config["default"]
+
+ self.can_only_join_rooms_with_invite = config.get(
+ "can_only_join_rooms_with_invite", False
+ )
+ self.can_only_create_one_to_one_rooms = config.get(
+ "can_only_create_one_to_one_rooms", False
+ )
+ self.can_only_invite_during_room_creation = config.get(
+ "can_only_invite_during_room_creation", False
+ )
+ self.can_invite_by_third_party_id = config.get(
+ "can_invite_by_third_party_id", True
+ )
+ self.domains_prevented_from_being_invited_to_published_rooms = config.get(
+ "domains_prevented_from_being_invited_to_published_rooms", []
+ )
+
+ def check_event_for_spam(self, event):
+ """Implements synapse.events.SpamChecker.check_event_for_spam
+ """
+ return False
+
+ def user_may_invite(
+ self,
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room=False,
+ ):
+ """Implements synapse.events.SpamChecker.user_may_invite
+ """
+ if self.can_only_invite_during_room_creation and not new_room:
+ return False
+
+ if not self.can_invite_by_third_party_id and third_party_invite:
+ return False
+
+ # This is a third party invite (without a bound mxid), so unless we have
+ # banned all third party invites (above) we allow it.
+ if not invitee_userid:
+ return True
+
+ inviter_domain = self._get_domain_from_id(inviter_userid)
+ invitee_domain = self._get_domain_from_id(invitee_userid)
+
+ if inviter_domain not in self.domain_mapping:
+ return self.default
+
+ if (
+ published_room
+ and invitee_domain
+ in self.domains_prevented_from_being_invited_to_published_rooms
+ ):
+ return False
+
+ return invitee_domain in self.domain_mapping[inviter_domain]
+
+ def user_may_create_room(
+ self, userid, invite_list, third_party_invite_list, cloning
+ ):
+ """Implements synapse.events.SpamChecker.user_may_create_room
+ """
+
+ if cloning:
+ return True
+
+ if not self.can_invite_by_third_party_id and third_party_invite_list:
+ return False
+
+ number_of_invites = len(invite_list) + len(third_party_invite_list)
+
+ if self.can_only_create_one_to_one_rooms and number_of_invites != 1:
+ return False
+
+ return True
+
+ def user_may_create_room_alias(self, userid, room_alias):
+ """Implements synapse.events.SpamChecker.user_may_create_room_alias
+ """
+ return True
+
+ def user_may_publish_room(self, userid, room_id):
+ """Implements synapse.events.SpamChecker.user_may_publish_room
+ """
+ return True
+
+ def user_may_join_room(self, userid, room_id, is_invited):
+ """Implements synapse.events.SpamChecker.user_may_join_room
+ """
+ if self.can_only_join_rooms_with_invite and not is_invited:
+ return False
+
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ """Implements synapse.events.SpamChecker.parse_config
+ """
+ if "default" in config:
+ return config
+ else:
+ raise ConfigError("No default set for spam_config DomainRuleChecker")
+
+ @staticmethod
+ def _get_domain_from_id(mxid):
+ """Parses a string and returns the domain part of the mxid.
+
+ Args:
+ mxid (str): a valid mxid
+
+ Returns:
+ str: the domain part of the mxid
+
+ """
+ idx = mxid.find(":")
+ if idx == -1:
+ raise Exception("Invalid ID: %r" % (mxid,))
+ return mxid[idx + 1 :]
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 9b78924d96..7f63f1bfa0 100644
--- a/synapse/spam_checker_api/__init__.py
+++ b/synapse/spam_checker_api/__init__.py
@@ -48,8 +48,10 @@ class SpamCheckerApi(object):
twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
The filtered state events in the room.
"""
- state_ids = yield self._store.get_filtered_current_state_ids(
- room_id=room_id, state_filter=StateFilter.from_types(types)
+ state_ids = yield defer.ensureDeferred(
+ self._store.get_filtered_current_state_ids(
+ room_id=room_id, state_filter=StateFilter.from_types(types)
+ )
)
- state = yield self._store.get_events(state_ids.values())
+ state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
return state.values()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index a1d3884667..dba8d91eef 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -641,7 +641,7 @@ class StateResolutionStore(object):
allow_rejected (bool): If True return rejected events.
Returns:
- Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+ Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
"""
return self.store.get_events(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6814bf5fcf..ab49d227de 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,12 +19,11 @@ import random
from abc import ABCMeta
from typing import Any, Optional
-from canonicaljson import json
-
from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.types import Collection, get_domain_from_id
+from synapse.util import json_decoder
logger = logging.getLogger(__name__)
@@ -99,13 +98,13 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
- # Decode it to a Unicode string before feeding it to json.loads, since
+ # Decode it to a Unicode string before feeding it to the JSON decoder, since
# Python 3.5 does not support deserializing bytes.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode("utf8")
try:
- return json.loads(db_content)
+ return json_decoder.decode(db_content)
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index f43463df53..90a1f9e8b1 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -18,8 +18,6 @@ from typing import Optional
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.metrics.background_process_metrics import run_as_background_process
from . import engines
@@ -308,9 +306,8 @@ class BackgroundUpdater(object):
update_name (str): Name of update
"""
- @defer.inlineCallbacks
- def noop_update(progress, batch_size):
- yield self._end_background_update(update_name)
+ async def noop_update(progress, batch_size):
+ await self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, noop_update)
@@ -409,12 +406,11 @@ class BackgroundUpdater(object):
else:
runner = create_index_sqlite
- @defer.inlineCallbacks
- def updater(progress, batch_size):
+ async def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
- yield self.db_pool.runWithConnection(runner)
- yield self._end_background_update(update_name)
+ await self.db_pool.runWithConnection(runner)
+ await self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, updater)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 4ada6f5563..b9aef96b08 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -332,8 +332,7 @@ class DatabasePool(object):
"""
return self._db_pool.running
- @defer.inlineCallbacks
- def _check_safe_to_upsert(self):
+ async def _check_safe_to_upsert(self):
"""
Is it safe to use native UPSERT?
@@ -342,7 +341,7 @@ class DatabasePool(object):
If the background updates have not completed, wait 15 sec and check again.
"""
- updates = yield self.simple_select_list(
+ updates = await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
@@ -517,14 +516,16 @@ class DatabasePool(object):
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
- result = yield self.runWithConnection(
- self.new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- **kwargs
+ result = yield defer.ensureDeferred(
+ self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ **kwargs
+ )
)
for after_callback, after_args, after_kwargs in after_callbacks:
@@ -536,8 +537,7 @@ class DatabasePool(object):
return result
- @defer.inlineCallbacks
- def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
+ async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
@@ -548,7 +548,7 @@ class DatabasePool(object):
kwargs: named args to pass to `func`
Returns:
- Deferred: The result of func
+ The result of func
"""
parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
if not parent_context:
@@ -571,12 +571,10 @@ class DatabasePool(object):
return func(conn, *args, **kwargs)
- result = yield make_deferred_yieldable(
+ return await make_deferred_yieldable(
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
)
- return result
-
@staticmethod
def cursor_to_dict(cursor):
"""Converts a SQL cursor into an list of dicts.
@@ -614,8 +612,7 @@ class DatabasePool(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
- @defer.inlineCallbacks
- def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+ async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
"""Executes an INSERT query on the named table.
Args:
@@ -631,7 +628,7 @@ class DatabasePool(object):
`or_ignore` is True
"""
try:
- yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+ await self.runInteraction(desc, self.simple_insert_txn, table, values)
except self.engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
@@ -684,8 +681,7 @@ class DatabasePool(object):
txn.executemany(sql, vals)
- @defer.inlineCallbacks
- def simple_upsert(
+ async def simple_upsert(
self,
table,
keyvalues,
@@ -714,14 +710,14 @@ class DatabasePool(object):
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
- Deferred(None or bool): Native upserts always return None. Emulated
+ None or bool: Native upserts always return None. Emulated
upserts return True if a new entry was created, False if an existing
one was updated.
"""
attempts = 0
while True:
try:
- result = yield self.runInteraction(
+ return await self.runInteraction(
desc,
self.simple_upsert_txn,
table,
@@ -730,7 +726,6 @@ class DatabasePool(object):
insertion_values,
lock=lock,
)
- return result
except self.engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
@@ -1121,8 +1116,7 @@ class DatabasePool(object):
return cls.cursor_to_dict(txn)
- @defer.inlineCallbacks
- def simple_select_many_batch(
+ async def simple_select_many_batch(
self,
table,
column,
@@ -1156,7 +1150,7 @@ class DatabasePool(object):
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
- rows = yield self.runInteraction(
+ rows = await self.runInteraction(
desc,
self.simple_select_many_txn,
table,
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 4406e58273..0ac854aee2 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -87,12 +87,21 @@ class Databases(object):
logger.info("Database %r prepared", db_name)
+ # Closing the context manager doesn't close the connection.
+ # psycopg will close the connection when the object gets GCed, but *only*
+ # if the PID is the same as when the connection was opened [1], and
+ # it may not be if we fork in the meantime.
+ #
+ # [1]: https://github.com/psycopg/psycopg2/blob/2_8_5/psycopg/connection_type.c#L1378
+
+ db_conn.close()
+
# Sanity check that we have actually configured all the required stores.
if not main:
raise Exception("No 'main' data store configured")
if not state:
- raise Exception("No 'main' data store configured")
+ raise Exception("No 'state' data store configured")
# We use local variables here to ensure that the databases do not have
# optional types.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5cf1a88399..02568a2391 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore(
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
Returns:
- A Deferred which resolves when the state was set successfully.
+ An Awaitable which resolves when the state was set successfully.
"""
return self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 10de446065..1e7637a6f5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
},
)
- def get_cache_stream_token(self, instance_name):
+ def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen:
- return self._cache_id_gen.get_current_token(instance_name)
+ return self._cache_id_gen.get_current_token_for_writer(instance_name)
else:
return 0
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2b33060480..9a786e2929 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
- inlineCallbacks=True,
)
- def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
- rows = yield self.db_pool.simple_select_many_batch(
+ async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+ rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 484875f989..4826be630c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
- def get_auth_chain(self, event_ids, include_given=False):
+ async def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
@@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns:
list of events
"""
- return self.get_auth_chain_ids(
+ event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
- ).addCallback(self.get_events_as_list)
+ )
+ return await self.get_events_as_list(event_ids)
def get_auth_chain_ids(
self,
@@ -257,11 +258,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
- def get_oldest_events_in_room(self, room_id):
- return self.db_pool.runInteraction(
- "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
- )
-
def get_oldest_events_with_depth_in_room(self, room_id):
return self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room",
@@ -303,14 +299,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
return max(row["depth"] for row in rows)
- def _get_oldest_events_in_room_txn(self, txn, room_id):
- return self.db_pool.simple_select_onecol_txn(
- txn,
- table="event_backward_extremities",
- keyvalues={"room_id": room_id},
- retcol="event_id",
- )
-
def get_prev_events_for_room(self, room_id: str):
"""
Gets a subset of the current forward extremities in the given room.
@@ -472,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
- def get_backfill_events(self, room_id, event_list, limit):
+ async def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
@@ -482,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_list (list)
limit (int)
"""
- return (
- self.db_pool.runInteraction(
- "get_backfill_events",
- self._get_backfill_events,
- room_id,
- event_list,
- limit,
- )
- .addCallback(self.get_events_as_list)
- .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
+ event_ids = await self.db_pool.runInteraction(
+ "get_backfill_events",
+ self._get_backfill_events,
+ room_id,
+ event_list,
+ limit,
)
+ events = await self.get_events_as_list(event_ids)
+ return sorted(events, key=lambda e: -e.depth)
def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@@ -553,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
- events = await self.get_events_as_list(ids)
- return events
+ return await self.get_events_as_list(ids)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7c246d3e4c..e8834b2162 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_delay = 3
self._rotate_count = 10000
- @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
- def get_unread_event_push_actions_by_room_for_user(
+ @cached(num_args=3, tree=True, max_entries=5000)
+ async def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
- ret = yield self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
last_read_event_id,
)
- return ret
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1a68bf32cb..b90e6de2d5 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,13 +17,11 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
import attr
from prometheus_client import Counter
-from twisted.internet import defer
-
import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
@@ -113,15 +111,14 @@ class PersistEventsStore:
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
- @defer.inlineCallbacks
- def _persist_events_and_state_updates(
+ async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
- ):
+ ) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -136,7 +133,7 @@ class PersistEventsStore:
backfilled
Returns:
- Deferred: resolves when the events have been persisted
+ Resolves when the events have been persisted
"""
# We want to calculate the stream orderings as late as possible, as
@@ -168,7 +165,7 @@ class PersistEventsStore:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
@@ -206,16 +203,15 @@ class PersistEventsStore:
(room_id,), list(latest_event_ids)
)
- @defer.inlineCallbacks
- def _get_events_which_are_prevs(self, event_ids):
+ async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.
Args:
- event_ids (Iterable[str]): event ids to filter
+ event_ids: event ids to filter
Returns:
- Deferred[List[str]]: filtered event ids
+ Filtered event ids
"""
results = []
@@ -240,14 +236,13 @@ class PersistEventsStore:
results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
return results
- @defer.inlineCallbacks
- def _get_prevs_before_rejected(self, event_ids):
+ async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
"""Get soft-failed ancestors to remove from the extremities.
Given a set of events, find all those that have been soft-failed or
@@ -259,11 +254,11 @@ class PersistEventsStore:
are separated by soft failed events.
Args:
- event_ids (Iterable[str]): Events to find prev events for. Note
- that these must have already been persisted.
+ event_ids: Events to find prev events for. Note that these must have
+ already been persisted.
Returns:
- Deferred[set[str]]
+ The previous events.
"""
# The set of event_ids to return. This includes all soft-failed events
@@ -304,7 +299,7 @@ class PersistEventsStore:
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 35a0e09e3c..e53c6373a8 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventContentFields
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
@@ -94,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="NOT have_censored",
)
- @defer.inlineCallbacks
- def _background_reindex_fields_sender(self, progress, batch_size):
+ async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -155,19 +152,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows)
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
)
return result
- @defer.inlineCallbacks
- def _background_reindex_origin_server_ts(self, progress, batch_size):
+ async def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -234,19 +230,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows_to_update)
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_ORIGIN_SERVER_TS_NAME
)
return result
- @defer.inlineCallbacks
- def _cleanup_extremities_bg_update(self, progress, batch_size):
+ async def _cleanup_extremities_bg_update(self, progress, batch_size):
"""Background update to clean out extremities that should have been
deleted previously.
@@ -414,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(original_set)
- num_handled = yield self.db_pool.runInteraction(
+ num_handled = await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
return num_handled
- @defer.inlineCallbacks
- def _redactions_received_ts(self, progress, batch_size):
+ async def _redactions_received_ts(self, progress, batch_size):
"""Handles filling out the `received_ts` column in redactions.
"""
last_event_id = progress.get("last_event_id", "")
@@ -480,17 +474,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows)
- count = yield self.db_pool.runInteraction(
+ count = await self.db_pool.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn
)
if not count:
- yield self.db_pool.updates._end_background_update("redactions_received_ts")
+ await self.db_pool.updates._end_background_update("redactions_received_ts")
return count
- @defer.inlineCallbacks
- def _event_fix_redactions_bytes(self, progress, batch_size):
+ async def _event_fix_redactions_bytes(self, progress, batch_size):
"""Undoes hex encoded censored redacted event JSON.
"""
@@ -511,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute("DROP INDEX redactions_censored_redacts")
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
- yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
+ await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
return 1
- @defer.inlineCallbacks
- def _event_store_labels(self, progress, batch_size):
+ async def _event_store_labels(self, progress, batch_size):
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")
@@ -575,11 +567,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return nbrows
- num_rows = yield self.db_pool.runInteraction(
+ num_rows = await self.db_pool.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn
)
if not num_rows:
- yield self.db_pool.updates._end_background_update("event_store_labels")
+ await self.db_pool.updates._end_background_update("event_store_labels")
return num_rows
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 755b7a2a85..4a3333c0db 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -19,9 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
-from typing import List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, overload
from constantly import NamedConstant, Names
+from typing_extensions import Literal
from twisted.internet import defer
@@ -32,7 +33,7 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersions,
)
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -137,44 +138,33 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_received_ts",
)
- def get_received_ts_by_stream_pos(self, stream_ordering):
- """Given a stream ordering get an approximate timestamp of when it
- happened.
-
- This is done by simply taking the received ts of the first event that
- has a stream ordering greater than or equal to the given stream pos.
- If none exists returns the current time, on the assumption that it must
- have happened recently.
-
- Args:
- stream_ordering (int)
-
- Returns:
- Deferred[int]
- """
-
- def _get_approximate_received_ts_txn(txn):
- sql = """
- SELECT received_ts FROM events
- WHERE stream_ordering >= ?
- LIMIT 1
- """
-
- txn.execute(sql, (stream_ordering,))
- row = txn.fetchone()
- if row and row[0]:
- ts = row[0]
- else:
- ts = self.clock.time_msec()
-
- return ts
+ # Inform mypy that if allow_none is False (the default) then get_event
+ # always returns an EventBase.
+ @overload
+ async def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[False] = False,
+ check_room_id: Optional[str] = None,
+ ) -> EventBase:
+ ...
- return self.db_pool.runInteraction(
- "get_approximate_received_ts", _get_approximate_received_ts_txn
- )
+ @overload
+ async def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[True] = False,
+ check_room_id: Optional[str] = None,
+ ) -> Optional[EventBase]:
+ ...
- @defer.inlineCallbacks
- def get_event(
+ async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@@ -182,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected: bool = False,
allow_none: bool = False,
check_room_id: Optional[str] = None,
- ):
+ ) -> Optional[EventBase]:
"""Get an event from the database by event_id.
Args:
@@ -207,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
- Deferred[EventBase|None]
+ The event, or None if the event was not found.
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[event_id],
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -230,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
return event
- @defer.inlineCallbacks
- def get_events(
+ async def get_events(
self,
- event_ids: List[str],
+ event_ids: Iterable[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
- ):
+ ) -> Dict[str, EventBase]:
"""Get events from the database
Args:
@@ -256,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
omits rejeted events from the response.
Returns:
- Deferred : Dict from event_id to event.
+ A mapping from event_id to event.
"""
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
event_ids,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -267,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
- @defer.inlineCallbacks
- def get_events_as_list(
+ async def get_events_as_list(
self,
- event_ids: List[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
- ):
+ ) -> List[EventBase]:
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
@@ -295,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
omits rejected events from the response.
Returns:
- Deferred[list[EventBase]]: List of events fetched from the database. The
- events are in the same order as `event_ids` arg.
+ List of events fetched from the database. The events are in the same
+ order as `event_ids` arg.
Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched.
@@ -306,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
- event_entry_map = yield self._get_events_from_cache_or_db(
+ event_entry_map = await self._get_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@@ -341,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
- event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+ event_map = await self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@@ -407,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
if get_prev_content:
if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
+ prev = await self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
@@ -419,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
return events
- @defer.inlineCallbacks
- def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -435,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
- Deferred[Dict[str, _EventCacheEntry]]:
+ Dict[str, _EventCacheEntry]:
map from event id to result
"""
event_entry_map = self._get_events_from_cache(
@@ -453,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
- missing_events = yield self._get_events_from_db(
+ missing_events = await self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)
@@ -561,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
- @defer.inlineCallbacks
- def _get_events_from_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database.
Returned events will be added to the cache for future lookups.
@@ -576,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
- Deferred[Dict[str, _EventCacheEntry]]:
+ Dict[str, _EventCacheEntry]:
map from event id to result. May return extra events which
weren't asked for.
"""
@@ -584,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
events_to_fetch = event_ids
while events_to_fetch:
- row_map = yield self._enqueue_events(events_to_fetch)
+ row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
@@ -610,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason:
continue
- d = db_to_json(row["json"])
- internal_metadata = db_to_json(row["internal_metadata"])
+ # If the event or metadata cannot be parsed, log the error and act
+ # as if the event is unknown.
+ try:
+ d = db_to_json(row["json"])
+ except ValueError:
+ logger.error("Unable to parse json from event: %s", event_id)
+ continue
+ try:
+ internal_metadata = db_to_json(row["internal_metadata"])
+ except ValueError:
+ logger.error(
+ "Unable to parse internal_metadata from event: %s", event_id
+ )
+ continue
format_version = row["format_version"]
if format_version is None:
@@ -686,8 +684,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- @defer.inlineCallbacks
- def _enqueue_events(self, events):
+ async def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -696,7 +693,7 @@ class EventsWorkerStore(SQLBaseStore):
events (Iterable[str]): events to be fetched.
Returns:
- Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+ Dict[str, Dict]: map from event id to row data from the database.
May contain events that weren't requested.
"""
@@ -719,7 +716,7 @@ class EventsWorkerStore(SQLBaseStore):
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
- row_map = yield events_d
+ row_map = await events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
return row_map
@@ -878,12 +875,11 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- @defer.inlineCallbacks
- def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
@@ -894,15 +890,14 @@ class EventsWorkerStore(SQLBaseStore):
return {r["event_id"] for r in rows}
- @defer.inlineCallbacks
- def have_seen_events(self, event_ids):
+ async def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
- Deferred[set[str]]: The events we have already seen.
+ set[str]: The events we have already seen.
"""
results = set()
@@ -918,41 +913,11 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
- def _get_total_state_event_counts_txn(self, txn, room_id):
- """
- See get_total_state_event_counts.
- """
- # We join against the events table as that has an index on room_id
- sql = """
- SELECT COUNT(*) FROM state_events
- INNER JOIN events USING (room_id, event_id)
- WHERE room_id=?
- """
- txn.execute(sql, (room_id,))
- row = txn.fetchone()
- return row[0] if row else 0
-
- def get_total_state_event_counts(self, room_id):
- """
- Gets the total number of state events in a room.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[int]
- """
- return self.db_pool.runInteraction(
- "get_total_state_event_counts",
- self._get_total_state_event_counts_txn,
- room_id,
- )
-
def _get_current_state_event_counts_txn(self, txn, room_id):
"""
See get_current_state_event_counts.
@@ -978,8 +943,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
- @defer.inlineCallbacks
- def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -990,9 +954,9 @@ class EventsWorkerStore(SQLBaseStore):
room_id (str)
Returns:
- Deferred[dict[str:int]] of complexity version to complexity.
+ dict[str:int] of complexity version to complexity.
"""
- state_events = yield self.get_current_state_event_counts(room_id)
+ state_events = await self.get_current_state_event_counts(room_id)
# Call this one "v1", so we can introduce new ones as we want to develop
# it.
@@ -1222,97 +1186,6 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
- @cached(num_args=5, max_entries=10)
- def get_all_new_events(
- self,
- last_backfill_id,
- last_forward_id,
- current_backfill_id,
- current_forward_id,
- limit,
- ):
- """Get all the new events that have arrived at the server either as
- new events or as backfilled events"""
- have_backfill_events = last_backfill_id != current_backfill_id
- have_forward_events = last_forward_id != current_forward_id
-
- if not have_backfill_events and not have_forward_events:
- return defer.succeed(AllNewEventsResult([], [], [], [], []))
-
- def get_all_new_events_txn(txn):
- sql = (
- "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? < stream_ordering AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
- )
- if have_forward_events:
- txn.execute(sql, (last_forward_id, current_forward_id, limit))
- new_forward_events = txn.fetchall()
-
- if len(new_forward_events) == limit:
- upper_bound = new_forward_events[-1][0]
- else:
- upper_bound = current_forward_id
-
- sql = (
- "SELECT event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (last_forward_id, upper_bound))
- forward_ex_outliers = txn.fetchall()
- else:
- new_forward_events = []
- forward_ex_outliers = []
-
- sql = (
- "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? > stream_ordering AND stream_ordering >= ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT ?"
- )
- if have_backfill_events:
- txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
- new_backfill_events = txn.fetchall()
-
- if len(new_backfill_events) == limit:
- upper_bound = new_backfill_events[-1][0]
- else:
- upper_bound = current_backfill_id
-
- sql = (
- "SELECT -event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (-last_backfill_id, -upper_bound))
- backward_ex_outliers = txn.fetchall()
- else:
- new_backfill_events = []
- backward_ex_outliers = []
-
- return AllNewEventsResult(
- new_forward_events,
- new_backfill_events,
- forward_ex_outliers,
- backward_ex_outliers,
- )
-
- return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
-
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
@@ -1320,9 +1193,9 @@ class EventsWorkerStore(SQLBaseStore):
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
- @cachedInlineCallbacks(max_entries=5000)
- def get_event_ordering(self, event_id):
- res = yield self.db_pool.simple_select_one(
+ @cached(max_entries=5000)
+ async def get_event_ordering(self, event_id):
+ res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
@@ -1357,14 +1230,3 @@ class EventsWorkerStore(SQLBaseStore):
return self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
-
-
-AllNewEventsResult = namedtuple(
- "AllNewEventsResult",
- [
- "new_forward_events",
- "new_backfill_events",
- "forward_ex_outliers",
- "backward_ex_outliers",
- ],
-)
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 380db3a3f3..0e3b8739c6 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -341,14 +341,15 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_users_for_summary_by_role", _get_users_for_summary_txn
)
- def is_user_in_group(self, user_id, group_id):
- return self.db_pool.simple_select_one_onecol(
+ async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
+ result = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
allow_none=True,
desc="is_user_in_group",
- ).addCallback(lambda r: bool(r))
+ )
+ return bool(result)
def is_user_admin_in_group(self, group_id, user_id):
return self.db_pool.simple_select_one_onecol(
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 384e9c5eb0..fadcad51e7 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,6 +16,7 @@
import itertools
import logging
+from typing import Iterable, Tuple
from signedjson.key import decode_verify_key_bytes
@@ -88,12 +89,17 @@ class KeyStore(SQLBaseStore):
return self.db_pool.runInteraction("get_server_verify_keys", _txn)
- def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+ async def store_server_verify_keys(
+ self,
+ from_server: str,
+ ts_added_ms: int,
+ verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
+ ) -> None:
"""Stores NACL verification keys for remote servers.
Args:
- from_server (str): Where the verification keys were looked up
- ts_added_ms (int): The time to record that the key was added
- verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+ from_server: Where the verification keys were looked up
+ ts_added_ms: The time to record that the key was added
+ verify_keys:
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
@@ -115,13 +121,7 @@ class KeyStore(SQLBaseStore):
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
- def _invalidate(res):
- f = self._get_server_verify_key.invalidate
- for i in invalidations:
- f((i,))
- return res
-
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"store_server_verify_keys",
self.db_pool.simple_upsert_many_txn,
table="server_signature_keys",
@@ -134,7 +134,11 @@ class KeyStore(SQLBaseStore):
"verify_key",
),
value_values=value_values,
- ).addCallback(_invalidate)
+ )
+
+ invalidate = self._get_server_verify_key.invalidate
+ for i in invalidations:
+ invalidate((i,))
def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 59ba12820a..4e3ec02d14 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -15,8 +15,8 @@
from typing import List, Tuple
+from synapse.api.presence import UserPresenceState
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_presence_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
)
- def get_presence_for_users(self, user_ids):
- rows = yield self.db_pool.simple_select_many_batch(
+ async def get_presence_for_users(self, user_ids):
+ rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
@@ -160,24 +157,3 @@ class PresenceStore(SQLBaseStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
-
- def allow_presence_visible(self, observed_localpart, observer_userid):
- return self.db_pool.simple_insert(
- table="presence_allow_inbound",
- values={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="allow_presence_visible",
- or_ignore=True,
- )
-
- def disallow_presence_visible(self, observed_localpart, observer_userid):
- return self.db_pool.simple_delete_one(
- table="presence_allow_inbound",
- keyvalues={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="disallow_presence_visible",
- )
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index b8261357d4..086cfbeed4 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,9 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple
+
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.databases.main.roommember import ProfileInfo
+from synapse.types import UserID
+from synapse.util.caches.descriptors import cached
+
+BATCH_SIZE = 100
class ProfileWorkerStore(SQLBaseStore):
@@ -38,6 +45,7 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
+ @cached(max_entries=5000)
def get_profile_displayname(self, user_localpart):
return self.db_pool.simple_select_one_onecol(
table="profiles",
@@ -46,6 +54,7 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_displayname",
)
+ @cached(max_entries=5000)
def get_profile_avatar_url(self, user_localpart):
return self.db_pool.simple_select_one_onecol(
table="profiles",
@@ -54,6 +63,56 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ def get_latest_profile_replication_batch_number(self):
+ def f(txn):
+ txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
+ rows = self.db_pool.cursor_to_dict(txn)
+ return rows[0]["maxbatch"]
+
+ return self.db_pool.runInteraction(
+ "get_latest_profile_replication_batch_number", f
+ )
+
+ def get_profile_batch(self, batchnum):
+ return self.db_pool.simple_select_list(
+ table="profiles",
+ keyvalues={"batch": batchnum},
+ retcols=("user_id", "displayname", "avatar_url", "active"),
+ desc="get_profile_batch",
+ )
+
+ def assign_profile_batch(self):
+ def f(txn):
+ sql = (
+ "UPDATE profiles SET batch = "
+ "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) "
+ "WHERE user_id in ("
+ " SELECT user_id FROM profiles WHERE batch is NULL limit ?"
+ ")"
+ )
+ txn.execute(sql, (BATCH_SIZE,))
+ return txn.rowcount
+
+ return self.db_pool.runInteraction("assign_profile_batch", f)
+
+ def get_replication_hosts(self):
+ def f(txn):
+ txn.execute(
+ "SELECT host, last_synced_batch FROM profile_replication_status"
+ )
+ rows = self.db_pool.cursor_to_dict(txn)
+ return {r["host"]: r["last_synced_batch"] for r in rows}
+
+ return self.db_pool.runInteraction("get_replication_hosts", f)
+
+ def update_replication_batch_for_host(self, host, last_synced_batch):
+ return self.db_pool.simple_upsert(
+ table="profile_replication_status",
+ keyvalues={"host": host},
+ values={"last_synced_batch": last_synced_batch},
+ desc="update_replication_batch_for_host",
+ )
+
def get_from_remote_profile_cache(self, user_id):
return self.db_pool.simple_select_one(
table="remote_profile_cache",
@@ -68,24 +127,83 @@ class ProfileWorkerStore(SQLBaseStore):
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self.db_pool.simple_update_one(
+ def set_profile_displayname(self, user_localpart, new_displayname, batchnum):
+ # Invalidate the read cache for this user
+ self.get_profile_displayname.invalidate((user_localpart,))
+
+ return self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
+ values={"displayname": new_displayname, "batch": batchnum},
desc="set_profile_displayname",
+ lock=False, # we can do this because user_id has a unique index
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self.db_pool.simple_update_one(
+ def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum):
+ # Invalidate the read cache for this user
+ self.get_profile_avatar_url.invalidate((user_localpart,))
+
+ return self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
+ values={"avatar_url": new_avatar_url, "batch": batchnum},
desc="set_profile_avatar_url",
+ lock=False, # we can do this because user_id has a unique index
+ )
+
+ def set_profiles_active(
+ self, users: List[UserID], active: bool, hide: bool, batchnum: int,
+ ):
+ """Given a set of users, set active and hidden flags on them.
+
+ Args:
+ users: A list of UserIDs
+ active: Whether to set the users to active or inactive
+ hide: Whether to hide the users (withold from replication). If
+ False and active is False, users will have their profiles
+ erased
+ batchnum: The batch number, used for profile replication
+
+ Returns:
+ Deferred
+ """
+ # Convert list of localparts to list of tuples containing localparts
+ user_localparts = [(user.localpart,) for user in users]
+
+ # Generate list of value tuples for each user
+ value_names = ("active", "batch")
+ values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple]
+
+ if not active and not hide:
+ # we are deactivating for real (not in hide mode)
+ # so clear the profile information
+ value_names += ("avatar_url", "displayname")
+ values = [v + (None, None) for v in values]
+
+ return self.db_pool.runInteraction(
+ "set_profiles_active",
+ self.db_pool.simple_upsert_many_txn,
+ table="profiles",
+ key_names=("user_id",),
+ key_values=user_localparts,
+ value_names=value_names,
+ value_values=values,
)
class ProfileStore(ProfileWorkerStore):
+ def __init__(self, database, db_conn, hs):
+
+ super(ProfileStore, self).__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_index_update(
+ "profile_replication_status_host_index",
+ index_name="profile_replication_status_idx",
+ table="profile_replication_status",
+ columns=["host"],
+ unique=True,
+ )
+
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
@@ -104,7 +222,7 @@ class ProfileStore(ProfileWorkerStore):
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self.db_pool.simple_update(
+ return self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
updatevalues={
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 6562db5c2b..a585e54812 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -30,9 +30,9 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -82,9 +82,9 @@ class PushRulesWorkerStore(
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen = ChainedIdGenerator(
- self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
- ) # type: Union[ChainedIdGenerator, SlavedIdTracker]
+ self._push_rules_stream_id_gen = StreamIdGenerator(
+ db_conn, "push_rules_stream", "stream_id"
+ ) # type: Union[StreamIdGenerator, SlavedIdTracker]
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
@@ -115,9 +115,9 @@ class PushRulesWorkerStore(
"""
raise NotImplementedError()
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_for_user(self, user_id):
- rows = yield self.db_pool.simple_select_list(
+ @cached(max_entries=5000)
+ async def get_push_rules_for_user(self, user_id):
+ rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
@@ -133,17 +133,15 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
- enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
+ enabled_map = await self.get_push_rules_enabled_for_user(user_id)
use_new_defaults = user_id in self._users_new_default_push_rules
- rules = _load_rules(rows, enabled_map, use_new_defaults)
-
- return rules
+ return _load_rules(rows, enabled_map, use_new_defaults)
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_enabled_for_user(self, user_id):
- results = yield self.db_pool.simple_select_list(
+ @cached(max_entries=5000)
+ async def get_push_rules_enabled_for_user(self, user_id):
+ results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
@@ -170,18 +168,15 @@ class PushRulesWorkerStore(
)
@cachedList(
- cached_method_name="get_push_rules_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
)
- def bulk_get_push_rules(self, user_ids):
+ async def bulk_get_push_rules(self, user_ids):
if not user_ids:
return {}
results = {user_id: [] for user_id in user_ids}
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
@@ -194,7 +189,7 @@ class PushRulesWorkerStore(
for row in rows:
results.setdefault(row["user_name"], []).append(row)
- enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+ enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
use_new_defaults = user_id in self._users_new_default_push_rules
@@ -205,14 +200,15 @@ class PushRulesWorkerStore(
return results
- @defer.inlineCallbacks
- def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
+ async def copy_push_rule_from_room_to_room(
+ self, new_room_id: str, user_id: str, rule: dict
+ ) -> None:
"""Copy a single push rule from one room to another for a specific user.
Args:
- new_room_id (str): ID of the new room.
- user_id (str): ID of user the push rule belongs to.
- rule (Dict): A push rule.
+ new_room_id: ID of the new room.
+ user_id : ID of user the push rule belongs to.
+ rule: A push rule.
"""
# Create new rule id
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
@@ -224,7 +220,7 @@ class PushRulesWorkerStore(
condition["pattern"] = new_room_id
# Add the rule for the new room
- yield self.add_push_rule(
+ await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
priority_class=rule["priority_class"],
@@ -232,20 +228,19 @@ class PushRulesWorkerStore(
actions=rule["actions"],
)
- @defer.inlineCallbacks
- def copy_push_rules_from_room_to_room_for_user(
- self, old_room_id, new_room_id, user_id
- ):
+ async def copy_push_rules_from_room_to_room_for_user(
+ self, old_room_id: str, new_room_id: str, user_id: str
+ ) -> None:
"""Copy all of the push rules from one room to another for a specific
user.
Args:
- old_room_id (str): ID of the old room.
- new_room_id (str): ID of the new room.
- user_id (str): ID of user to copy push rules for.
+ old_room_id: ID of the old room.
+ new_room_id: ID of the new room.
+ user_id: ID of user to copy push rules for.
"""
# Retrieve push rules for this user
- user_push_rules = yield self.get_push_rules_for_user(user_id)
+ user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
for rule in user_push_rules:
@@ -254,21 +249,20 @@ class PushRulesWorkerStore(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
- yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
+ await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
- inlineCallbacks=True,
)
- def bulk_get_push_rules_enabled(self, user_ids):
+ async def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
return {}
results = {user_id: {} for user_id in user_ids}
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
@@ -332,8 +326,7 @@ class PushRulesWorkerStore(
class PushRuleStore(PushRulesWorkerStore):
- @defer.inlineCallbacks
- def add_push_rule(
+ async def add_push_rule(
self,
user_id,
rule_id,
@@ -342,13 +335,14 @@ class PushRuleStore(PushRulesWorkerStore):
actions,
before=None,
after=None,
- ):
+ ) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
+ with self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
if before or after:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id,
@@ -362,7 +356,7 @@ class PushRuleStore(PushRulesWorkerStore):
after,
)
else:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id,
@@ -546,16 +540,15 @@ class PushRuleStore(PushRulesWorkerStore):
},
)
- @defer.inlineCallbacks
- def delete_push_rule(self, user_id, rule_id):
+ async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
- user_id (str): The matrix ID of the push rule owner
- rule_id (str): The rule_id of the rule to be deleted
+ user_id: The matrix ID of the push rule owner
+ rule_id: The rule_id of the rule to be deleted
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
@@ -567,20 +560,21 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.db_pool.runInteraction(
+ with self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
+ await self.db_pool.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
stream_id,
event_stream_ordering,
)
- @defer.inlineCallbacks
- def set_push_rule_enabled(self, user_id, rule_id, enabled):
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.db_pool.runInteraction(
+ async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
+ with self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
+ await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id,
@@ -611,8 +605,9 @@ class PushRuleStore(PushRulesWorkerStore):
op="ENABLE" if enabled else "DISABLE",
)
- @defer.inlineCallbacks
- def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+ async def set_push_rule_actions(
+ self, user_id, rule_id, actions, is_default_rule
+ ) -> None:
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -651,9 +646,10 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.db_pool.runInteraction(
+ with self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
+ await self.db_pool.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
stream_id,
@@ -681,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
- def get_push_rules_stream_token(self):
- """Get the position of the push rules stream.
- Returns a pair of a stream id for the push_rules stream and the
- room stream ordering it corresponds to."""
- return self._push_rules_stream_id_gen.get_current_token()
-
def get_max_push_rules_stream_id(self):
- return self.get_push_rules_stream_token()[0]
+ return self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b5200fbe79..1126fd0751 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -19,10 +19,8 @@ from typing import Iterable, Iterator, List, Tuple
from canonicaljson import encode_canonical_json
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
logger = logging.getLogger(__name__)
@@ -34,23 +32,22 @@ class PusherWorkerStore(SQLBaseStore):
Drops any rows whose data cannot be decoded
"""
for r in rows:
- dataJson = r["data"]
+ data_json = r["data"]
try:
- r["data"] = db_to_json(dataJson)
+ r["data"] = db_to_json(data_json)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
r["id"],
- dataJson,
+ data_json,
e.args[0],
)
continue
yield r
- @defer.inlineCallbacks
- def user_has_pusher(self, user_id):
- ret = yield self.db_pool.simple_select_one_onecol(
+ async def user_has_pusher(self, user_id):
+ ret = await self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
@@ -61,9 +58,8 @@ class PusherWorkerStore(SQLBaseStore):
def get_pushers_by_user_id(self, user_id):
return self.get_pushers_by({"user_name": user_id})
- @defer.inlineCallbacks
- def get_pushers_by(self, keyvalues):
- ret = yield self.db_pool.simple_select_list(
+ async def get_pushers_by(self, keyvalues):
+ ret = await self.db_pool.simple_select_list(
"pushers",
keyvalues,
[
@@ -87,16 +83,14 @@ class PusherWorkerStore(SQLBaseStore):
)
return self._decode_pushers_rows(ret)
- @defer.inlineCallbacks
- def get_all_pushers(self):
+ async def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
- rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers)
- return rows
+ return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
async def get_all_updated_pushers_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -164,19 +158,16 @@ class PusherWorkerStore(SQLBaseStore):
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
- @cachedInlineCallbacks(num_args=1, max_entries=15000)
- def get_if_user_has_pusher(self, user_id):
+ @cached(num_args=1, max_entries=15000)
+ async def get_if_user_has_pusher(self, user_id):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(
- cached_method_name="get_if_user_has_pusher",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
- def get_if_users_have_pushers(self, user_ids):
- rows = yield self.db_pool.simple_select_many_batch(
+ async def get_if_users_have_pushers(self, user_ids):
+ rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
@@ -189,34 +180,38 @@ class PusherWorkerStore(SQLBaseStore):
return result
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering(
+ async def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
- ):
- yield self.db_pool.simple_update_one(
+ ) -> None:
+ await self.db_pool.simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
desc="update_pusher_last_stream_ordering",
)
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering_and_success(
- self, app_id, pushkey, user_id, last_stream_ordering, last_success
- ):
+ async def update_pusher_last_stream_ordering_and_success(
+ self,
+ app_id: str,
+ pushkey: str,
+ user_id: str,
+ last_stream_ordering: int,
+ last_success: int,
+ ) -> bool:
"""Update the last stream ordering position we've processed up to for
the given pusher.
Args:
- app_id (str)
- pushkey (str)
- last_stream_ordering (int)
- last_success (int)
+ app_id
+ pushkey
+ user_id
+ last_stream_ordering
+ last_success
Returns:
- Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+ True if the pusher still exists; False if it has been deleted.
"""
- updated = yield self.db_pool.simple_update(
+ updated = await self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
@@ -228,18 +223,18 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated)
- @defer.inlineCallbacks
- def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
- yield self.db_pool.simple_update(
+ async def update_pusher_failing_since(
+ self, app_id, pushkey, user_id, failing_since
+ ) -> None:
+ await self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
desc="update_pusher_failing_since",
)
- @defer.inlineCallbacks
- def get_throttle_params_by_room(self, pusher_id):
- res = yield self.db_pool.simple_select_list(
+ async def get_throttle_params_by_room(self, pusher_id):
+ res = await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
@@ -255,11 +250,10 @@ class PusherWorkerStore(SQLBaseStore):
return params_by_room
- @defer.inlineCallbacks
- def set_throttle_params(self, pusher_id, room_id, params):
+ async def set_throttle_params(self, pusher_id, room_id, params) -> None:
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
- yield self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
@@ -272,8 +266,7 @@ class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
- @defer.inlineCallbacks
- def add_pusher(
+ async def add_pusher(
self,
user_id,
access_token,
@@ -287,11 +280,11 @@ class PusherStore(PusherWorkerStore):
data,
last_stream_ordering,
profile_tag="",
- ):
+ ) -> None:
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
- yield self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@@ -316,15 +309,16 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
(user_id,),
)
- @defer.inlineCallbacks
- def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
+ async def delete_pusher_by_app_id_pushkey_user_id(
+ self, app_id, pushkey, user_id
+ ) -> None:
def delete_pusher_txn(txn, stream_id):
self._invalidate_cache_and_stream(
txn, self.get_if_user_has_pusher, (user_id,)
@@ -351,6 +345,6 @@ class PusherStore(PusherWorkerStore):
)
with self._pushers_id_gen.get_next() as stream_id:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1920a8a152..19ad1c056f 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -16,7 +16,7 @@
import abc
import logging
-from typing import List, Tuple
+from typing import List, Optional, Tuple
from twisted.internet import defer
@@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -56,9 +56,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
raise NotImplementedError()
- @cachedInlineCallbacks()
- def get_users_with_read_receipts_in_room(self, room_id):
- receipts = yield self.get_receipts_for_room(room_id, "m.read")
+ @cached()
+ async def get_users_with_read_receipts_in_room(self, room_id):
+ receipts = await self.get_receipts_for_room(room_id, "m.read")
return {r["user_id"] for r in receipts}
@cached(num_args=2)
@@ -84,9 +84,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
- @cachedInlineCallbacks(num_args=2)
- def get_receipts_for_user(self, user_id, receipt_type):
- rows = yield self.db_pool.simple_select_list(
+ @cached(num_args=2)
+ async def get_receipts_for_user(self, user_id, receipt_type):
+ rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@@ -95,8 +95,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows}
- @defer.inlineCallbacks
- def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
+ async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
sql = (
"SELECT rl.room_id, rl.event_id,"
@@ -110,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
- rows = yield self.db_pool.runInteraction(
+ rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
)
return {
@@ -122,56 +121,61 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows
}
- @defer.inlineCallbacks
- def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def get_linearized_receipts_for_rooms(
+ self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients.
Args:
- room_ids (list): List of room_ids.
- to_key (int): Max stream id to fetch receipts upto.
- from_key (int): Min stream id to fetch receipts from. None fetches
+ room_id: List of room_ids.
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
- list: A list of receipts.
+ A list of receipts.
"""
room_ids = set(room_ids)
if from_key is not None:
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
- room_ids = yield self._receipts_stream_cache.get_entities_changed(
+ room_ids = self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
- results = yield self._get_linearized_receipts_for_rooms(
+ results = await self._get_linearized_receipts_for_rooms(
room_ids, to_key, from_key=from_key
)
return [ev for res in results.values() for ev in res]
- def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ async def get_linearized_receipts_for_room(
+ self, room_id: str, to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""Get receipts for a single room for sending to clients.
Args:
- room_ids (str): The room id.
- to_key (int): Max stream id to fetch receipts upto.
- from_key (int): Min stream id to fetch receipts from. None fetches
+ room_ids: The room id.
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
- Deferred[list]: A list of receipts.
+ A list of receipts.
"""
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
- defer.succeed([])
+ return []
- return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+ return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
- @cachedInlineCallbacks(num_args=3, tree=True)
- def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ @cached(num_args=3, tree=True)
+ async def _get_linearized_receipts_for_room(
+ self, room_id: str, to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""See get_linearized_receipts_for_room
"""
@@ -195,7 +199,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rows
- rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
+ rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
@@ -212,9 +216,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
- inlineCallbacks=True,
)
- def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
return {}
@@ -243,7 +246,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
- txn_results = yield self.db_pool.runInteraction(
+ txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
@@ -346,7 +349,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def _invalidate_get_users_with_receipts_in_room(
- self, room_id, receipt_type, user_id
+ self, room_id: str, receipt_type: str, user_id: str
):
if receipt_type != "m.read":
return
@@ -472,15 +475,21 @@ class ReceiptsStore(ReceiptsWorkerStore):
return rx_ts
- @defer.inlineCallbacks
- def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
+ async def insert_receipt(
+ self,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: dict,
+ ) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
representations.
"""
if not event_ids:
- return
+ return None
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
@@ -507,13 +516,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
- linearized_event_id = yield self.db_pool.runInteraction(
+ linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
- event_ts = yield self.db_pool.runInteraction(
+ event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@@ -535,7 +544,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
now - event_ts,
)
- yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
+ await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 402ae25571..5986d32b18 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,9 +17,7 @@
import logging
import re
-from typing import Dict, List, Optional
-
-from twisted.internet.defer import Deferred
+from typing import Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -158,6 +156,37 @@ class RegistrationWorkerStore(SQLBaseStore):
"set_account_validity_for_user", set_account_validity_for_user_txn
)
+ async def get_expired_users(self):
+ """Get UserIDs of all expired users.
+
+ Users who are not active, or do not have profile information, are
+ excluded from the results.
+
+ Returns:
+ Deferred[List[UserID]]: List of expired user IDs
+ """
+
+ def get_expired_users_txn(txn, now_ms):
+ # We need to use pattern matching as profiles.user_id is confusingly just the
+ # user's localpart, whereas account_validity.user_id is a full user ID
+ sql = """
+ SELECT av.user_id from account_validity AS av
+ LEFT JOIN profiles as p
+ ON av.user_id LIKE '%%' || p.user_id || ':%%'
+ WHERE expiration_ts_ms <= ?
+ AND p.active = 1
+ """
+ txn.execute(sql, (now_ms,))
+ rows = txn.fetchall()
+
+ return [UserID.from_string(row[0]) for row in rows]
+
+ res = await self.db_pool.runInteraction(
+ "get_expired_users", get_expired_users_txn, self.clock.time_msec()
+ )
+
+ return res
+
async def set_renewal_token_for_user(
self, user_id: str, renewal_token: str
) -> None:
@@ -264,6 +293,54 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user",
)
+ async def get_info_for_users(
+ self, user_ids: List[str],
+ ):
+ """Return the user info for a given set of users
+
+ Args:
+ user_ids: A list of users to return information about
+
+ Returns:
+ Deferred[Dict[str, bool]]: A dictionary mapping each user ID to
+ a dict with the following keys:
+ * expired - whether this is an expired user
+ * deactivated - whether this is a deactivated user
+ """
+ # Get information of all our local users
+ def _get_info_for_users_txn(txn):
+ rows = []
+
+ for user_id in user_ids:
+ sql = """
+ SELECT u.name, u.deactivated, av.expiration_ts_ms
+ FROM users as u
+ LEFT JOIN account_validity as av
+ ON av.user_id = u.name
+ WHERE u.name = ?
+ """
+
+ txn.execute(sql, (user_id,))
+ row = txn.fetchone()
+ if row:
+ rows.append(row)
+
+ return rows
+
+ info_rows = await self.db_pool.runInteraction(
+ "get_info_for_users", _get_info_for_users_txn
+ )
+
+ return {
+ user_id: {
+ "expired": (
+ expiration is not None and self.clock.time_msec() >= expiration
+ ),
+ "deactivated": deactivated == 1,
+ }
+ for user_id, deactivated, expiration in info_rows
+ }
+
async def is_server_admin(self, user: UserID) -> bool:
"""Determines if a user is an admin of this homeserver.
@@ -304,7 +381,7 @@ class RegistrationWorkerStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+ "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
" access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
@@ -563,7 +640,7 @@ class RegistrationWorkerStore(SQLBaseStore):
id_server (str)
Returns:
- Deferred
+ Awaitable
"""
# We need to use an upsert, in case they user had already bound the
# threepid
@@ -952,6 +1029,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname=None,
admin=False,
user_type=None,
+ shadow_banned=False,
):
"""Attempts to register an account.
@@ -968,6 +1046,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
+ shadow_banned (bool): Whether the user is shadow-banned,
+ i.e. they may be told their requests succeeded but we ignore them.
Raises:
StoreError if the user_id could not be registered.
@@ -986,6 +1066,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
+ shadow_banned,
)
def _register_user(
@@ -999,6 +1080,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
+ shadow_banned,
):
user_id_obj = UserID.from_string(user_id)
@@ -1028,6 +1110,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
+ "shadow_banned": shadow_banned,
},
)
else:
@@ -1042,6 +1125,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
+ "shadow_banned": shadow_banned,
},
)
@@ -1077,7 +1161,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
- ) -> Deferred:
+ ) -> Awaitable:
"""Record a mapping from an external user id to a mxid
Args:
@@ -1345,43 +1429,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"validate_threepid_session_txn", validate_threepid_session_txn
)
- def upsert_threepid_validation_session(
- self,
- medium,
- address,
- client_secret,
- send_attempt,
- session_id,
- validated_at=None,
- ):
- """Upsert a threepid validation session
- Args:
- medium (str): The medium of the 3PID
- address (str): The address of the 3PID
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- send_attempt (int): The latest send_attempt on this session
- session_id (str): The id of this validation session
- validated_at (int|None): The unix timestamp in milliseconds of
- when the session was marked as valid
- """
- insertion_values = {
- "medium": medium,
- "address": address,
- "client_secret": client_secret,
- }
-
- if validated_at:
- insertion_values["validated_at"] = validated_at
-
- return self.db_pool.simple_upsert(
- table="threepid_validation_session",
- keyvalues={"session_id": session_id},
- values={"last_send_attempt": send_attempt},
- insertion_values=insertion_values,
- desc="upsert_threepid_validation_session",
- )
-
def start_or_continue_validation_session(
self,
medium,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f4008e6221..0142a856d5 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -23,6 +23,8 @@ from typing import Any, Dict, List, Optional, Tuple
from canonicaljson import json
+from twisted.internet import defer
+
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
@@ -35,10 +37,6 @@ from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
-OpsLevel = collections.namedtuple(
- "OpsLevel", ("ban_level", "kick_level", "redact_level")
-)
-
RatelimitOverride = collections.namedtuple(
"RatelimitOverride", ("messages_per_second", "burst_count")
)
@@ -344,6 +342,24 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ @defer.inlineCallbacks
+ def is_room_published(self, room_id):
+ """Check whether a room has been published in the local public room
+ directory.
+
+ Args:
+ room_id (str)
+ Returns:
+ bool: Whether the room is currently published in the room directory
+ """
+ # Get room information
+ room_info = yield self.get_room(room_id)
+ if not room_info:
+ defer.returnValue(False)
+
+ # Check the is_public value
+ defer.returnValue(room_info.get("is_public", False))
+
async def get_rooms_paginate(
self,
start: int,
@@ -552,6 +568,11 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
"""
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum, in order not to filter out events we should filter out when sending to
+ # the client.
+ if not self.config.retention_enabled:
+ defer.returnValue({"min_lifetime": None, "max_lifetime": None})
def get_retention_policy_for_room_txn(txn):
txn.execute(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index b2fcfc9bfe..161edbeccb 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -17,8 +17,6 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
@@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count,
)
- @defer.inlineCallbacks
- def _count_known_servers(self):
+ async def _count_known_servers(self):
"""
Count the servers that this server knows about.
@@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query)
return list(txn)[0][0]
- count = yield self.db_pool.runInteraction("get_known_servers", _transact)
+ count = await self.db_pool.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
@@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id",
- list_name="event_ids",
- inlineCallbacks=True,
+ cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
)
- def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+ async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
@@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup
Returns:
- Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+ dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
@@ -772,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
- def get_membership_from_event_ids(
+ async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs.
"""
- return self.db_pool.simple_select_many_batch(
+ return await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
diff --git a/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
new file mode 100644
index 0000000000..e744c02fe8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Add a batch number to track changes to profiles and the
+ * order they're made in so we can replicate user profiles
+ * to other hosts as they change
+ */
+ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL;
+
+/*
+ * Index on the batch number so we can get profiles
+ * by their batch
+ */
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+
+/*
+ * A table to track what batch of user profiles has been
+ * synced to what profile replication target.
+ */
+CREATE TABLE profile_replication_status (
+ host TEXT NOT NULL,
+ last_synced_batch BIGINT NOT NULL
+);
diff --git a/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
new file mode 100644
index 0000000000..96051ac179
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * A flag saying whether the user owning the profile has been deactivated
+ * This really belongs on the users table, not here, but the users table
+ * stores users by their full user_id and profiles stores them by localpart,
+ * so we can't easily join between the two tables. Plus, the batch number
+ * realy ought to represent data in this table that has changed.
+ */
+ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
new file mode 100644
index 0000000000..7542ab8cbd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE UNIQUE INDEX profile_replication_status_idx ON profile_replication_status(host);
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
new file mode 100644
index 0000000000..260b009b48
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A shadow-banned user may be told that their requests succeeded when they were
+-- actually ignored.
+ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
new file mode 100644
index 0000000000..15421b99ac
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This table is no longer used.
+DROP TABLE IF EXISTS presence_allow_inbound;
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..20c5af2eb7 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
@@ -658,10 +658,19 @@ CREATE TABLE presence_stream (
+CREATE TABLE profile_replication_status (
+ host text NOT NULL,
+ last_synced_batch bigint NOT NULL
+);
+
+
+
CREATE TABLE profiles (
user_id text NOT NULL,
displayname text,
- avatar_url text
+ avatar_url text,
+ batch bigint,
+ active smallint DEFAULT 1 NOT NULL
);
@@ -1788,6 +1797,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id);
+CREATE INDEX profiles_batch_idx ON profiles USING btree (batch);
+
+
+
CREATE INDEX public_room_index ON rooms USING btree (is_public);
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..e28ec3fa45 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us
CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) );
CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) );
CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL );
-CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) );
+CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) );
CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) );
CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER );
CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) );
@@ -202,6 +202,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id);
CREATE INDEX group_invites_u_idx ON group_invites(user_id);
CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL );
CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL );
CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 96e0378e50..991233a9bc 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
- inlineCallbacks=True,
)
- def _get_state_group_for_events(self, event_ids):
+ async def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index aaf225894e..497f607703 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,15 +39,17 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
-from typing import Optional
+from typing import Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer
+from synapse.api.filtering import Filter
+from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -68,8 +70,12 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
- direction, column_names, from_token, to_token, engine
-):
+ direction: str,
+ column_names: Tuple[str, str],
+ from_token: Optional[Tuple[int, int]],
+ to_token: Optional[Tuple[int, int]],
+ engine: BaseDatabaseEngine,
+) -> str:
"""Creates an SQL expression to bound the columns by the pagination
tokens.
@@ -90,21 +96,19 @@ def generate_pagination_where_clause(
token, but include those that match the to token.
Args:
- direction (str): Whether we're paginating backwards("b") or
- forwards ("f").
- column_names (tuple[str, str]): The column names to bound. Must *not*
- be user defined as these get inserted directly into the SQL
- statement without escapes.
- from_token (tuple[int, int]|None): The start point for the pagination.
- This is an exclusive minimum bound if direction is "f", and an
- inclusive maximum bound if direction is "b".
- to_token (tuple[int, int]|None): The endpoint point for the pagination.
- This is an inclusive maximum bound if direction is "f", and an
- exclusive minimum bound if direction is "b".
+ direction: Whether we're paginating backwards("b") or forwards ("f").
+ column_names: The column names to bound. Must *not* be user defined as
+ these get inserted directly into the SQL statement without escapes.
+ from_token: The start point for the pagination. This is an exclusive
+ minimum bound if direction is "f", and an inclusive maximum bound if
+ direction is "b".
+ to_token: The endpoint point for the pagination. This is an inclusive
+ maximum bound if direction is "f", and an exclusive minimum bound if
+ direction is "b".
engine: The database engine to generate the clauses for
Returns:
- str: The sql expression
+ The sql expression
"""
assert direction in ("b", "f")
@@ -132,7 +136,12 @@ def generate_pagination_where_clause(
return " AND ".join(where_clause)
-def _make_generic_sql_bound(bound, column_names, values, engine):
+def _make_generic_sql_bound(
+ bound: str,
+ column_names: Tuple[str, str],
+ values: Tuple[Optional[int], int],
+ engine: BaseDatabaseEngine,
+) -> str:
"""Create an SQL expression that bounds the given column names by the
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
@@ -142,18 +151,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
out manually.
Args:
- bound (str): The comparison operator to use. One of ">", "<", ">=",
+ bound: The comparison operator to use. One of ">", "<", ">=",
"<=", where the values are on the left and columns on the right.
- names (tuple[str, str]): The column names. Must *not* be user defined
+ names: The column names. Must *not* be user defined
as these get inserted directly into the SQL statement without
escapes.
- values (tuple[int|None, int]): The values to bound the columns by. If
+ values: The values to bound the columns by. If
the first value is None then only creates a bound on the second
column.
engine: The database engine to generate the SQL for
Returns:
- str
+ The SQL statement
"""
assert bound in (">", "<", ">=", "<=")
@@ -193,7 +202,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
)
-def filter_to_clause(event_filter):
+def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -291,34 +300,35 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_room_min_stream_ordering(self):
raise NotImplementedError()
- @defer.inlineCallbacks
- def get_room_events_stream_for_rooms(
- self, room_ids, from_key, to_key, limit=0, order="DESC"
- ):
+ async def get_room_events_stream_for_rooms(
+ self,
+ room_ids: Iterable[str],
+ from_key: str,
+ to_key: str,
+ limit: int = 0,
+ order: str = "DESC",
+ ) -> Dict[str, Tuple[List[EventBase], str]]:
"""Get new room events in stream ordering since `from_key`.
Args:
- room_id (str)
- from_key (str): Token from which no events are returned before
- to_key (str): Token from which no events are returned after. (This
+ room_ids
+ from_key: Token from which no events are returned before
+ to_key: Token from which no events are returned after. (This
is typically the current stream token)
- limit (int): Maximum number of events to return
- order (str): Either "DESC" or "ASC". Determines which events are
+ limit: Maximum number of events to return
+ order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
- Deferred[dict[str,tuple[list[FrozenEvent], str]]]
- A map from room id to a tuple containing:
- - list of recent events in the room
- - stream ordering key for the start of the chunk of events returned.
+ A map from room id to a tuple containing:
+ - list of recent events in the room
+ - stream ordering key for the start of the chunk of events returned.
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
- room_ids = yield self._events_stream_cache.get_entities_changed(
- room_ids, from_id
- )
+ room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
if not room_ids:
return {}
@@ -326,7 +336,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
- res = yield make_deferred_yieldable(
+ res = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -361,28 +371,30 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if self._events_stream_cache.has_entity_changed(room_id, from_key)
}
- @defer.inlineCallbacks
- def get_room_events_stream_for_room(
- self, room_id, from_key, to_key, limit=0, order="DESC"
- ):
-
+ async def get_room_events_stream_for_room(
+ self,
+ room_id: str,
+ from_key: str,
+ to_key: str,
+ limit: int = 0,
+ order: str = "DESC",
+ ) -> Tuple[List[EventBase], str]:
"""Get new room events in stream ordering since `from_key`.
Args:
- room_id (str)
- from_key (str): Token from which no events are returned before
- to_key (str): Token from which no events are returned after. (This
+ room_id
+ from_key: Token from which no events are returned before
+ to_key: Token from which no events are returned after. (This
is typically the current stream token)
- limit (int): Maximum number of events to return
- order (str): Either "DESC" or "ASC". Determines which events are
+ limit: Maximum number of events to return
+ order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
- events (in ascending order) and the token from the start of
- the chunk of events returned.
+ The list of events (in ascending order) and the token from the start
+ of the chunk of events returned.
"""
if from_key == to_key:
return [], from_key
@@ -390,9 +402,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
- has_changed = yield self._events_stream_cache.has_entity_changed(
- room_id, from_id
- )
+ has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
if not has_changed:
return [], from_key
@@ -410,9 +420,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
- rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f)
+ rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
- ret = yield self.get_events_as_list(
+ ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -430,8 +440,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
- @defer.inlineCallbacks
- def get_membership_changes_for_user(self, user_id, from_key, to_key):
+ async def get_membership_changes_for_user(self, user_id, from_key, to_key):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -460,9 +469,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows
- rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f)
+ rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
- ret = yield self.get_events_as_list(
+ ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -470,27 +479,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
- @defer.inlineCallbacks
- def get_recent_events_for_room(self, room_id, limit, end_token):
+ async def get_recent_events_for_room(
+ self, room_id: str, limit: int, end_token: str
+ ) -> Tuple[List[EventBase], str]:
"""Get the most recent events in the room in topological ordering.
Args:
- room_id (str)
- limit (int)
- end_token (str): The stream token representing now.
+ room_id
+ limit
+ end_token: The stream token representing now.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
- events and a token pointing to the start of the returned
- events.
- The events returned are in ascending order.
+ A list of events and a token pointing to the start of the returned
+ events. The events returned are in ascending order.
"""
- rows, token = yield self.get_recent_event_ids_for_room(
+ rows, token = await self.get_recent_event_ids_for_room(
room_id, limit, end_token
)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -498,20 +506,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
- @defer.inlineCallbacks
- def get_recent_event_ids_for_room(self, room_id, limit, end_token):
+ async def get_recent_event_ids_for_room(
+ self, room_id: str, limit: int, end_token: str
+ ) -> Tuple[List[_EventDictReturn], str]:
"""Get the most recent events in the room in topological ordering.
Args:
- room_id (str)
- limit (int)
- end_token (str): The stream token representing now.
+ room_id
+ limit
+ end_token: The stream token representing now.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
- _EventDictReturn and a token pointing to the start of the returned
- events.
- The events returned are in ascending order.
+ A list of _EventDictReturn and a token pointing to the start of the
+ returned events. The events returned are in ascending order.
"""
# Allow a zero limit here, and no-op.
if limit == 0:
@@ -519,7 +526,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
- rows, token = yield self.db_pool.runInteraction(
+ rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
@@ -532,12 +539,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, token
- def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
+ def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
"""Gets details of the first event in a room at or before a stream ordering
Args:
- room_id (str):
- stream_ordering (int):
+ room_id:
+ stream_ordering:
Returns:
Deferred[(int, int, str)]:
@@ -574,55 +581,67 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return "t%d-%d" % (topo, token)
- def get_stream_token_for_event(self, event_id):
- """The stream token for an event
+ async def get_stream_id_for_event(self, event_id: str) -> int:
+ """The stream ID for an event
Args:
- event_id(str): The id of the event to look up a stream token for.
+ event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A deferred "s%d" stream token.
+ A stream ID.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
- ).addCallback(lambda row: "s%d" % (row,))
+ )
- def get_topological_token_for_event(self, event_id):
+ async def get_stream_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
Args:
- event_id(str): The id of the event to look up a stream token for.
+ event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A deferred "t%d-%d" topological token.
+ A "s%d" stream token.
"""
- return self.db_pool.simple_select_one(
+ stream_id = await self.get_stream_id_for_event(event_id)
+ return "s%d" % (stream_id,)
+
+ async def get_topological_token_for_event(self, event_id: str) -> str:
+ """The stream token for an event
+ Args:
+ event_id: The id of the event to look up a stream token for.
+ Raises:
+ StoreError if the event wasn't in the database.
+ Returns:
+ A "t%d-%d" topological token.
+ """
+ row = await self.db_pool.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
- ).addCallback(
- lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
)
+ return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
- def get_max_topological_token(self, room_id, stream_key):
+ async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
"""Get the max topological token in a room before the given stream
ordering.
Args:
- room_id (str)
- stream_key (int)
+ room_id
+ stream_key
Returns:
- Deferred[int]
+ The maximum topological token.
"""
sql = (
"SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
- return self.db_pool.execute(
+ row = await self.db_pool.execute(
"get_max_topological_token", None, sql, room_id, stream_key
- ).addCallback(lambda r: r[0][0] if r else 0)
+ )
+ return row[0][0] if row else 0
def _get_max_topological_txn(self, txn, room_id):
txn.execute(
@@ -634,16 +653,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows[0][0] if rows else 0
@staticmethod
- def _set_before_and_after(events, rows, topo_order=True):
+ def _set_before_and_after(
+ events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
+ ):
"""Inserts ordering information to events' internal metadata from
the DB rows.
Args:
- events (list[FrozenEvent])
- rows (list[_EventDictReturn])
- topo_order (bool): Whether the events were ordered topologically
- or by stream ordering. If true then all rows should have a non
- null topological_ordering.
+ events
+ rows
+ topo_order: Whether the events were ordered topologically or by stream
+ ordering. If true then all rows should have a non null
+ topological_ordering.
"""
for event, row in zip(events, rows):
stream = row.stream_ordering
@@ -656,25 +677,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
internal.after = str(RoomStreamToken(topo, stream))
internal.order = (int(topo) if topo else 0, int(stream))
- @defer.inlineCallbacks
- def get_events_around(
- self, room_id, event_id, before_limit, after_limit, event_filter=None
- ):
+ async def get_events_around(
+ self,
+ room_id: str,
+ event_id: str,
+ before_limit: int,
+ after_limit: int,
+ event_filter: Optional[Filter] = None,
+ ) -> dict:
"""Retrieve events and pagination tokens around a given event in a
room.
-
- Args:
- room_id (str)
- event_id (str)
- before_limit (int)
- after_limit (int)
- event_filter (Filter|None)
-
- Returns:
- dict
"""
- results = yield self.db_pool.runInteraction(
+ results = await self.db_pool.runInteraction(
"get_events_around",
self._get_events_around_txn,
room_id,
@@ -684,11 +699,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events_before = yield self.get_events_as_list(
+ events_before = await self.get_events_as_list(
list(results["before"]["event_ids"]), get_prev_content=True
)
- events_after = yield self.get_events_as_list(
+ events_after = await self.get_events_as_list(
list(results["after"]["event_ids"]), get_prev_content=True
)
@@ -700,17 +715,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
}
def _get_events_around_txn(
- self, txn, room_id, event_id, before_limit, after_limit, event_filter
- ):
+ self,
+ txn,
+ room_id: str,
+ event_id: str,
+ before_limit: int,
+ after_limit: int,
+ event_filter: Optional[Filter],
+ ) -> dict:
"""Retrieves event_ids and pagination tokens around a given event in a
room.
Args:
- room_id (str)
- event_id (str)
- before_limit (int)
- after_limit (int)
- event_filter (Filter|None)
+ room_id
+ event_id
+ before_limit
+ after_limit
+ event_filter
Returns:
dict
@@ -758,22 +779,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token},
}
- @defer.inlineCallbacks
- def get_all_new_events_stream(self, from_id, current_id, limit):
+ async def get_all_new_events_stream(
+ self, from_id: int, current_id: int, limit: int
+ ) -> Tuple[int, List[EventBase]]:
"""Get all new events
Returns all events with from_id < stream_ordering <= current_id.
Args:
- from_id (int): the stream_ordering of the last event we processed
- current_id (int): the stream_ordering of the most recently processed event
- limit (int): the maximum number of events to return
+ from_id: the stream_ordering of the last event we processed
+ current_id: the stream_ordering of the most recently processed event
+ limit: the maximum number of events to return
Returns:
- Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
- `next_id` is the next value to pass as `from_id` (it will either be the
- stream_ordering of the last returned event, or, if fewer than `limit` events
- were found, `current_id`.
+ A tuple of (next_id, events), where `next_id` is the next value to
+ pass as `from_id` (it will either be the stream_ordering of the
+ last returned event, or, if fewer than `limit` events were found,
+ the `current_id`).
"""
def get_all_new_events_stream_txn(txn):
@@ -795,11 +817,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.db_pool.runInteraction(
+ upper_bound, event_ids = await self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
- events = yield self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(event_ids)
return upper_bound, events
@@ -817,21 +839,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_federation_out_pos",
)
- async def update_federation_out_pos(self, typ, stream_id):
+ async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
if self._need_to_reset_federation_stream_positions:
await self.db_pool.runInteraction(
"_reset_federation_positions_txn", self._reset_federation_positions_txn
)
self._need_to_reset_federation_stream_positions = False
- return await self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ, "instance_name": self._instance_name},
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
- def _reset_federation_positions_txn(self, txn):
+ def _reset_federation_positions_txn(self, txn) -> None:
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
@@ -892,39 +914,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
values={"stream_id": stream_id},
)
- def has_room_changed_since(self, room_id, stream_id):
+ def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
def _paginate_room_events_txn(
self,
txn,
- room_id,
- from_token,
- to_token=None,
- direction="b",
- limit=-1,
- event_filter=None,
- ):
+ room_id: str,
+ from_token: RoomStreamToken,
+ to_token: Optional[RoomStreamToken] = None,
+ direction: str = "b",
+ limit: int = -1,
+ event_filter: Optional[Filter] = None,
+ ) -> Tuple[List[_EventDictReturn], str]:
"""Returns list of events before or after a given token.
Args:
txn
- room_id (str)
- from_token (RoomStreamToken): The token used to stream from
- to_token (RoomStreamToken|None): A token which if given limits the
- results to only those before
- direction(char): Either 'b' or 'f' to indicate whether we are
- paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return.
- event_filter (Filter|None): If provided filters the events to
+ room_id
+ from_token: The token used to stream from
+ to_token: A token which if given limits the results to only those before
+ direction: Either 'b' or 'f' to indicate whether we are paginating
+ forwards or backwards from `from_key`.
+ limit: The maximum number of events to return.
+ event_filter: If provided filters the events to
those that match the filter.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
- as a list of _EventDictReturn and a token that points to the end
- of the result set. If no events are returned then the end of the
- stream has been reached (i.e. there are no events between
- `from_token` and `to_token`), or `limit` is zero.
+ A list of _EventDictReturn and a token that points to the end of the
+ result set. If no events are returned then the end of the stream has
+ been reached (i.e. there are no events between `from_token` and
+ `to_token`), or `limit` is zero.
"""
assert int(limit) >= 0
@@ -1008,35 +1028,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, str(next_token)
- @defer.inlineCallbacks
- def paginate_room_events(
- self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
- ):
+ async def paginate_room_events(
+ self,
+ room_id: str,
+ from_key: str,
+ to_key: Optional[str] = None,
+ direction: str = "b",
+ limit: int = -1,
+ event_filter: Optional[Filter] = None,
+ ) -> Tuple[List[EventBase], str]:
"""Returns list of events before or after a given token.
Args:
- room_id (str)
- from_key (str): The token used to stream from
- to_key (str|None): A token which if given limits the results to
- only those before
- direction(char): Either 'b' or 'f' to indicate whether we are
- paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return.
- event_filter (Filter|None): If provided filters the events to
- those that match the filter.
+ room_id
+ from_key: The token used to stream from
+ to_key: A token which if given limits the results to only those before
+ direction: Either 'b' or 'f' to indicate whether we are paginating
+ forwards or backwards from `from_key`.
+ limit: The maximum number of events to return.
+ event_filter: If provided filters the events to those that match the filter.
Returns:
- tuple[list[FrozenEvent], str]: Returns the results as a list of
- events and a token that points to the end of the result set. If no
- events are returned then the end of the stream has been reached
- (i.e. there are no events between `from_key` and `to_key`).
+ The results as a list of events and a token that points to the end
+ of the result set. If no events are returned then the end of the
+ stream has been reached (i.e. there are no events between `from_key`
+ and `to_key`).
"""
from_key = RoomStreamToken.parse(from_key)
if to_key:
to_key = RoomStreamToken.parse(to_key)
- rows, token = yield self.db_pool.runInteraction(
+ rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
@@ -1047,7 +1070,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -1057,8 +1080,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
class StreamStore(StreamWorkerStore):
- def get_room_max_stream_ordering(self):
+ def get_room_max_stream_ordering(self) -> int:
return self._stream_id_gen.get_current_token()
- def get_room_min_stream_ordering(self):
+ def get_room_min_stream_ordering(self) -> int:
return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index ab6cb2c1f6..e3547e53b3 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -13,35 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import operator
-
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
class UserErasureWorkerStore(SQLBaseStore):
@cached()
- def is_user_erased(self, user_id):
+ async def is_user_erased(self, user_id: str) -> bool:
"""
Check if the given user id has requested erasure
Args:
- user_id (str): full user id to check
+ user_id: full user id to check
Returns:
- Deferred[bool]: True if the user has requested erasure
+ True if the user has requested erasure
"""
- return self.db_pool.simple_select_onecol(
+ result = await self.db_pool.simple_select_onecol(
table="erased_users",
keyvalues={"user_id": user_id},
retcol="1",
desc="is_user_erased",
- ).addCallback(operator.truth)
+ )
+ return bool(result)
- @cachedList(
- cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
- )
- def are_users_erased(self, user_ids):
+ @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
+ async def are_users_erased(self, user_ids):
"""
Checks which users in a list have requested erasure
@@ -49,14 +46,14 @@ class UserErasureWorkerStore(SQLBaseStore):
user_ids (iterable[str]): full user id to check
Returns:
- Deferred[dict[str, bool]]:
+ dict[str, bool]:
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
@@ -65,8 +62,7 @@ class UserErasureWorkerStore(SQLBaseStore):
)
erased_users = {row["user_id"] for row in rows}
- res = {u: u in erased_users for u in user_ids}
- return res
+ return {u: u in erased_users for u in user_ids}
class UserErasureStore(UserErasureWorkerStore):
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index e2ddd01290..0bf772d4d1 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,7 +16,7 @@
import contextlib
import threading
from collections import deque
-from typing import Dict, Set, Tuple
+from typing import Dict, Set
from typing_extensions import Deque
@@ -158,63 +158,13 @@ class StreamIdGenerator(object):
return self._current
+ def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
-class ChainedIdGenerator(object):
- """Used to generate new stream ids where the stream must be kept in sync
- with another stream. It generates pairs of IDs, the first element is an
- integer ID for this stream, the second element is the ID for the stream
- that this stream needs to be kept in sync with."""
-
- def __init__(self, chained_generator, db_conn, table, column):
- self.chained_generator = chained_generator
- self._table = table
- self._lock = threading.Lock()
- self._current_max = _load_current_id(db_conn, table, column)
- self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
-
- def get_next(self):
- """
- Usage:
- with stream_id_gen.get_next() as (stream_id, chained_id):
- # ... persist event ...
- """
- with self._lock:
- self._current_max += 1
- next_id = self._current_max
- chained_id = self.chained_generator.get_current_token()
-
- self._unfinished_ids.append((next_id, chained_id))
-
- @contextlib.contextmanager
- def manager():
- try:
- yield (next_id, chained_id)
- finally:
- with self._lock:
- self._unfinished_ids.remove((next_id, chained_id))
-
- return manager()
-
- def get_current_token(self):
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
+ For streams with single writers this is equivalent to
+ `get_current_token`.
"""
- with self._lock:
- if self._unfinished_ids:
- stream_id, chained_id = self._unfinished_ids[0]
- return stream_id - 1, chained_id
-
- return self._current_max, self.chained_generator.get_current_token()
-
- def advance(self, token: int):
- """Stub implementation for advancing the token when receiving updates
- over replication; raises an exception as this instance should be the
- only source of updates.
- """
-
- raise Exception(
- "Attempted to advance token on source for table %r", self._table
- )
+ return self.get_current_token()
class MultiWriterIdGenerator:
@@ -298,7 +248,7 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
- assert self.get_current_token() < next_id
+ assert self.get_current_token_for_writer(self._instance_name) < next_id
with self._lock:
self._unfinished_ids.add(next_id)
@@ -344,16 +294,18 @@ class MultiWriterIdGenerator:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, next_id)
- def get_current_token(self, instance_name: str = None) -> int:
- """Gets the current position of a named writer (defaults to current
- instance).
-
- Returns 0 if we don't have a position for the named writer (likely due
- to it being a new writer).
+ def get_current_token(self) -> int:
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
"""
- if instance_name is None:
- instance_name = self._instance_name
+ # Currently we don't support this operation, as it's not obvious how to
+ # condense the stream positions of multiple writers into a single int.
+ raise NotImplementedError()
+
+ def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+ """
with self._lock:
return self._current_positions.get(instance_name, 0)
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 393e34b9fb..7ab46f42bf 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -39,7 +39,7 @@ class EventSources(object):
self.store = hs.get_datastore()
def get_current_token(self) -> StreamToken:
- push_rules_key, _ = self.store.get_push_rules_stream_token()
+ push_rules_key = self.store.get_max_push_rules_stream_id()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
groups_key = self.store.get_group_stream_token()
diff --git a/synapse/third_party_rules/__init__.py b/synapse/third_party_rules/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/synapse/third_party_rules/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
new file mode 100644
index 0000000000..2519e05ae0
--- /dev/null
+++ b/synapse/third_party_rules/access_rules.py
@@ -0,0 +1,947 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import email.utils
+import logging
+from typing import Dict, List, Optional, Tuple
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.api.errors import SynapseError
+from synapse.config._base import ConfigError
+from synapse.events import EventBase
+from synapse.module_api import ModuleApi
+from synapse.types import Requester, StateMap, UserID, get_domain_from_id
+
+logger = logging.getLogger(__name__)
+
+ACCESS_RULES_TYPE = "im.vector.room.access_rules"
+
+
+class AccessRules:
+ DIRECT = "direct"
+ RESTRICTED = "restricted"
+ UNRESTRICTED = "unrestricted"
+
+
+VALID_ACCESS_RULES = (
+ AccessRules.DIRECT,
+ AccessRules.RESTRICTED,
+ AccessRules.UNRESTRICTED,
+)
+
+# Rules to which we need to apply the power levels restrictions.
+#
+# These are all of the rules that neither:
+# * forbid users from joining based on a server blacklist (which means that there
+# is no need to apply power level restrictions), nor
+# * target direct chats (since we allow both users to be room admins in this case).
+#
+# The power-level restrictions, when they are applied, prevent the following:
+# * the default power level for users (users_default) being set to anything other than 0.
+# * a non-default power level being assigned to any user which would be forbidden from
+# joining a restricted room.
+RULES_WITH_RESTRICTED_POWER_LEVELS = (AccessRules.UNRESTRICTED,)
+
+
+class RoomAccessRules(object):
+ """Implementation of the ThirdPartyEventRules module API that allows federation admins
+ to define custom rules for specific events and actions.
+ Implements the custom behaviour for the "im.vector.room.access_rules" state event.
+
+ Takes a config in the format:
+
+ third_party_event_rules:
+ module: third_party_rules.RoomAccessRules
+ config:
+ # List of domains (server names) that can't be invited to rooms if the
+ # "restricted" rule is set. Defaults to an empty list.
+ domains_forbidden_when_restricted: []
+
+ # Identity server to use when checking the HS an email address belongs to
+ # using the /info endpoint. Required.
+ id_server: "vector.im"
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(
+ self, config: Dict, module_api: ModuleApi,
+ ):
+ self.id_server = config["id_server"]
+ self.module_api = module_api
+
+ self.domains_forbidden_when_restricted = config.get(
+ "domains_forbidden_when_restricted", []
+ )
+
+ @staticmethod
+ def parse_config(config: Dict) -> Dict:
+ """Parses and validates the options specified in the homeserver config.
+
+ Args:
+ config: The config dict.
+
+ Returns:
+ The config dict.
+
+ Raises:
+ ConfigError: If there was an issue with the provided module configuration.
+ """
+ if "id_server" not in config:
+ raise ConfigError("No IS for event rules TchapEventRules")
+
+ return config
+
+ async def on_create_room(
+ self, requester: Requester, config: Dict, is_requester_admin: bool,
+ ) -> bool:
+ """Implements synapse.events.ThirdPartyEventRules.on_create_room.
+
+ Checks if a im.vector.room.access_rules event is being set during room creation.
+ If yes, make sure the event is correct. Otherwise, append an event with the
+ default rule to the initial state.
+
+ Checks if a m.rooms.power_levels event is being set during room creation.
+ If yes, make sure the event is allowed. Otherwise, set power_level_content_override
+ in the config dict to our modified version of the default room power levels.
+
+ Args:
+ requester: The user who is making the createRoom request.
+ config: The createRoom config dict provided by the user.
+ is_requester_admin: Whether the requester is a Synapse admin.
+
+ Returns:
+ Whether the request is allowed.
+
+ Raises:
+ SynapseError: If the createRoom config dict is invalid or its contents blocked.
+ """
+ is_direct = config.get("is_direct")
+ preset = config.get("preset")
+ access_rule = None
+ join_rule = None
+
+ # If there's a rules event in the initial state, check if it complies with the
+ # spec for im.vector.room.access_rules and deny the request if not.
+ for event in config.get("initial_state", []):
+ if event["type"] == ACCESS_RULES_TYPE:
+ access_rule = event["content"].get("rule")
+
+ # Make sure the event has a valid content.
+ if access_rule is None:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule name is valid.
+ if access_rule not in VALID_ACCESS_RULES:
+ raise SynapseError(400, "Invalid access rule")
+
+ if (is_direct and access_rule != AccessRules.DIRECT) or (
+ access_rule == AccessRules.DIRECT and not is_direct
+ ):
+ raise SynapseError(400, "Invalid access rule")
+
+ if event["type"] == EventTypes.JoinRules:
+ join_rule = event["content"].get("join_rule")
+
+ if access_rule is None:
+ # If there's no access rules event in the initial state, create one with the
+ # default setting.
+ if is_direct:
+ default_rule = AccessRules.DIRECT
+ else:
+ # If the default value for non-direct chat changes, we should make another
+ # case here for rooms created with either a "public" join_rule or the
+ # "public_chat" preset to make sure those keep defaulting to "restricted"
+ default_rule = AccessRules.RESTRICTED
+
+ if not config.get("initial_state"):
+ config["initial_state"] = []
+
+ config["initial_state"].append(
+ {
+ "type": ACCESS_RULES_TYPE,
+ "state_key": "",
+ "content": {"rule": default_rule},
+ }
+ )
+
+ access_rule = default_rule
+
+ # Check that the preset in use is compatible with the access rule, whether it's
+ # user-defined or the default.
+ #
+ # Direct rooms may not have their join_rules set to JoinRules.PUBLIC.
+ if (
+ join_rule == JoinRules.PUBLIC or preset == RoomCreationPreset.PUBLIC_CHAT
+ ) and access_rule == AccessRules.DIRECT:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Check if the creator can override values for the power levels.
+ allowed = self._is_power_level_content_allowed(
+ config.get("power_level_content_override", {}), access_rule
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content override")
+
+ use_default_power_levels = True
+ if config.get("power_level_content_override"):
+ use_default_power_levels = False
+
+ # Second loop for events we need to know the current rule to process.
+ for event in config.get("initial_state", []):
+ if event["type"] == EventTypes.PowerLevels:
+ allowed = self._is_power_level_content_allowed(
+ event["content"], access_rule
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content")
+
+ use_default_power_levels = False
+
+ # If power levels were not overridden by the user, override with DINUM's preferred
+ # defaults instead
+ if use_default_power_levels:
+ config["power_level_content_override"] = self._get_default_power_levels(
+ requester.user.to_string()
+ )
+
+ return True
+
+ # If power levels are not overridden by the user during room creation, the following
+ # rules are used instead. Changes from Synapse's default power levels are noted.
+ #
+ # The same power levels are currently applied regardless of room preset.
+ @staticmethod
+ def _get_default_power_levels(user_id: str) -> Dict:
+ return {
+ "users": {user_id: 100},
+ "users_default": 0,
+ "events": {
+ EventTypes.Name: 50,
+ EventTypes.PowerLevels: 100,
+ EventTypes.RoomHistoryVisibility: 100,
+ EventTypes.CanonicalAlias: 50,
+ EventTypes.RoomAvatar: 50,
+ EventTypes.Tombstone: 100,
+ EventTypes.ServerACL: 100,
+ EventTypes.RoomEncryption: 100,
+ },
+ "events_default": 0,
+ "state_default": 100, # Admins should be the only ones to perform other tasks
+ "ban": 50,
+ "kick": 50,
+ "redact": 50,
+ "invite": 50, # All rooms should require mod to invite, even private
+ }
+
+ @defer.inlineCallbacks
+ def check_threepid_can_be_invited(
+ self, medium: str, address: str, state_events: StateMap[EventBase],
+ ) -> bool:
+ """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited.
+
+ Check if a threepid can be invited to the room via a 3PID invite given the current
+ rules and the threepid's address, by retrieving the HS it's mapped to from the
+ configured identity server, and checking if we can invite users from it.
+
+ Args:
+ medium: The medium of the threepid.
+ address: The address of the threepid.
+ state_events: A dict mapping (event type, state key) to state event.
+ State events in the room the threepid is being invited to.
+
+ Returns:
+ Whether the threepid invite is allowed.
+ """
+ rule = self._get_rule_from_state(state_events)
+
+ if medium != "email":
+ return False
+
+ if rule != AccessRules.RESTRICTED:
+ # Only "restricted" requires filtering 3PID invites. We don't need to do
+ # anything for "direct" here, because only "restricted" requires filtering
+ # based on the HS the address is mapped to.
+ return True
+
+ parsed_address = email.utils.parseaddr(address)[1]
+ if parsed_address != address:
+ # Avoid reproducing the security issue described here:
+ # https://matrix.org/blog/2019/04/18/security-update-sydent-1-0-2
+ # It's probably not worth it but let's just be overly safe here.
+ return False
+
+ # Get the HS this address belongs to from the identity server.
+ res = yield self.module_api.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/info" % (self.id_server,),
+ {"medium": medium, "address": address},
+ )
+
+ # Look for a domain that's not forbidden from being invited.
+ if not res.get("hs"):
+ return False
+ if res.get("hs") in self.domains_forbidden_when_restricted:
+ return False
+
+ return True
+
+ async def check_event_allowed(
+ self, event: EventBase, state_events: StateMap[EventBase],
+ ) -> bool:
+ """Implements synapse.events.ThirdPartyEventRules.check_event_allowed.
+
+ Checks the event's type and the current rule and calls the right function to
+ determine whether the event can be allowed.
+
+ Args:
+ event: The event to check.
+ state_events: A dict mapping (event type, state key) to state event.
+ State events in the room the event originated from.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ if event.type == ACCESS_RULES_TYPE:
+ return await self._on_rules_change(event, state_events)
+
+ # We need to know the rule to apply when processing the event types below.
+ rule = self._get_rule_from_state(state_events)
+
+ if event.type == EventTypes.PowerLevels:
+ return self._is_power_level_content_allowed(
+ event.content, rule, on_room_creation=False
+ )
+
+ if event.type == EventTypes.Member or event.type == EventTypes.ThirdPartyInvite:
+ return await self._on_membership_or_invite(event, rule, state_events)
+
+ if event.type == EventTypes.JoinRules:
+ return self._on_join_rule_change(event, rule)
+
+ if event.type == EventTypes.RoomAvatar:
+ return self._on_room_avatar_change(event, rule)
+
+ if event.type == EventTypes.Name:
+ return self._on_room_name_change(event, rule)
+
+ if event.type == EventTypes.Topic:
+ return self._on_room_topic_change(event, rule)
+
+ return True
+
+ async def check_visibility_can_be_modified(
+ self, room_id: str, state_events: StateMap[EventBase], new_visibility: str
+ ) -> bool:
+ """Implements
+ synapse.events.ThirdPartyEventRules.check_visibility_can_be_modified
+
+ Determines whether a room can be published, or removed from, the public room
+ list. A room is published if its visibility is set to "public". Otherwise,
+ its visibility is "private". A room with access rule other than "restricted"
+ may not be published.
+
+ Args:
+ room_id: The ID of the room.
+ state_events: A dict mapping (event type, state key) to state event.
+ State events in the room.
+ new_visibility: The new visibility state. Either "public" or "private".
+
+ Returns:
+ Whether the room is allowed to be published to, or removed from, the public
+ rooms directory.
+ """
+ # We need to know the rule to apply when processing the event types below.
+ rule = self._get_rule_from_state(state_events)
+
+ # Allow adding a room to the public rooms list only if it is restricted
+ if new_visibility == "public":
+ return rule == AccessRules.RESTRICTED
+
+ # By default a room is created as "restricted", meaning it is allowed to be
+ # published to the public rooms directory.
+ return True
+
+ async def _on_rules_change(
+ self, event: EventBase, state_events: StateMap[EventBase]
+ ):
+ """Checks whether an im.vector.room.access_rules event is forbidden or allowed.
+
+ Args:
+ event: The im.vector.room.access_rules event.
+ state_events: A dict mapping (event type, state key) to state event.
+ State events in the room before the event was sent.
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ new_rule = event.content.get("rule")
+
+ # Check for invalid values.
+ if new_rule not in VALID_ACCESS_RULES:
+ return False
+
+ # Make sure we don't apply "direct" if the room has more than two members.
+ if new_rule == AccessRules.DIRECT:
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ if len(existing_members) > 2 or len(threepid_tokens) > 1:
+ return False
+
+ if new_rule != AccessRules.RESTRICTED:
+ # Block this change if this room is currently listed in the public rooms
+ # directory
+ if await self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ event.room_id
+ ):
+ return False
+
+ prev_rules_event = state_events.get((ACCESS_RULES_TYPE, ""))
+
+ # Now that we know the new rule doesn't break the "direct" case, we can allow any
+ # new rule in rooms that had none before.
+ if prev_rules_event is None:
+ return True
+
+ prev_rule = prev_rules_event.content.get("rule")
+
+ # Currently, we can only go from "restricted" to "unrestricted".
+ return (
+ prev_rule == AccessRules.RESTRICTED and new_rule == AccessRules.UNRESTRICTED
+ )
+
+ async def _on_membership_or_invite(
+ self, event: EventBase, rule: str, state_events: StateMap[EventBase],
+ ) -> bool:
+ """Applies the correct rule for incoming m.room.member and
+ m.room.third_party_invite events.
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+ state_events: A dict mapping (event type, state key) to state event.
+ The state of the room before the event was sent.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ if rule == AccessRules.RESTRICTED:
+ ret = self._on_membership_or_invite_restricted(event)
+ elif rule == AccessRules.UNRESTRICTED:
+ ret = self._on_membership_or_invite_unrestricted(event, state_events)
+ elif rule == AccessRules.DIRECT:
+ ret = self._on_membership_or_invite_direct(event, state_events)
+ else:
+ # We currently apply the default (restricted) if we don't know the rule, we
+ # might want to change that in the future.
+ ret = self._on_membership_or_invite_restricted(event)
+
+ if event.type == "m.room.member":
+ # If this is an admin leaving, and they are the last admin in the room,
+ # raise the power levels of the room so that the room is 'frozen'.
+ #
+ # We have to freeze the room by puppeting an admin user, which we can
+ # only do for local users
+ if (
+ self._is_local_user(event.sender)
+ and event.membership == Membership.LEAVE
+ ):
+ await self._freeze_room_if_last_admin_is_leaving(event, state_events)
+
+ return ret
+
+ async def _freeze_room_if_last_admin_is_leaving(
+ self, event: EventBase, state_events: StateMap[EventBase]
+ ):
+ power_level_state_event = state_events.get(
+ (EventTypes.PowerLevels, "")
+ ) # type: EventBase
+ if not power_level_state_event:
+ return
+ power_level_content = power_level_state_event.content
+
+ # Do some validation checks on the power level state event
+ if (
+ not isinstance(power_level_content, dict)
+ or "users" not in power_level_content
+ or not isinstance(power_level_content["users"], dict)
+ ):
+ # We can't use this power level event to determine whether the room should be
+ # frozen. Bail out.
+ return
+
+ user_id = event.get("sender")
+ if not user_id:
+ return
+
+ # Get every admin user defined in the room's state
+ admin_users = {
+ user
+ for user, power_level in power_level_content["users"].items()
+ if power_level >= 100
+ }
+
+ if user_id not in admin_users:
+ # This user is not an admin, ignore them
+ return
+
+ if any(
+ event_type == EventTypes.Member
+ and event.membership in [Membership.JOIN, Membership.INVITE]
+ and state_key in admin_users
+ and state_key != user_id
+ for (event_type, state_key), event in state_events.items()
+ ):
+ # There's another admin user in, or invited to, the room
+ return
+
+ # Freeze the room by raising the required power level to send events to 100
+ logger.info("Freezing room '%s'", event.room_id)
+
+ # Modify the existing power levels to raise all required types to 100
+ #
+ # This changes a power level state event's content from something like:
+ # {
+ # "redact": 50,
+ # "state_default": 50,
+ # "ban": 50,
+ # "notifications": {
+ # "room": 50
+ # },
+ # "events": {
+ # "m.room.avatar": 50,
+ # "m.room.encryption": 50,
+ # "m.room.canonical_alias": 50,
+ # "m.room.name": 50,
+ # "im.vector.modular.widgets": 50,
+ # "m.room.topic": 50,
+ # "m.room.tombstone": 50,
+ # "m.room.history_visibility": 100,
+ # "m.room.power_levels": 100
+ # },
+ # "users_default": 0,
+ # "events_default": 0,
+ # "users": {
+ # "@admin:example.com": 100,
+ # },
+ # "kick": 50,
+ # "invite": 0
+ # }
+ #
+ # to
+ #
+ # {
+ # "redact": 100,
+ # "state_default": 100,
+ # "ban": 100,
+ # "notifications": {
+ # "room": 50
+ # },
+ # "events": {}
+ # "users_default": 0,
+ # "events_default": 100,
+ # "users": {
+ # "@admin:example.com": 100,
+ # },
+ # "kick": 100,
+ # "invite": 100
+ # }
+ new_content = {}
+ for key, value in power_level_content.items():
+ # Do not change "users_default", as that key specifies the default power
+ # level of new users
+ if isinstance(value, int) and key != "users_default":
+ value = 100
+ new_content[key] = value
+
+ # Set some values in case they are missing from the original
+ # power levels event content
+ new_content.update(
+ {
+ # Clear out any special-cased event keys
+ "events": {},
+ # Ensure state_default and events_default keys exist and are 100.
+ # Otherwise a lower PL user could potentially send state events that
+ # aren't explicitly mentioned elsewhere in the power level dict
+ "state_default": 100,
+ "events_default": 100,
+ # Membership events default to 50 if they aren't present. Set them
+ # to 100 here, as they would be set to 100 if they were present anyways
+ "ban": 100,
+ "kick": 100,
+ "invite": 100,
+ "redact": 100,
+ }
+ )
+
+ await self.module_api.create_and_send_event_into_room(
+ {
+ "room_id": event.room_id,
+ "sender": user_id,
+ "type": EventTypes.PowerLevels,
+ "content": new_content,
+ "state_key": "",
+ }
+ )
+
+ def _on_membership_or_invite_restricted(self, event: EventBase) -> bool:
+ """Implements the checks and behaviour specified for the "restricted" rule.
+
+ "restricted" currently means that users can only invite users if their server is
+ included in a limited list of domains.
+
+ Args:
+ event: The event to check.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ # We're not applying the rules on m.room.third_party_member events here because
+ # the filtering on threepids is done in check_threepid_can_be_invited, which is
+ # called before check_event_allowed.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return True
+
+ # We only need to process "join" and "invite" memberships, in order to be backward
+ # compatible, e.g. if a user from a blacklisted server joined a restricted room
+ # before the rules started being enforced on the server, that user must be able to
+ # leave it.
+ if event.membership not in [Membership.JOIN, Membership.INVITE]:
+ return True
+
+ invitee_domain = get_domain_from_id(event.state_key)
+ return invitee_domain not in self.domains_forbidden_when_restricted
+
+ def _on_membership_or_invite_unrestricted(
+ self, event: EventBase, state_events: StateMap[EventBase]
+ ) -> bool:
+ """Implements the checks and behaviour specified for the "unrestricted" rule.
+
+ "unrestricted" currently means that forbidden users cannot join without an invite.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ # If this is a join from a forbidden user and they don't have an invite to the
+ # room, then deny it
+ if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+ # Check if this user is from a forbidden server
+ target_domain = get_domain_from_id(event.state_key)
+ if target_domain in self.domains_forbidden_when_restricted:
+ # If so, they'll need an invite to join this room. Check if one exists
+ if not self._user_is_invited_to_room(event.state_key, state_events):
+ return False
+
+ return True
+
+ def _on_membership_or_invite_direct(
+ self, event: EventBase, state_events: StateMap[EventBase],
+ ) -> bool:
+ """Implements the checks and behaviour specified for the "direct" rule.
+
+ "direct" currently means that no member is allowed apart from the two initial
+ members the room was created for (i.e. the room's creator and their first invitee).
+
+ Args:
+ event: The event to check.
+ state_events: A dict mapping (event type, state key) to state event.
+ The state of the room before the event was sent.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ # Get the room memberships and 3PID invite tokens from the room's state.
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ # There should never be more than one 3PID invite in the room state: if the second
+ # original user came and left, and we're inviting them using their email address,
+ # given we know they have a Matrix account binded to the address (so they could
+ # join the first time), Synapse will successfully look it up before attempting to
+ # store an invite on the IS.
+ if len(threepid_tokens) == 1 and event.type == EventTypes.ThirdPartyInvite:
+ # If we already have a 3PID invite in flight, don't accept another one, unless
+ # the new one has the same invite token as its state key. This is because 3PID
+ # invite revocations must be allowed, and a revocation is basically a new 3PID
+ # invite event with an empty content and the same token as the invite it
+ # revokes.
+ return event.state_key in threepid_tokens
+
+ if len(existing_members) == 2:
+ # If the user was within the two initial user of the room, Synapse would have
+ # looked it up successfully and thus sent a m.room.member here instead of
+ # m.room.third_party_invite.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return False
+
+ # We can only have m.room.member events here. The rule in this case is to only
+ # allow the event if its target is one of the initial two members in the room,
+ # i.e. the state key of one of the two m.room.member states in the room.
+ return event.state_key in existing_members
+
+ # We're alone in the room (and always have been) and there's one 3PID invite in
+ # flight.
+ if len(existing_members) == 1 and len(threepid_tokens) == 1:
+ # We can only have m.room.member events here. In this case, we can only allow
+ # the event if it's either a m.room.member from the joined user (we can assume
+ # that the only m.room.member event is a join otherwise we wouldn't be able to
+ # send an event to the room) or an an invite event which target is the invited
+ # user.
+ target = event.state_key
+ is_from_threepid_invite = self._is_invite_from_threepid(
+ event, threepid_tokens[0]
+ )
+ return is_from_threepid_invite or target == existing_members[0]
+
+ return True
+
+ def _is_power_level_content_allowed(
+ self, content: Dict, access_rule: str, on_room_creation: bool = True
+ ) -> bool:
+ """Check if a given power levels event is permitted under the given access rule.
+
+ It shouldn't be allowed if it either changes the default PL to a non-0 value or
+ gives a non-0 PL to a user that would have been forbidden from joining the room
+ under a more restrictive access rule.
+
+ Args:
+ content: The content of the m.room.power_levels event to check.
+ access_rule: The access rule in place in this room.
+ on_room_creation: True if this call is happening during a room's
+ creation, False otherwise.
+
+ Returns:
+ Whether the content of the power levels event is valid.
+ """
+ # Only enforce these rules during room creation
+ #
+ # We want to allow admins to modify or fix the power levels in a room if they
+ # have a special circumstance, but still want to encourage a certain pattern during
+ # room creation.
+ if on_room_creation:
+ # If invite requirements are <PL50
+ if content.get("invite", 50) < 50:
+ return False
+
+ # If "other" state requirements are <PL100
+ if content.get("state_default", 100) < 100:
+ return False
+
+ # Check if we need to apply the restrictions with the current rule.
+ if access_rule not in RULES_WITH_RESTRICTED_POWER_LEVELS:
+ return True
+
+ # If users_default is explicitly set to a non-0 value, deny the event.
+ users_default = content.get("users_default", 0)
+ if users_default:
+ return False
+
+ users = content.get("users", {})
+ for user_id, power_level in users.items():
+ server_name = get_domain_from_id(user_id)
+ # Check the domain against the blacklist. If found, and the PL isn't 0, deny
+ # the event.
+ if (
+ server_name in self.domains_forbidden_when_restricted
+ and power_level != 0
+ ):
+ return False
+
+ return True
+
+ def _on_join_rule_change(self, event: EventBase, rule: str) -> bool:
+ """Check whether a join rule change is allowed.
+
+ A join rule change is always allowed unless the new join rule is "public" and
+ the current access rule is "direct".
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+
+ Returns:
+ Whether the change is allowed.
+ """
+ if event.content.get("join_rule") == JoinRules.PUBLIC:
+ return rule != AccessRules.DIRECT
+
+ return True
+
+ def _on_room_avatar_change(self, event: EventBase, rule: str) -> bool:
+ """Check whether a change of room avatar is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ return rule != AccessRules.DIRECT
+
+ def _on_room_name_change(self, event: EventBase, rule: str) -> bool:
+ """Check whether a change of room name is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ return rule != AccessRules.DIRECT
+
+ def _on_room_topic_change(self, event: EventBase, rule: str) -> bool:
+ """Check whether a change of room topic is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ return rule != AccessRules.DIRECT
+
+ @staticmethod
+ def _get_rule_from_state(state_events: StateMap[EventBase]) -> Optional[str]:
+ """Extract the rule to be applied from the given set of state events.
+
+ Args:
+ state_events: A dict mapping (event type, state key) to state event.
+
+ Returns:
+ The name of the rule (either "direct", "restricted" or "unrestricted") if found,
+ else None.
+ """
+ access_rules = state_events.get((ACCESS_RULES_TYPE, ""))
+ if access_rules is None:
+ return AccessRules.RESTRICTED
+
+ return access_rules.content.get("rule")
+
+ @staticmethod
+ def _get_join_rule_from_state(state_events: StateMap[EventBase]) -> Optional[str]:
+ """Extract the room's join rule from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+
+ Returns:
+ The name of the join rule (either "public", or "invite") if found, else None.
+ """
+ join_rule_event = state_events.get((EventTypes.JoinRules, ""))
+ if join_rule_event is None:
+ return None
+
+ return join_rule_event.content.get("join_rule")
+
+ @staticmethod
+ def _get_members_and_tokens_from_state(
+ state_events: StateMap[EventBase],
+ ) -> Tuple[List[str], List[str]]:
+ """Retrieves the list of users that have a m.room.member event in the room,
+ as well as 3PID invites tokens in the room.
+
+ Args:
+ state_events: A dict mapping (event type, state key) to state event.
+
+ Returns:
+ A tuple containing the:
+ * targets of the m.room.member events in the state.
+ * 3PID invite tokens in the state.
+ """
+ existing_members = []
+ threepid_invite_tokens = []
+ for key, state_event in state_events.items():
+ if key[0] == EventTypes.Member and state_event.content:
+ existing_members.append(state_event.state_key)
+ if key[0] == EventTypes.ThirdPartyInvite and state_event.content:
+ # Don't include revoked invites.
+ threepid_invite_tokens.append(state_event.state_key)
+
+ return existing_members, threepid_invite_tokens
+
+ @staticmethod
+ def _is_invite_from_threepid(invite: EventBase, threepid_invite_token: str) -> bool:
+ """Checks whether the given invite follows the given 3PID invite.
+
+ Args:
+ invite: The m.room.member event with "invite" membership.
+ threepid_invite_token: The state key from the 3PID invite.
+
+ Returns:
+ Whether the invite is due to the given 3PID invite.
+ """
+ token = (
+ invite.content.get("third_party_invite", {})
+ .get("signed", {})
+ .get("token", "")
+ )
+
+ return token == threepid_invite_token
+
+ def _is_local_user(self, user_id: str) -> bool:
+ """Checks whether a given user ID belongs to this homeserver, or a remote
+
+ Args:
+ user_id: A user ID to check.
+
+ Returns:
+ True if the user belongs to this homeserver, False otherwise.
+ """
+ user = UserID.from_string(user_id)
+
+ # Extract the localpart and ask the module API for a user ID from the localpart
+ # The module API will append the local homeserver's server_name
+ local_user_id = self.module_api.get_qualified_user_id(user.localpart)
+
+ # If the user ID we get based on the localpart is the same as the original user ID,
+ # then they were a local user
+ return user_id == local_user_id
+
+ def _user_is_invited_to_room(
+ self, user_id: str, state_events: StateMap[EventBase]
+ ) -> bool:
+ """Checks whether a given user has been invited to a room
+
+ A user has an invite for a room if its state contains a `m.room.member`
+ event with membership "invite" and their user ID as the state key.
+
+ Args:
+ user_id: The user to check.
+ state_events: The state events from the room.
+
+ Returns:
+ True if the user has been invited to the room, or False if they haven't.
+ """
+ for (event_type, state_key), state_event in state_events.items():
+ if (
+ event_type == EventTypes.Member
+ and state_key == user_id
+ and state_event.membership == Membership.INVITE
+ ):
+ return True
+
+ return False
diff --git a/synapse/types.py b/synapse/types.py
index 9e580f4295..e867cadbad 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -22,6 +22,7 @@ from typing import Any, Dict, Tuple, Type, TypeVar
import attr
from signedjson.key import decode_verify_key_bytes
+from six.moves import filter
from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError
@@ -51,7 +52,15 @@ JsonDict = Dict[str, Any]
class Requester(
namedtuple(
- "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
+ "Requester",
+ [
+ "user",
+ "access_token_id",
+ "is_guest",
+ "shadow_banned",
+ "device_id",
+ "app_service",
+ ],
)
):
"""
@@ -62,6 +71,7 @@ class Requester(
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
+ shadow_banned (bool): True if the user making this request has been shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
"""
@@ -77,6 +87,7 @@ class Requester(
"user_id": self.user.to_string(),
"access_token_id": self.access_token_id,
"is_guest": self.is_guest,
+ "shadow_banned": self.shadow_banned,
"device_id": self.device_id,
"app_server_id": self.app_service.id if self.app_service else None,
}
@@ -101,13 +112,19 @@ class Requester(
user=UserID.from_string(input["user_id"]),
access_token_id=input["access_token_id"],
is_guest=input["is_guest"],
+ shadow_banned=input["shadow_banned"],
device_id=input["device_id"],
app_service=appservice,
)
def create_requester(
- user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None
+ user_id,
+ access_token_id=None,
+ is_guest=False,
+ shadow_banned=False,
+ device_id=None,
+ app_service=None,
):
"""
Create a new ``Requester`` object
@@ -117,6 +134,7 @@ def create_requester(
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
+ shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
@@ -125,7 +143,9 @@ def create_requester(
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
- return Requester(user_id, access_token_id, is_guest, device_id, app_service)
+ return Requester(
+ user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
+ )
def get_domain_from_id(string):
@@ -276,6 +296,19 @@ def contains_invalid_mxid_characters(localpart):
return any(c not in mxid_localpart_allowed_characters for c in localpart)
+def strip_invalid_mxid_characters(localpart):
+ """Removes any invalid characters from an mxid
+
+ Args:
+ localpart (basestring): the localpart to be stripped
+
+ Returns:
+ localpart (basestring): the localpart having been stripped
+ """
+ filtered = filter(lambda c: c in mxid_localpart_allowed_characters, localpart)
+ return "".join(filtered)
+
+
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index b3f76428b6..b2a22dbd5c 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -25,8 +25,18 @@ from synapse.logging import context
logger = logging.getLogger(__name__)
-# Create a custom encoder to reduce the whitespace produced by JSON encoding.
-json_encoder = json.JSONEncoder(separators=(",", ":"))
+
+def _reject_invalid_json(val):
+ """Do not allow Infinity, -Infinity, or NaN values in JSON."""
+ raise json.JSONDecodeError("Invalid JSON value: '%s'" % val)
+
+
+# Create a custom encoder to reduce the whitespace produced by JSON encoding and
+# ensure that valid JSON is produced.
+json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
+
+# Create a custom decoder to reject Python extensions to JSON.
+json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
def unwrapFirstError(failure):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index c2d72a82cf..49d9fddcf0 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -285,16 +285,9 @@ class Cache(object):
class _CacheDescriptorBase(object):
- def __init__(
- self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
- ):
+ def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
self.orig = orig
- if inlineCallbacks:
- self.function_to_call = defer.inlineCallbacks(orig)
- else:
- self.function_to_call = orig
-
arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
@@ -364,7 +357,7 @@ class CacheDescriptor(_CacheDescriptorBase):
invalidated) by adding a special "cache_context" argument to the function
and passing that as a kwarg to all caches called. For example::
- @cachedInlineCallbacks(cache_context=True)
+ @cached(cache_context=True)
def foo(self, key, cache_context):
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
@@ -382,17 +375,11 @@ class CacheDescriptor(_CacheDescriptorBase):
max_entries=1000,
num_args=None,
tree=False,
- inlineCallbacks=False,
cache_context=False,
iterable=False,
):
- super(CacheDescriptor, self).__init__(
- orig,
- num_args=num_args,
- inlineCallbacks=inlineCallbacks,
- cache_context=cache_context,
- )
+ super().__init__(orig, num_args=num_args, cache_context=cache_context)
self.max_entries = max_entries
self.tree = tree
@@ -465,9 +452,7 @@ class CacheDescriptor(_CacheDescriptorBase):
observer = defer.succeed(cached_result_d)
except KeyError:
- ret = defer.maybeDeferred(
- preserve_fn(self.function_to_call), obj, *args, **kwargs
- )
+ ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
def onErr(f):
cache.invalidate(cache_key)
@@ -510,9 +495,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
of results.
"""
- def __init__(
- self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
- ):
+ def __init__(self, orig, cached_method_name, list_name, num_args=None):
"""
Args:
orig (function)
@@ -521,12 +504,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
num_args (int): number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all
named args of the function.
- inlineCallbacks (bool): Whether orig is a generator that should
- be wrapped by defer.inlineCallbacks
"""
- super(CacheListDescriptor, self).__init__(
- orig, num_args=num_args, inlineCallbacks=inlineCallbacks
- )
+ super().__init__(orig, num_args=num_args)
self.list_name = list_name
@@ -631,7 +610,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
cached_defers.append(
defer.maybeDeferred(
- preserve_fn(self.function_to_call), **args_to_call
+ preserve_fn(self.orig), **args_to_call
).addCallbacks(complete_all, errback)
)
@@ -695,21 +674,7 @@ def cached(
)
-def cachedInlineCallbacks(
- max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
-):
- return lambda orig: CacheDescriptor(
- orig,
- max_entries=max_entries,
- num_args=num_args,
- tree=tree,
- inlineCallbacks=True,
- cache_context=cache_context,
- iterable=iterable,
- )
-
-
-def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=None):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
@@ -725,8 +690,6 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters.
- inlineCallbacks (bool): Should the function be wrapped in an
- `defer.inlineCallbacks`?
Example:
@@ -744,5 +707,4 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
- inlineCallbacks=inlineCallbacks,
)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 2e2b40a426..61d96a6c28 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -24,9 +24,7 @@ from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
-# Note: The : character is allowed here for older clients, but will be removed in a
-# future release. Context: https://github.com/matrix-org/synapse/issues/6766
-client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$")
+client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
# random_string and random_string_with_symbols are used for a range of things,
# some cryptographically important, some less so. We use SystemRandom to make sure
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 43c2e0ac23..cfdaa1c5d9 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -16,11 +16,14 @@
import logging
import re
+from twisted.internet import defer
+
logger = logging.getLogger(__name__)
+@defer.inlineCallbacks
def check_3pid_allowed(hs, medium, address):
- """Checks whether a given format of 3PID is allowed to be used on this HS
+ """Checks whether a given 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
@@ -28,9 +31,36 @@ def check_3pid_allowed(hs, medium, address):
address (str): address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
- bool: whether the 3PID medium/address is allowed to be added to this HS
+ defered bool: whether the 3PID medium/address is allowed to be added to this HS
"""
+ if hs.config.check_is_for_allowed_local_3pids:
+ data = yield hs.get_simple_http_client().get_json(
+ "https://%s%s"
+ % (
+ hs.config.check_is_for_allowed_local_3pids,
+ "/_matrix/identity/api/v1/internal-info",
+ ),
+ {"medium": medium, "address": address},
+ )
+
+ # Check for invalid response
+ if "hs" not in data and "shadow_hs" not in data:
+ defer.returnValue(False)
+
+ # Check if this user is intended to register for this homeserver
+ if (
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
+ ):
+ defer.returnValue(False)
+
+ if data.get("requires_invite", False) and not data.get("invited", False):
+ # Requires an invite but hasn't been invited
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
@@ -43,11 +73,11 @@ def check_3pid_allowed(hs, medium, address):
if medium == constraint["medium"] and re.match(
constraint["pattern"], address
):
- return True
+ defer.returnValue(True)
else:
- return True
+ defer.returnValue(True)
- return False
+ defer.returnValue(False)
def canonicalise_email(address: str) -> str:
diff --git a/sytest-blacklist b/sytest-blacklist
index 79b2d4402a..7ceeaca8d6 100644
--- a/sytest-blacklist
+++ b/sytest-blacklist
@@ -36,3 +36,29 @@ Inbound federation of state requires event_id as a mandatory paramater
# Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands
Can upload self-signing keys
+
+# Blacklisted due to https://github.com/matrix-org/synapse-dinsic/issues/43
+Inviting an AS-hosted user asks the AS server
+Accesing an AS-hosted room alias asks the AS server
+Events in rooms with AS-hosted room aliases are sent to AS server
+
+# flaky test
+If remote user leaves room we no longer receive device updates
+
+# flaky test
+Can re-join room if re-invited
+
+# flaky test
+Forgotten room messages cannot be paginated
+
+# flaky test
+Local device key changes get to remote servers
+
+# flaky test
+Old leaves are present in gapped incremental syncs
+
+# flaky test on workers
+Old members are included in gappy incr LL sync if they start speaking
+
+# flaky test on workers
+Presence changes to UNAVAILABLE are reported to remote room members
diff --git a/tests/config/test_base.py b/tests/config/test_base.py
new file mode 100644
index 0000000000..42ee5f56d9
--- /dev/null
+++ b/tests/config/test_base.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path
+import tempfile
+
+from synapse.config import ConfigError
+from synapse.util.stringutils import random_string
+
+from tests import unittest
+
+
+class BaseConfigTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.hs = hs
+
+ def test_loading_missing_templates(self):
+ # Use a temporary directory that exists on the system, but that isn't likely to
+ # contain template files
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Attempt to load an HTML template from our custom template directory
+ template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0]
+
+ # If no errors, we should've gotten the default template instead
+
+ # Render the template
+ a_random_string = random_string(5)
+ html_content = template.render({"error_description": a_random_string})
+
+ # Check that our string exists in the template
+ self.assertIn(
+ a_random_string,
+ html_content,
+ "Template file did not contain our test string",
+ )
+
+ def test_loading_custom_templates(self):
+ # Use a temporary directory that exists on the system
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Create a temporary bogus template file
+ with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template:
+ # Get temporary file's filename
+ template_filename = os.path.basename(tmp_template.name)
+
+ # Write a custom HTML template
+ contents = b"{{ test_variable }}"
+ tmp_template.write(contents)
+ tmp_template.flush()
+
+ # Attempt to load the template from our custom template directory
+ template = (
+ self.hs.config.read_templates([template_filename], tmp_dir)
+ )[0]
+
+ # Render the template
+ a_random_string = random_string(5)
+ html_content = template.render({"test_variable": a_random_string})
+
+ # Check that our string exists in the template
+ self.assertIn(
+ a_random_string,
+ html_content,
+ "Template file did not contain our test string",
+ )
+
+ def test_loading_template_from_nonexistent_custom_directory(self):
+ with self.assertRaises(ConfigError):
+ self.hs.config.read_templates(
+ ["some_filename.html"], "a_nonexistent_directory"
+ )
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index b8ca118716..9bd515080c 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -79,9 +79,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+ )
handler.federation_handler.do_invite_join = Mock(
- return_value=make_awaitable(("", 1))
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -110,9 +112,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+ )
handler.federation_handler.do_invite_join = Mock(
- return_value=make_awaitable(("", 1))
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -148,9 +152,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable(None)
+ )
handler.federation_handler.do_invite_join = Mock(
- return_value=make_awaitable(("", 1))
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
# Artificially raise the complexity
@@ -204,9 +210,11 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+ )
handler.federation_handler.do_invite_join = Mock(
- return_value=make_awaitable(("", 1))
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -234,9 +242,11 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+ )
handler.federation_handler.do_invite_join = Mock(
- return_value=make_awaitable(("", 1))
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(
diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py
new file mode 100644
index 0000000000..0ab0356109
--- /dev/null
+++ b/tests/handlers/test_identity.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse.rest.admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account
+
+from tests import unittest
+
+
+class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.address = "test@test"
+ self.is_server_name = "testis"
+ self.is_server_url = "https://testis"
+ self.rewritten_is_url = "https://int.testis"
+
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = [self.is_server_name]
+ config["rewrite_identity_server_urls"] = {
+ self.is_server_url: self.rewritten_is_url
+ }
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.side_effect = defer.succeed({})
+ mock_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ mock_blacklisting_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_blacklisting_http_client.get_json.side_effect = defer.succeed({})
+ mock_blacklisting_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.blacklisting_http_client = (
+ mock_blacklisting_http_client
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+
+ def test_rewritten_id_server(self):
+ """
+ Tests that, when validating a 3PID association while rewriting the IS's server
+ name:
+ * the bind request is done against the rewritten hostname
+ * the original, non-rewritten, server name is stored in the database
+ """
+ handler = self.hs.get_handlers().identity_handler
+ post_json_get_json = handler.blacklisting_http_client.post_json_get_json
+ store = self.hs.get_datastore()
+
+ creds = {"sid": "123", "client_secret": "some_secret"}
+
+ # Make sure processing the mocked response goes through.
+ data = self.get_success(
+ handler.bind_threepid(
+ client_secret=creds["client_secret"],
+ sid=creds["sid"],
+ mxid=self.user_id,
+ id_server=self.is_server_name,
+ use_v2=False,
+ )
+ )
+ self.assertEqual(data.get("address"), self.address)
+
+ # Check that the request was done against the rewritten server name.
+ post_json_get_json.assert_called_once_with(
+ "%s/_matrix/identity/api/v1/3pid/bind" % (self.rewritten_is_url,),
+ {
+ "sid": creds["sid"],
+ "client_secret": creds["client_secret"],
+ "mxid": self.user_id,
+ },
+ headers={},
+ )
+
+ # Check that the original server name is saved in the database instead of the
+ # rewritten one.
+ id_servers = self.get_success(
+ store.get_id_servers_user_bound(self.user_id, "email", self.address)
+ )
+ self.assertEqual(id_servers, [self.is_server_name])
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 05ea40a7de..306dcfe944 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -19,6 +19,7 @@ from mock import Mock, call
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
@@ -32,7 +33,6 @@ from synapse.handlers.presence import (
handle_update,
)
from synapse.rest.client.v1 import room
-from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from tests import unittest
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d70e1fc608..d7f0c19c4c 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -64,14 +64,16 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- yield self.store.create_profile(self.frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
self.handler = hs.get_profile_handler()
self.hs = hs
@defer.inlineCallbacks
def test_get_my_name(self):
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
+ )
displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.frank)
@@ -112,10 +114,17 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
+ )
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
+ "Frank",
)
# Setting displayname a second time is forbidden
@@ -157,8 +166,10 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.create_profile("caroline")
- yield self.store.set_profile_displayname("caroline", "Caroline")
+ yield defer.ensureDeferred(self.store.create_profile("caroline"))
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname("caroline", "Caroline", 1)
+ )
response = yield defer.ensureDeferred(
self.query_handlers["profile"](
@@ -170,8 +181,10 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
- yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png", 1
+ )
)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
@@ -211,8 +224,10 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png", 1
+ )
)
self.assertEquals(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e364b1bd62..6dfea58cff 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -20,8 +20,14 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v2_alpha.register import (
+ _map_email_to_displayname,
+ register_servlets,
+)
from synapse.types import RoomAlias, UserID, create_requester
+from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -36,6 +42,10 @@ class RegistrationHandlers(object):
class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """
+ servlets = [
+ register_servlets,
+ ]
+
def make_homeserver(self, reactor, clock):
hs_config = self.default_config()
@@ -475,6 +485,104 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_email_to_displayname_mapping(self):
+ """Test that custom emails are mapped to new user displaynames correctly"""
+ self._check_mapping(
+ "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]"
+ )
+
+ self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]")
+
+ self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]")
+
+ # Multibyte unicode characters
+ self._check_mapping(
+ "j\u030a\u0065an-poppy.seed@example.com",
+ "J\u030a\u0065an-Poppy Seed [Example]",
+ )
+
+ def _check_mapping(self, i, expected):
+ result = _map_email_to_displayname(i)
+ self.assertEqual(result, expected)
+
+ @override_config(
+ {
+ "bind_new_user_emails_to_sydent": "https://is.example.com",
+ "registrations_require_3pid": ["email"],
+ "account_threepid_delegates": {},
+ "email": {
+ "smtp_host": "127.0.0.1",
+ "smtp_port": 20,
+ "require_transport_security": False,
+ "smtp_user": None,
+ "smtp_pass": None,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "http://localhost",
+ }
+ )
+ def test_user_email_bound_via_sydent_internal_api(self):
+ """Tests that emails are bound after registration if this option is set"""
+ # Register user with an email address
+ email = "alice@example.com"
+
+ # Mock Synapse's threepid validator
+ get_threepid_validation_session = Mock(
+ return_value=defer.succeed(
+ {"medium": "email", "address": email, "validated_at": 0}
+ )
+ )
+ self.store.get_threepid_validation_session = get_threepid_validation_session
+ delete_threepid_session = Mock(return_value=defer.succeed(None))
+ self.store.delete_threepid_session = delete_threepid_session
+
+ # Mock Synapse's http json post method to check for the internal bind call
+ post_json_get_json = Mock(return_value=defer.succeed(None))
+ self.hs.get_simple_http_client().post_json_get_json = post_json_get_json
+
+ # Retrieve a UIA session ID
+ channel = self.uia_register(
+ 401, {"username": "alice", "password": "nobodywillguessthis"}
+ )
+ session_id = channel.json_body["session"]
+
+ # Register our email address using the fake validation session above
+ channel = self.uia_register(
+ 200,
+ {
+ "username": "alice",
+ "password": "nobodywillguessthis",
+ "auth": {
+ "session": session_id,
+ "type": "m.login.email.identity",
+ "threepid_creds": {"sid": "blabla", "client_secret": "blablabla"},
+ },
+ },
+ )
+ self.assertEqual(channel.json_body["user_id"], "@alice:test")
+
+ # Check that a bind attempt was made to our fake identity server
+ post_json_get_json.assert_called_with(
+ "https://is.example.com/_matrix/identity/internal/bind",
+ {"address": "alice@example.com", "medium": "email", "mxid": "@alice:test"},
+ )
+
+ # Check that we stored a mapping of this bind
+ bound_threepids = self.get_success(
+ self.store.user_get_bound_threepids("@alice:test")
+ )
+ self.assertListEqual(bound_threepids, [{"medium": "email", "address": email}])
+
+ def uia_register(self, expected_response: int, body: dict) -> FakeChannel:
+ """Make a register request."""
+ request, channel = self.make_request(
+ "POST", "register", body
+ ) # type: SynapseRequest, FakeChannel
+ self.render(request)
+
+ self.assertEqual(request.code, expected_response)
+ return channel
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 0e666492f6..4b627dac00 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -21,8 +21,14 @@ from tests import unittest
# The expected number of state events in a fresh public room.
EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5
+
# The expected number of state events in a fresh private room.
-EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
+#
+# Note: we increase this by 2 on the dinsic branch as we send
+# a "im.vector.room.access_rules" state event into new private rooms,
+# and an encryption state event as all private rooms are encrypted
+# by default
+EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7
class StatsRoomTests(unittest.HomeserverTestCase):
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 64afd581bc..e01de158e5 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
([], 0)
)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
- self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+ self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
None
)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 31ed89a5cd..46c3810e70 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import user_directory
+from synapse.rest.client.v2_alpha import account, account_validity, user_directory
from synapse.storage.roommember import ProfileInfo
from tests import unittest
@@ -549,3 +549,136 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)
+
+
+class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ account.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+
+ # Set accounts to expire after a week
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ return config
+
+ def prepare(self, reactor, clock, hs):
+ super(UserInfoTestCase, self).prepare(reactor, clock, hs)
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ def test_user_info(self):
+ """Test /users/info for local users from the Client-Server API"""
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request info about each user from user_three
+ request, channel = self.make_request(
+ "POST",
+ path="/_matrix/client/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ access_token=user_three_token,
+ shorthand=False,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def test_user_info_federation(self):
+ """Test that /users/info can be called from the Federation API, and
+ and that we can query remote users from the Client-Server API
+ """
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request information about our local users from the perspective of a remote server
+ request, channel = self.make_request(
+ "POST",
+ path="/_matrix/federation/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def setup_test_users(self):
+ """Create an admin user and three test users, each with a different state"""
+
+ # Create an admin user to expire other users with
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_token = self.login("admin", "adminpassword")
+
+ # Create three users
+ user_one = self.register_user("alice", "pass")
+ user_one_token = self.login("alice", "pass")
+ user_two = self.register_user("bob", "pass")
+ user_three = self.register_user("carl", "pass")
+ user_three_token = self.login("carl", "pass")
+
+ # Deactivate user_one
+ self.deactivate(user_one, user_one_token)
+
+ # Expire user_two
+ self.expire(user_two, admin_token)
+
+ # Do nothing to user_three
+
+ return user_one, user_two, user_three, user_three_token
+
+ def expire(self, user_id_to_expire, admin_tok):
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ request_data = {
+ "user_id": user_id_to_expire,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def deactivate(self, user_id, tok):
+ request_data = {
+ "auth": {"type": "m.login.password", "user": user_id, "password": "pass"},
+ "erase": False,
+ }
+ request, channel = self.make_request(
+ "POST", "account/deactivate", request_data, access_token=tok
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 69945a8f98..db260d599e 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -101,7 +101,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.agent = MatrixFederationAgent(
reactor=self.reactor,
- tls_client_options_factory=self.tls_factory,
+ tls_client_options_factory=FederationPolicyForHTTPS(config),
user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 807cd65dd6..9c778a0e45 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -12,16 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
+from synapse.events import EventBase
from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import create_requester
from tests.unittest import HomeserverTestCase
class ModuleApiTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler())
+ self.event_creation_handler = homeserver.get_event_creation_handler()
def test_can_register_user(self):
"""Tests that an external module can register a user"""
@@ -52,3 +64,137 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the displayname was assigned
displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino")
+
+ def test_sending_events_into_room(self):
+ """Tests that a module can send events into a room"""
+ # Mock out create_and_send_nonmember_event to check whether events are being sent
+ self.event_creation_handler.create_and_send_nonmember_event = Mock(
+ spec=[],
+ side_effect=self.event_creation_handler.create_and_send_nonmember_event,
+ )
+
+ # Create a user and room to play with
+ user_id = self.register_user("summer", "monkey")
+ tok = self.login("summer", "monkey")
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ # Create and send a non-state event
+ content = {"body": "I am a puppet", "msgtype": "m.text"}
+ event_dict = {
+ "room_id": room_id,
+ "type": "m.room.message",
+ "content": content,
+ "sender": user_id,
+ }
+ event = self.get_success(
+ self.module_api.create_and_send_event_into_room(event_dict)
+ ) # type: EventBase
+ self.assertEqual(event.sender, user_id)
+ self.assertEqual(event.type, "m.room.message")
+ self.assertEqual(event.room_id, room_id)
+ self.assertFalse(hasattr(event, "state_key"))
+ self.assertDictEqual(event.content, content)
+
+ # Check that the event was sent
+ self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
+ create_requester(user_id), event_dict, ratelimit=False,
+ )
+
+ # Create and send a state event
+ content = {
+ "events_default": 0,
+ "users": {user_id: 100},
+ "state_default": 50,
+ "users_default": 0,
+ "events": {"test.event.type": 25},
+ }
+ event_dict = {
+ "room_id": room_id,
+ "type": "m.room.power_levels",
+ "content": content,
+ "sender": user_id,
+ "state_key": "",
+ }
+ event = self.get_success(
+ self.module_api.create_and_send_event_into_room(event_dict)
+ ) # type: EventBase
+ self.assertEqual(event.sender, user_id)
+ self.assertEqual(event.type, "m.room.power_levels")
+ self.assertEqual(event.room_id, room_id)
+ self.assertEqual(event.state_key, "")
+ self.assertDictEqual(event.content, content)
+
+ # Check that the event was sent
+ self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
+ create_requester(user_id),
+ {
+ "type": "m.room.power_levels",
+ "content": content,
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": "",
+ },
+ ratelimit=False,
+ )
+
+ # Check that we can't send membership events
+ content = {
+ "membership": "leave",
+ }
+ event_dict = {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "content": content,
+ "sender": user_id,
+ "state_key": user_id,
+ }
+ self.get_failure(
+ self.module_api.create_and_send_event_into_room(event_dict), Exception
+ )
+
+ def test_public_rooms(self):
+ """Tests that a room can be added and removed from the public rooms list,
+ as well as have its public rooms directory state queried.
+ """
+ # Create a user and room to play with
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ # The room should not currently be in the public rooms directory
+ is_in_public_rooms = self.get_success(
+ self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ room_id
+ )
+ )
+ self.assertFalse(is_in_public_rooms)
+
+ # Let's try adding it to the public rooms directory
+ self.get_success(
+ self.module_api.public_room_list_manager.add_room_to_public_room_list(
+ room_id
+ )
+ )
+
+ # And checking whether it's in there...
+ is_in_public_rooms = self.get_success(
+ self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ room_id
+ )
+ )
+ self.assertTrue(is_in_public_rooms)
+
+ # Let's remove it again
+ self.get_success(
+ self.module_api.public_room_list_manager.remove_room_from_public_room_list(
+ room_id
+ )
+ )
+
+ # Should be gone
+ is_in_public_rooms = self.get_success(
+ self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ room_id
+ )
+ )
+ self.assertFalse(is_in_public_rooms)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index b567868b02..2f56cacc7a 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -346,8 +346,8 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(self.push_attempts[1][1], "example.com")
- # check that this is low-priority
- self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+ # check that this is high-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_mention(self):
"""
@@ -418,8 +418,8 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(self.push_attempts[1][1], "example.com")
- # check that this is low-priority
- self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+ # check that this is high-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_atroom(self):
"""
@@ -497,5 +497,5 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(self.push_attempts[1][1], "example.com")
- # check that this is low-priority
- self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+ # check that this is high-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..4224b0a92e 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@
import json
+from mock import Mock
+
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import account
from tests import unittest
-class IdentityTestCase(unittest.HomeserverTestCase):
+class IdentityDisabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts fail when the HS's config disallows them."""
servlets = [
+ account.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
@@ -32,24 +39,111 @@ class IdentityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["testis"]
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_disabled(self):
+ request, channel = self.make_request(
+ b"POST", "/createRoom", b"{}", access_token=self.tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ room_id = channel.json_body["room_id"]
+
+ params = {
+ "id_server": "testis",
+ "medium": "email",
+ "address": "test@example.com",
+ }
+ request_data = json.dumps(params)
+ request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
+ request, channel = self.make_request(
+ b"POST", request_url, request_data, access_token=self.tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
def test_3pid_lookup_disabled(self):
- self.hs.config.enable_3pid_lookup = False
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+ def test_3pid_bulk_lookup_disabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+
+class IdentityEnabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts succeed when the HS's config allows them."""
+
+ servlets = [
+ account.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
- self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["enable_3pid_lookup"] = True
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.return_value = defer.succeed((200, "{}"))
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.http_client = mock_http_client
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_enabled(self):
request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=tok
+ b"POST", "/createRoom", b"{}", access_token=self.tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
room_id = channel.json_body["room_id"]
+ # Replace the blacklisting SimpleHttpClient with our mock
+ self.hs.get_room_member_handler().simple_http_client = Mock(
+ spec=["get_json", "post_json_get_json"]
+ )
+ self.hs.get_room_member_handler().simple_http_client.get_json.return_value = defer.succeed(
+ (200, "{}")
+ )
+
params = {
"id_server": "testis",
"medium": "email",
@@ -58,7 +152,44 @@ class IdentityTestCase(unittest.HomeserverTestCase):
request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
request, channel = self.make_request(
- b"POST", request_url, request_data, access_token=tok
+ b"POST", request_url, request_data, access_token=self.tok
)
self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ get_json = self.hs.get_handlers().identity_handler.http_client.get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "test@example.com", "medium": "email"},
+ )
+
+ def test_3pid_lookup_enabled(self):
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.render(request)
+
+ get_json = self.hs.get_simple_http_client().get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "foo@bar.baz", "medium": "email"},
+ )
+
+ def test_3pid_bulk_lookup_enabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.render(request)
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ post_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/bulk_lookup",
+ {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]},
+ )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 0b191d13c6..ab91baeacc 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -34,6 +34,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
"default_policy": {
@@ -205,6 +206,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
}
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
new file mode 100644
index 0000000000..de7856fba9
--- /dev/null
+++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,1066 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import random
+import string
+from typing import Optional
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.rest import admin
+from synapse.rest.client.v1 import directory, login, room
+from synapse.third_party_rules.access_rules import (
+ ACCESS_RULES_TYPE,
+ AccessRules,
+ RoomAccessRules,
+)
+from synapse.types import JsonDict, create_requester
+
+from tests import unittest
+
+
+class RoomAccessTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["third_party_event_rules"] = {
+ "module": "synapse.third_party_rules.access_rules.RoomAccessRules",
+ "config": {
+ "domains_forbidden_when_restricted": ["forbidden_domain"],
+ "id_server": "testis",
+ },
+ }
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ def send_invite(destination, room_id, event_id, pdu):
+ return defer.succeed(pdu)
+
+ def get_json(uri, args={}, headers=None):
+ address_domain = args["address"].split("@")[1]
+ return defer.succeed({"hs": address_domain})
+
+ def post_json_get_json(uri, post_json, args={}, headers=None):
+ token = "".join(random.choice(string.ascii_letters) for _ in range(10))
+ return defer.succeed(
+ {
+ "token": token,
+ "public_keys": [
+ {
+ "public_key": "serverpublickey",
+ "key_validity_url": "https://testis/pubkey/isvalid",
+ },
+ {
+ "public_key": "phemeralpublickey",
+ "key_validity_url": "https://testis/pubkey/ephemeral/isvalid",
+ },
+ ],
+ "display_name": "f...@b...",
+ }
+ )
+
+ mock_federation_client = Mock(spec=["send_invite"])
+ mock_federation_client.send_invite.side_effect = send_invite
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"],)
+ # Mocking the response for /info on the IS API.
+ mock_http_client.get_json.side_effect = get_json
+ # Mocking the response for /store-invite on the IS API.
+ mock_http_client.post_json_get_json.side_effect = post_json_get_json
+ self.hs = self.setup_test_homeserver(
+ config=config,
+ federation_client=mock_federation_client,
+ simple_http_client=mock_http_client,
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.blacklisting_http_client = (
+ mock_http_client
+ )
+
+ self.third_party_event_rules = self.hs.get_third_party_event_rules()
+
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.restricted_room = self.create_room()
+ self.unrestricted_room = self.create_room(rule=AccessRules.UNRESTRICTED)
+ self.direct_rooms = [
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ ]
+
+ self.invitee_id = self.register_user("invitee", "test")
+ self.invitee_tok = self.login("invitee", "test")
+
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ )
+
+ def test_create_room_no_rule(self):
+ """Tests that creating a room with no rule will set the default."""
+ room_id = self.create_room()
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.RESTRICTED)
+
+ def test_create_room_direct_no_rule(self):
+ """Tests that creating a direct room with no rule will set the default."""
+ room_id = self.create_room(direct=True)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.DIRECT)
+
+ def test_create_room_valid_rule(self):
+ """Tests that creating a room with a valid rule will set the right."""
+ room_id = self.create_room(rule=AccessRules.UNRESTRICTED)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.UNRESTRICTED)
+
+ def test_create_room_invalid_rule(self):
+ """Tests that creating a room with an invalid rule will set fail."""
+ self.create_room(rule=AccessRules.DIRECT, expected_code=400)
+
+ def test_create_room_direct_invalid_rule(self):
+ """Tests that creating a direct room with an invalid rule will fail.
+ """
+ self.create_room(direct=True, rule=AccessRules.RESTRICTED, expected_code=400)
+
+ def test_create_room_default_power_level_rules(self):
+ """Tests that a room created with no power level overrides instead uses the dinum
+ defaults
+ """
+ room_id = self.create_room(direct=True, rule=AccessRules.DIRECT)
+ power_levels = self.helper.get_state(room_id, "m.room.power_levels", self.tok)
+
+ # Inviting another user should require PL50, even in private rooms
+ self.assertEqual(power_levels["invite"], 50)
+ # Sending arbitrary state events should require PL100
+ self.assertEqual(power_levels["state_default"], 100)
+
+ def test_create_room_fails_on_incorrect_power_level_rules(self):
+ """Tests that a room created with power levels lower than that required are rejected"""
+ modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id)
+ modified_power_levels["invite"] = 0
+ modified_power_levels["state_default"] = 50
+
+ self.create_room(
+ direct=True,
+ rule=AccessRules.DIRECT,
+ initial_state=[
+ {"type": "m.room.power_levels", "content": modified_power_levels}
+ ],
+ expected_code=400,
+ )
+
+ def test_existing_room_can_change_power_levels(self):
+ """Tests that a room created with default power levels can have their power levels
+ dropped after room creation
+ """
+ # Creates a room with the default power levels
+ room_id = self.create_room(
+ direct=True, rule=AccessRules.DIRECT, expected_code=200,
+ )
+
+ # Attempt to drop invite and state_default power levels after the fact
+ room_power_levels = self.helper.get_state(
+ room_id, "m.room.power_levels", self.tok
+ )
+ room_power_levels["invite"] = 0
+ room_power_levels["state_default"] = 50
+ self.helper.send_state(
+ room_id, "m.room.power_levels", room_power_levels, self.tok
+ )
+
+ def test_public_room(self):
+ """Tests that it's only possible to have a room listed in the public room list
+ if the access rule is restricted.
+ """
+ # Creating a room with the public_chat preset should succeed and set the access
+ # rule to restricted.
+ preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT)
+ self.assertEqual(
+ self.current_rule_in_room(preset_room_id), AccessRules.RESTRICTED
+ )
+
+ # Creating a room with the public join rule in its initial state should succeed
+ # and set the access rule to restricted.
+ init_state_room_id = self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ]
+ )
+ self.assertEqual(
+ self.current_rule_in_room(init_state_room_id), AccessRules.RESTRICTED
+ )
+
+ # List preset_room_id in the public room list
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/directory/list/room/%s" % (preset_room_id,),
+ {"visibility": "public"},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # List init_state_room_id in the public room list
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/directory/list/room/%s" % (init_state_room_id,),
+ {"visibility": "public"},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Changing access rule to unrestricted should fail.
+ self.change_rule_in_room(
+ preset_room_id, AccessRules.UNRESTRICTED, expected_code=403
+ )
+ self.change_rule_in_room(
+ init_state_room_id, AccessRules.UNRESTRICTED, expected_code=403
+ )
+
+ # Changing access rule to direct should fail.
+ self.change_rule_in_room(preset_room_id, AccessRules.DIRECT, expected_code=403)
+ self.change_rule_in_room(
+ init_state_room_id, AccessRules.DIRECT, expected_code=403
+ )
+
+ # Creating a new room with the public_chat preset and an access rule of direct
+ # should fail.
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=AccessRules.DIRECT,
+ expected_code=400,
+ )
+
+ # Changing join rule to public in an direct room should fail.
+ self.change_join_rule_in_room(
+ self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403
+ )
+
+ def test_restricted(self):
+ """Tests that in restricted mode we're unable to invite users from blacklisted
+ servers but can invite other users.
+
+ Also tests that the room can be published to, and removed from, the public room
+ list.
+ """
+ # We can't invite a user from a forbidden HS.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can invite a user which HS isn't forbidden.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:allowed_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.restricted_room,
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.restricted_room,
+ expected_code=200,
+ )
+
+ # We are allowed to publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # We are allowed to remove the room from the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room
+ data = {"visibility": "private"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ def test_direct(self):
+ """Tests that, in direct mode, other users than the initial two can't be invited,
+ but the following scenario works:
+ * invited user joins the room
+ * invited user leaves the room
+ * room creator re-invites invited user
+
+ Tests that a user from a HS that's in the list of forbidden domains (to use
+ in restricted mode) can be invited.
+
+ Tests that the room cannot be published to the public room list.
+ """
+ not_invited_user = "@not_invited:forbidden_domain"
+
+ # We can't invite a new user to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # The invited user can join the room.
+ self.helper.join(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can leave the room.
+ self.helper.leave(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can be re-invited to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # If we're alone in the room and have always been the only member, we can invite
+ # someone.
+ self.helper.invite(
+ room=self.direct_rooms[1],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # Disable the 3pid invite ratelimiter
+ burst = self.hs.config.rc_third_party_invite.burst_count
+ per_second = self.hs.config.rc_third_party_invite.per_second
+ self.hs.config.rc_third_party_invite.burst_count = 10
+ self.hs.config.rc_third_party_invite.per_second = 0.1
+
+ # We can't send a 3PID invite to a room that already has two members.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[0],
+ expected_code=403,
+ )
+
+ # We can't send a 3PID invite to a room that already has a pending invite.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[1],
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to a room in which we've always been the only member.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to a room in which there's a 3PID invite.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=403,
+ )
+
+ self.hs.config.rc_third_party_invite.burst_count = burst
+ self.hs.config.rc_third_party_invite.per_second = per_second
+
+ # We can't publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.direct_rooms[0]
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.code, 403, channel.result)
+
+ def test_unrestricted(self):
+ """Tests that, in unrestricted mode, we can invite whoever we want, but we can
+ only change the power level of users that wouldn't be forbidden in restricted
+ mode.
+
+ Tests that the room cannot be published to the public room list.
+ """
+ # We can invite
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:not_forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a power level event that doesn't redefine the default PL or set a
+ # non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a power level event that redefines the default PL and doesn't set
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={
+ "users": {self.user_id: 100, "@test:not_forbidden_domain": 10},
+ "users_default": 10,
+ },
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't send a power level event that doesn't redefines the default PL but sets
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.unrestricted_room
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.code, 403, channel.result)
+
+ def test_change_rules(self):
+ """Tests that we can only change the current rule from restricted to
+ unrestricted.
+ """
+ # We can't change the rule from restricted to direct.
+ self.change_rule_in_room(
+ room_id=self.restricted_room, new_rule=AccessRules.DIRECT, expected_code=403
+ )
+
+ # We can change the rule from restricted to unrestricted.
+ # Note that this changes self.restricted_room to an unrestricted room
+ self.change_rule_in_room(
+ room_id=self.restricted_room,
+ new_rule=AccessRules.UNRESTRICTED,
+ expected_code=200,
+ )
+
+ # We can't change the rule from unrestricted to restricted.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=AccessRules.RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from unrestricted to direct.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=AccessRules.DIRECT,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to restricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=AccessRules.RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=AccessRules.UNRESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't publish a room to the public room list and then change its rule to
+ # unrestricted
+
+ # Create a restricted room
+ test_room_id = self.create_room(rule=AccessRules.RESTRICTED)
+
+ # Publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % test_room_id
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Attempt to switch the room to "unrestricted"
+ self.change_rule_in_room(
+ room_id=test_room_id, new_rule=AccessRules.UNRESTRICTED, expected_code=403
+ )
+
+ # Attempt to switch the room to "direct"
+ self.change_rule_in_room(
+ room_id=test_room_id, new_rule=AccessRules.DIRECT, expected_code=403
+ )
+
+ def test_change_room_avatar(self):
+ """Tests that changing the room avatar is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ avatar_content = {
+ "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394},
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ }
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_name(self):
+ """Tests that changing the room name is always allowed unless the room is a direct
+ chat, in which case it's forbidden.
+ """
+
+ name_content = {"name": "My super room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_topic(self):
+ """Tests that changing the room topic is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ topic_content = {"topic": "Welcome to this room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_revoke_3pid_invite_direct(self):
+ """Tests that revoking a 3PID invite doesn't cause the room access rules module to
+ confuse the revokation as a new 3PID invite.
+ """
+ invite_token = "sometoken"
+
+ invite_body = {
+ "display_name": "ker...@exa...",
+ "public_keys": [
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ },
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I",
+ },
+ ],
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ }
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=self.tok,
+ )
+
+ invite_token = "someothertoken"
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ def test_check_event_allowed(self):
+ """Tests that RoomAccessRules.check_event_allowed behaves accordingly.
+
+ It tests that:
+ * forbidden users cannot join restricted rooms.
+ * forbidden users can only join unrestricted rooms if they have an invite.
+ """
+ event_creator = self.hs.get_event_creation_handler()
+
+ # Test that forbidden users cannot join restricted rooms
+ requester = create_requester(self.user_id)
+ allowed_requester = create_requester("@user:allowed_domain")
+ forbidden_requester = create_requester("@user:forbidden_domain")
+
+ # Create a join event for a forbidden user
+ forbidden_join_event, forbidden_join_event_context = self.get_success(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.restricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ )
+ )
+
+ # Create a join event for an allowed user
+ allowed_join_event, allowed_join_event_context = self.get_success(
+ event_creator.create_event(
+ allowed_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.restricted_room,
+ "sender": allowed_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": allowed_requester.user.to_string(),
+ },
+ )
+ )
+
+ # Assert a join event from a forbidden user to a restricted room is rejected
+ can_join = self.get_success(
+ self.third_party_event_rules.check_event_allowed(
+ forbidden_join_event, forbidden_join_event_context
+ )
+ )
+ self.assertFalse(can_join)
+
+ # But a join event from an non-forbidden user to a restricted room is allowed
+ can_join = self.get_success(
+ self.third_party_event_rules.check_event_allowed(
+ allowed_join_event, allowed_join_event_context
+ )
+ )
+ self.assertTrue(can_join)
+
+ # Test that forbidden users can only join unrestricted rooms if they have an invite
+
+ # Recreate the forbidden join event for the unrestricted room instead
+ forbidden_join_event, forbidden_join_event_context = self.get_success(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.unrestricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ )
+ )
+
+ # A forbidden user without an invite should not be able to join an unrestricted room
+ can_join = self.get_success(
+ self.third_party_event_rules.check_event_allowed(
+ forbidden_join_event, forbidden_join_event_context
+ )
+ )
+ self.assertFalse(can_join)
+
+ # However, if we then invite this user...
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=requester.user.to_string(),
+ targ=forbidden_requester.user.to_string(),
+ tok=self.tok,
+ )
+
+ # And create another join event, making sure that its context states it's coming
+ # in after the above invite was made...
+ forbidden_join_event, forbidden_join_event_context = self.get_success(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.unrestricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ )
+ )
+
+ # Then the forbidden user should be able to join!
+ can_join = self.get_success(
+ self.third_party_event_rules.check_event_allowed(
+ forbidden_join_event, forbidden_join_event_context
+ )
+ )
+ self.assertTrue(can_join)
+
+ def test_freezing_a_room(self):
+ """Tests that the power levels in a room change to prevent new events from
+ non-admin users when the last admin of a room leaves.
+ """
+
+ def freeze_room_with_id_and_power_levels(
+ room_id: str, custom_power_levels_content: Optional[JsonDict] = None,
+ ):
+ # Invite a user to the room, they join with PL 0
+ self.helper.invite(
+ room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok,
+ )
+
+ # Invitee joins the room
+ self.helper.join(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ if not custom_power_levels_content:
+ # Retrieve the room's current power levels event content
+ power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.tok,
+ )
+ else:
+ power_levels = custom_power_levels_content
+
+ # Override the room's power levels with the given power levels content
+ self.helper.send_state(
+ room_id=room_id,
+ event_type="m.room.power_levels",
+ body=custom_power_levels_content,
+ tok=self.tok,
+ )
+
+ # Ensure that the invitee leaving the room does not change the power levels
+ self.helper.leave(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ # Retrieve the new power levels of the room
+ new_power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.tok,
+ )
+
+ # Ensure they have not changed
+ self.assertDictEqual(power_levels, new_power_levels)
+
+ # Invite the user back again
+ self.helper.invite(
+ room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok,
+ )
+
+ # Invitee joins the room
+ self.helper.join(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ # Now the admin leaves the room
+ self.helper.leave(
+ room=room_id, user=self.user_id, tok=self.tok,
+ )
+
+ # Check the power levels again
+ new_power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.invitee_tok,
+ )
+
+ # Ensure that the new power levels prevent anyone but admins from sending
+ # certain events
+ self.assertEquals(new_power_levels["state_default"], 100)
+ self.assertEquals(new_power_levels["events_default"], 100)
+ self.assertEquals(new_power_levels["kick"], 100)
+ self.assertEquals(new_power_levels["invite"], 100)
+ self.assertEquals(new_power_levels["ban"], 100)
+ self.assertEquals(new_power_levels["redact"], 100)
+ self.assertDictEqual(new_power_levels["events"], {})
+ self.assertDictEqual(new_power_levels["users"], {self.user_id: 100})
+
+ # Ensure new users entering the room aren't going to immediately become admins
+ self.assertEquals(new_power_levels["users_default"], 0)
+
+ # Test that freezing a room with the default power level state event content works
+ room1 = self.create_room()
+ freeze_room_with_id_and_power_levels(room1)
+
+ # Test that freezing a room with a power level state event that is missing
+ # `state_default` and `event_default` keys behaves as expected
+ room2 = self.create_room()
+ freeze_room_with_id_and_power_levels(
+ room2,
+ {
+ "ban": 50,
+ "events": {
+ "m.room.avatar": 50,
+ "m.room.canonical_alias": 50,
+ "m.room.history_visibility": 100,
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ },
+ "invite": 0,
+ "kick": 50,
+ "redact": 50,
+ "users": {self.user_id: 100},
+ "users_default": 0,
+ # Explicitly remove `state_default` and `event_default` keys
+ },
+ )
+
+ # Test that freezing a room with a power level state event that is *additionally*
+ # missing `ban`, `invite`, `kick` and `redact` keys behaves as expected
+ room3 = self.create_room()
+ freeze_room_with_id_and_power_levels(
+ room3,
+ {
+ "events": {
+ "m.room.avatar": 50,
+ "m.room.canonical_alias": 50,
+ "m.room.history_visibility": 100,
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ },
+ "users": {self.user_id: 100},
+ "users_default": 0,
+ # Explicitly remove `state_default` and `event_default` keys
+ # Explicitly remove `ban`, `invite`, `kick` and `redact` keys
+ },
+ )
+
+ def create_room(
+ self,
+ direct=False,
+ rule=None,
+ preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ initial_state=None,
+ expected_code=200,
+ ):
+ content = {"is_direct": direct, "preset": preset}
+
+ if rule:
+ content["initial_state"] = [
+ {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}}
+ ]
+
+ if initial_state:
+ if "initial_state" not in content:
+ content["initial_state"] = []
+
+ content["initial_state"] += initial_state
+
+ request, channel = self.make_request(
+ "POST", "/_matrix/client/r0/createRoom", content, access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ if expected_code == 200:
+ return channel.json_body["room_id"]
+
+ def current_rule_in_room(self, room_id):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["rule"]
+
+ def change_rule_in_room(self, room_id, new_rule, expected_code=200):
+ data = {"rule": new_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200):
+ data = {"join_rule": new_join_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_threepid_invite(self, address, room_id, expected_code=200):
+ params = {"id_server": "testis", "medium": "email", "address": address}
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/invite" % room_id,
+ json.dumps(params),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_state_with_state_key(
+ self, room_id, event_type, state_key, body, tok, expect_code=200
+ ):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id,
+ event_type,
+ state_key,
+ )
+
+ request, channel = self.make_request(
+ "PUT", path, json.dumps(body), access_token=tok
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expect_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
new file mode 100644
index 0000000000..d03e121664
--- /dev/null
+++ b/tests/rest/client/test_third_party_rules.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import threading
+from typing import Dict
+
+from mock import Mock
+
+from synapse.events import EventBase
+from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import Requester, StateMap
+
+from tests import unittest
+
+thread_local = threading.local()
+
+
+class ThirdPartyRulesTestModule:
+ def __init__(self, config: Dict, module_api: ModuleApi):
+ # keep a record of the "current" rules module, so that the test can patch
+ # it if desired.
+ thread_local.rules_module = self
+ self.module_api = module_api
+
+ async def on_create_room(
+ self, requester: Requester, config: dict, is_requester_admin: bool
+ ):
+ return True
+
+ async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ return config
+
+
+def current_rules_module() -> ThirdPartyRulesTestModule:
+ return thread_local.rules_module
+
+
+class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["third_party_event_rules"] = {
+ "module": __name__ + ".ThirdPartyRulesTestModule",
+ "config": {},
+ }
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ # Create a user and room to play with during the tests
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_third_party_rules(self):
+ """Tests that a forbidden event is forbidden from being sent, but an allowed one
+ can be sent.
+ """
+ # patch the rules module with a Mock which will return False for some event
+ # types
+ async def check(ev, state):
+ return ev.type != "foo.bar.forbidden"
+
+ callback = Mock(spec=[], side_effect=check)
+ current_rules_module().check_event_allowed = callback
+
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ callback.assert_called_once()
+
+ # there should be various state events in the state arg: do some basic checks
+ state_arg = callback.call_args[0][1]
+ for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
+ self.assertIn(k, state_arg)
+ ev = state_arg[k]
+ self.assertEqual(ev.type, k[0])
+ self.assertEqual(ev.state_key, k[1])
+
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_modify_event(self):
+ """Tests that the module can successfully tweak an event before it is persisted.
+ """
+ # first patch the event checker so that it will modify the event
+ async def check(ev: EventBase, state):
+ ev.content = {"x": "y"}
+ return True
+
+ current_rules_module().check_event_allowed = check
+
+ # now send the event
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+ {"x": "x"},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ event_id = channel.json_body["event_id"]
+
+ # ... and check that it got modified
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["x"], "y")
+
+ def test_send_event(self):
+ """Tests that the module can send an event into a room via the module api"""
+ content = {
+ "msgtype": "m.text",
+ "body": "Hello!",
+ }
+ event_dict = {
+ "room_id": self.room_id,
+ "type": "m.room.message",
+ "content": content,
+ "sender": self.user_id,
+ }
+ event = self.get_success(
+ current_rules_module().module_api.create_and_send_event_into_room(
+ event_dict
+ )
+ ) # type: EventBase
+
+ self.assertEquals(event.sender, self.user_id)
+ self.assertEquals(event.room_id, self.room_id)
+ self.assertEquals(event.type, "m.room.message")
+ self.assertEquals(event.content, content)
diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py
deleted file mode 100644
index 7167fc56b6..0000000000
--- a/tests/rest/client/third_party_rules.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-
-from tests import unittest
-
-
-class ThirdPartyRulesTestModule(object):
- def __init__(self, config):
- pass
-
- def check_event_allowed(self, event, context):
- if event.type == "foo.bar.forbidden":
- return False
- else:
- return True
-
- @staticmethod
- def parse_config(config):
- return config
-
-
-class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
- servlets = [
- admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def make_homeserver(self, reactor, clock):
- config = self.default_config()
- config["third_party_event_rules"] = {
- "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule",
- "config": {},
- }
-
- self.hs = self.setup_test_homeserver(config=config)
- return self.hs
-
- def test_third_party_rules(self):
- """Tests that a forbidden event is forbidden from being sent, but an allowed one
- can be sent.
- """
- user_id = self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
-
- room_id = self.helper.create_room_as(user_id, tok=tok)
-
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id,
- {},
- access_token=tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"200", channel.result)
-
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id,
- {},
- access_token=tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index db52725cfe..2668662c9e 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 152a5182fa..0a51aeff92 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -14,11 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
import os
import re
from email.parser import Parser
+from typing import Optional
import pkg_resources
@@ -29,6 +29,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from tests import unittest
+from tests.unittest import override_config
class PasswordResetTestCase(unittest.HomeserverTestCase):
@@ -668,16 +669,104 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
- def _request_token(self, email, client_secret):
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link(self):
+ """Tests a valid next_link parameter value with no whitelist (good case)"""
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/a/good/site",
+ expect_code=200,
+ )
+
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link_exotic_protocol(self):
+ """Tests using a esoteric protocol as a next_link parameter value.
+ Someone may be hosting a client on IPFS etc.
+ """
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
+ expect_code=200,
+ )
+
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link_file_uri(self):
+ """Tests next_link parameters cannot be file URI"""
+ # Attempt to use a next_link value that points to the local disk
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="file:///host/path",
+ expect_code=400,
+ )
+
+ @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
+ def test_next_link_domain_whitelist(self):
+ """Tests next_link parameters must fit the whitelist if provided"""
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/some/good/page",
+ expect_code=200,
+ )
+
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.org/some/also/good/page",
+ expect_code=200,
+ )
+
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://bad.example.org/some/bad/page",
+ expect_code=400,
+ )
+
+ @override_config({"next_link_domain_whitelist": []})
+ def test_empty_next_link_domain_whitelist(self):
+ """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
+ disallowed
+ """
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/a/page",
+ expect_code=400,
+ )
+
+ def _request_token(
+ self,
+ email: str,
+ client_secret: str,
+ next_link: Optional[str] = None,
+ expect_code: int = 200,
+ ) -> str:
+ """Request a validation token to add an email address to a user's account
+
+ Args:
+ email: The email address to validate
+ client_secret: A secret string
+ next_link: A link to redirect the user to after validation
+ expect_code: Expected return code of the call
+
+ Returns:
+ The ID of the new threepid validation session
+ """
+ body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
+ if next_link:
+ body["next_link"] = next_link
+
request, channel = self.make_request(
- "POST",
- b"account/3pid/email/requestToken",
- {"client_secret": client_secret, "email": email, "send_attempt": 1},
+ "POST", b"account/3pid/email/requestToken", body,
)
self.render(request)
- self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(expect_code, channel.code, channel.result)
- return channel.json_body["sid"]
+ return channel.json_body.get("sid")
def _request_token_invalid_email(
self, email, expected_errcode, expected_error, client_secret="foobar",
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 53a43038f0..ecf697e5e0 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,8 +19,12 @@ import datetime
import json
import os
+from mock import Mock
+
import pkg_resources
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
@@ -87,14 +91,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")
- def test_POST_bad_username(self):
- request_data = json.dumps({"username": 777, "password": "monkey"})
- request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
-
- self.assertEquals(channel.result["code"], b"400", channel.result)
- self.assertEquals(channel.json_body["error"], "Invalid username")
-
def test_POST_user_valid(self):
user_id = "@kermit:test"
device_id = "frogfone"
@@ -160,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -186,7 +182,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -303,6 +299,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(channel.json_body.get("sid"))
+class RegisterHideProfileTestCase(unittest.HomeserverTestCase):
+
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.url = b"/_matrix/client/r0/register"
+
+ config = self.default_config()
+ config["enable_registration"] = True
+ config["show_users_in_user_directory"] = False
+ config["replicate_user_profiles_to"] = ["fakeserver"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_profile_hidden(self):
+ user_id = self.register_user("kermit", "monkey")
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # We expect post_json_get_json to have been called twice: once with the original
+ # profile and once with the None profile resulting from the request to hide it
+ # from the user directory.
+ self.assertEqual(post_json.call_count, 2, post_json.call_args_list)
+
+ # Get the args (and not kwargs) passed to post_json.
+ args = post_json.call_args[0]
+ # Make sure the last call was attempting to replicate profiles.
+ split_uri = args[0].split("/")
+ self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0])
+ # Make sure the last profile update was overriding the user's profile to None.
+ self.assertEqual(args[1]["batch"][user_id], None, args[1])
+
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -312,6 +349,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
logout.register_servlets,
account_validity.register_servlets,
+ account.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -437,6 +475,155 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.client.v1.profile.register_servlets,
+ synapse.rest.client.v1.room.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Set accounts to expire after a week
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ config["replicate_user_profiles_to"] = "test.is"
+
+ # Mock homeserver requests to an identity server
+ mock_http_client = Mock(spec=["post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_expired_user_in_directory(self):
+ """Test that an expired user is hidden in the user directory"""
+ # Create an admin user to search the user directory
+ admin_id = self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ # Ensure the admin never expires
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": admin_id,
+ "expiration_ts": 999999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Mock the homeserver's HTTP client
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # Create a user
+ username = "kermit"
+ user_id = self.register_user(username, "monkey")
+ self.login(username, "monkey")
+ self.get_success(
+ self.hs.get_datastore().set_profile_displayname(username, "mr.kermit", 1)
+ )
+
+ # Check that a full profile for this user is replicated
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+ # Expire the user
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Wait for the background job to run which hides expired users in the directory
+ self.reactor.advance(60 * 60 * 1000)
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's None, signifying that the user should be removed from the user
+ # directory because they were expired
+ replicated_content = batch[user_id]
+ self.assertIsNone(replicated_content)
+
+ # Now renew the user, and check they get replicated again to the identity server
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 99999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.pump(10)
+ self.reactor.advance(10)
+ self.pump()
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None, signifying that the user is back in the user
+ # directory
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -587,7 +774,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"POST", "account/deactivate", request_data, access_token=tok
)
self.render(request)
- self.assertEqual(request.code, 200)
+ self.assertEqual(request.code, 200, channel.result)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py
new file mode 100644
index 0000000000..1accc70dc9
--- /dev/null
+++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+
+import synapse.rest.admin
+from synapse.config._base import ConfigError
+from synapse.rest.client.v1 import login, room
+from synapse.rulecheck.domain_rule_checker import DomainRuleChecker
+
+from tests import unittest
+from tests.server import make_request, render
+
+
+class DomainRuleCheckerTestCase(unittest.TestCase):
+ def test_allowed(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ "domains_prevented_from_being_invited_to_published_rooms": ["target_two"],
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_one", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_two", "test:target_two", None, "room", False
+ )
+ )
+
+ # User can invite internal user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test1:target_one", None, "room", False, True
+ )
+ )
+
+ # User can invite external user to a non-published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, False
+ )
+ )
+
+ def test_disallowed(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ "source_four": [],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_one", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_one", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_four", "test:target_one", None, "room", False
+ )
+ )
+
+ # User cannot invite external user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, True
+ )
+ )
+
+ def test_default_allow(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_default_deny(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_config_parse(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ self.assertEquals(config, DomainRuleChecker.parse_config(config))
+
+ def test_config_parse_failure(self):
+ config = {
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ }
+ }
+ self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config)
+
+
+class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["localhost"]
+
+ config["spam_checker"] = {
+ "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker",
+ "config": {
+ "default": True,
+ "domain_mapping": {},
+ "can_only_join_rooms_with_invite": True,
+ "can_only_create_one_to_one_rooms": True,
+ "can_only_invite_during_room_creation": True,
+ "can_invite_by_third_party_id": False,
+ },
+ }
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user_id = self.register_user("admin_user", "pass", admin=True)
+ self.admin_access_token = self.login("admin_user", "pass")
+
+ self.normal_user_id = self.register_user("normal_user", "pass", admin=False)
+ self.normal_access_token = self.login("normal_user", "pass")
+
+ self.other_user_id = self.register_user("other_user", "pass", admin=False)
+
+ def test_admin_can_create_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_normal_user_cannot_create_empty_room(self):
+ channel = self._create_room(self.normal_access_token)
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_cannot_create_room_with_multiple_invites(self):
+ channel = self._create_room(
+ self.normal_access_token,
+ content={"invite": [self.other_user_id, self.admin_user_id]},
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly counts both normal and third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [self.other_user_id],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly rejects third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_can_room_with_single_invites(self):
+ channel = self._create_room(
+ self.normal_access_token, content={"invite": [self.other_user_id]}
+ )
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_cannot_join_public_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403
+ )
+
+ def test_can_join_invited_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ def test_cannot_invite(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ def test_cannot_3pid_invite(self):
+ """Test that unbound 3pid invites get rejected.
+ """
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/invite" % (room_id),
+ {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"},
+ access_token=self.normal_access_token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 403, channel.result["body"])
+
+ def _create_room(self, token, content={}):
+ path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,)
+
+ request, channel = make_request(
+ self.hs.get_reactor(),
+ "POST",
+ path,
+ content=json.dumps(content).encode("utf8"),
+ )
+ render(request, self.resource, self.hs.get_reactor())
+
+ return channel
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 2858d13558..23db821fb7 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
@@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 98b74890d5..17fbde284a 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
)
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@@ -207,7 +208,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -219,9 +222,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -349,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
- self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
+ self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index efcaeef1e7..13bcac743a 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_insert(
- table="tablename", values={"columname": "Value"}
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_insert(
+ table="tablename", values={"columname": "Value"}
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_insert(
- table="tablename",
- # Use OrderedDict() so we can assert on the SQL generated
- values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_insert(
+ table="tablename",
+ # Use OrderedDict() so we can assert on the SQL generated
+ values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+ )
)
self.mock_txn.execute.assert_called_with(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 3fab5a5248..43639ca286 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
- self.requester = Requester(self.user, None, False, None, None)
+ self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
- self.requester = Requester(self.user, None, False, None, None)
+ self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
@@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
"3"
] = 300000
+
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
# All entries within time frame
self.assertEqual(
@@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
3,
)
# Oldest room to expire
- self.pump(1)
+ self.pump(1.01)
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
self.assertEqual(
len(
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a7b85004e5..949846fe33 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
room_creator = self.hs.get_room_creation_handler()
user = UserID("alice", "test")
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
# Real events, forward extremities
events = [(3, 2), (6, 2), (4, 6)]
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 857db071d4..238bad5b45 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return self.store.db_pool.simple_insert(
- "events",
- {
- "stream_ordering": so,
- "received_ts": ts,
- "event_id": "event%i" % so,
- "type": "",
- "room_id": "",
- "content": "",
- "processed": True,
- "outlier": False,
- "topological_ordering": 0,
- "depth": 0,
- },
+ return defer.ensureDeferred(
+ self.store.db_pool.simple_insert(
+ "events",
+ {
+ "stream_ordering": so,
+ "received_ts": ts,
+ "event_id": "event%i" % so,
+ "type": "",
+ "room_id": "",
+ "content": "",
+ "processed": True,
+ "outlier": False,
+ "topological_ordering": 0,
+ "depth": 0,
+ },
+ )
)
# start with the base case where there are no events in the table
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index e845410dae..7a05194653 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -88,7 +88,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -98,12 +98,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled
@@ -116,8 +116,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second")
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token("first"), 3)
- self.assertEqual(first_id_gen.get_current_token("second"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -166,7 +166,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -176,9 +176,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index ab0df5ea93..745fa15e26 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -34,9 +34,13 @@ class DataStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_users_paginate(self):
- yield self.store.register_user(self.user.to_string(), "pass")
- yield self.store.create_profile(self.user.localpart)
- yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+ yield defer.ensureDeferred(
+ self.store.register_user(self.user.to_string(), "pass")
+ )
+ yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.user.localpart, self.displayname, 1)
+ )
users, total = yield self.store.get_users_paginate(
0, 10, name="bc", guests=False
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..16a32cb819 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,9 +33,11 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.create_profile(self.u_frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
- yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
+ )
self.assertEquals(
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
@@ -43,10 +45,12 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.create_profile(self.u_frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
- yield self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.u_frank.localpart, "http://my.site/here", 1
+ )
)
self.assertEquals(
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index a6012c973d..918387733b 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
+from synapse.api.errors import NotFoundError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -46,30 +47,19 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage()
# Get the topological token
- event = store.get_topological_token_for_event(last["event_id"])
- self.pump()
- event = self.successResultOf(event)
-
- # Purge everything before this topological token
- purge = defer.ensureDeferred(
- storage.purge_events.purge_history(self.room_id, event, True)
+ event = self.get_success(
+ store.get_topological_token_for_event(last["event_id"])
)
- self.pump()
- self.assertEqual(self.successResultOf(purge), None)
- # Try and get the events
- get_first = store.get_event(first["event_id"])
- get_second = store.get_event(second["event_id"])
- get_third = store.get_event(third["event_id"])
- get_last = store.get_event(last["event_id"])
- self.pump()
+ # Purge everything before this topological token
+ self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
- self.failureResultOf(get_first)
- self.failureResultOf(get_second)
- self.failureResultOf(get_third)
- self.successResultOf(get_last)
+ self.get_failure(store.get_event(first["event_id"]), NotFoundError)
+ self.get_failure(store.get_event(second["event_id"]), NotFoundError)
+ self.get_failure(store.get_event(third["event_id"]), NotFoundError)
+ self.get_success(store.get_event(last["event_id"]))
def test_purge_wont_delete_extrems(self):
"""
@@ -84,9 +74,9 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
- event = storage.get_topological_token_for_event(last["event_id"])
- self.pump()
- event = self.successResultOf(event)
+ event = self.get_success(
+ storage.get_topological_token_for_event(last["event_id"])
+ )
event = "t{}-{}".format(
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
)
@@ -98,14 +88,7 @@ class PurgeTests(HomeserverTestCase):
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
- get_first = storage.get_event(first["event_id"])
- get_second = storage.get_event(second["event_id"])
- get_third = storage.get_event(third["event_id"])
- get_last = storage.get_event(last["event_id"])
- self.pump()
-
- # Nothing is deleted.
- self.successResultOf(get_first)
- self.successResultOf(get_second)
- self.successResultOf(get_third)
- self.successResultOf(get_last)
+ self.get_success(storage.get_event(first["event_id"]))
+ self.get_success(storage.get_event(second["event_id"]))
+ self.get_success(storage.get_event(third["event_id"]))
+ self.get_success(storage.get_event(last["event_id"]))
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 17c9da4838..d98fe8754d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.
diff --git a/tests/test_federation.py b/tests/test_federation.py
index f2fa42bfb9..4a4548433f 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -42,7 +42,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
user_id = UserID("us", "test")
- our_user = Requester(user_id, None, False, None, None)
+ our_user = Requester(user_id, None, False, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
room_deferred = ensureDeferred(
room_creator.create_room(
diff --git a/tests/test_server.py b/tests/test_server.py
index d628070e48..655c918a15 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -178,7 +178,6 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
- self.assertEqual(channel.headers.getRawHeaders(b"Content-Length"), [b"15"])
class OptionsResourceTests(unittest.TestCase):
diff --git a/tests/test_types.py b/tests/test_types.py
index 480bea1bdc..d4a722a30f 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -12,9 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import string_types
from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ GroupID,
+ RoomAlias,
+ UserID,
+ map_username_to_mxid_localpart,
+ strip_invalid_mxid_characters,
+)
from tests import unittest
@@ -103,3 +110,16 @@ class MapUsernameTestCase(unittest.TestCase):
self.assertEqual(
map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast"
)
+
+
+class StripInvalidMxidCharactersTestCase(unittest.TestCase):
+ def test_return_type(self):
+ unstripped = strip_invalid_mxid_characters("test")
+ stripped = strip_invalid_mxid_characters("test@")
+
+ self.assertTrue(isinstance(unstripped, string_types), type(unstripped))
+ self.assertTrue(isinstance(stripped, string_types), type(stripped))
+
+ def test_strip(self):
+ stripped = strip_invalid_mxid_characters("test@")
+ self.assertEqual(stripped, "test", stripped)
diff --git a/tests/unittest.py b/tests/unittest.py
index d0bba3ddef..7b80999a74 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -250,7 +250,11 @@ class HomeserverTestCase(TestCase):
async def get_user_by_req(request, allow_guest=False, rights="access"):
return create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
+ UserID.from_string(self.helper.auth_user_id),
+ 1,
+ False,
+ False,
+ None,
)
self.hs.get_auth().get_user_by_req = get_user_by_req
@@ -540,7 +544,7 @@ class HomeserverTestCase(TestCase):
"""
event_creator = self.hs.get_event_creation_handler()
secrets = self.hs.get_secrets()
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
event, context = self.get_success(
event_creator.create_event(
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 4d2b9e0d64..0363735d4f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
- def list_fn(self, args1, arg2):
+ @descriptors.cachedList("fn", "args1")
+ async def list_fn(self, args1, arg2):
assert current_context().request == "c1"
# we want this to behave like an asynchronous function
- yield run_on_reactor()
+ await run_on_reactor()
assert current_context().request == "c1"
return self.mock(args1, arg2)
@@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
- def list_fn(self, args1, arg2):
+ @descriptors.cachedList("fn", "args1")
+ async def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
- yield run_on_reactor()
+ await run_on_reactor()
return self.mock(args1, arg2)
obj = Cls()
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index 4f4da29a98..8491f7cc83 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -28,9 +28,6 @@ class StringUtilsTestCase(unittest.TestCase):
"_--something==_",
"...--==-18913",
"8Dj2odd-e9asd.cd==_--ddas-secret-",
- # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766
- # To be removed in a future release
- "SECRET:1234567890",
]
bad = [
diff --git a/tests/utils.py b/tests/utils.py
index a61cbdef44..d543f3ed32 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -173,6 +173,8 @@ def default_config(name, parse=False):
"update_user_directory": False,
"caches": {"global_factor": 1},
"listeners": [{"port": 0, "type": "http"}],
+ # Enable encryption by default in private rooms
+ "encryption_enabled_by_default_for_room_type": "invite",
}
if parse:
diff --git a/tox.ini b/tox.ini
index e5413eb110..050e36bc82 100644
--- a/tox.ini
+++ b/tox.ini
@@ -113,14 +113,14 @@ commands =
[testenv:packaging]
skip_install=True
deps =
- check-manifest
+ check-manifest==0.41
commands =
check-manifest
[testenv:check_codestyle]
skip_install = True
deps =
- flake8
+ flake8==3.8.3
flake8-comprehensions
# We pin so that our tests don't start failing on new releases of black.
black==19.10b0
@@ -138,7 +138,8 @@ commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev sc
skip_install = True
deps = towncrier>=18.6.0rc1
commands =
- python -m towncrier.check --compare-with=origin/develop
+ python -m towncrier.check --compare-with=origin/dinsic
+basepython = python3.6
[testenv:check-sampleconfig]
commands = {toxinidir}/scripts-dev/generate_sample_config --check
@@ -169,7 +170,7 @@ commands=
skip_install = True
deps =
{[base]deps}
- mypy==0.750
+ mypy==0.782
mypy-zope
env =
MYPYPATH = stubs/
@@ -190,6 +191,7 @@ commands = mypy \
synapse/handlers/message.py \
synapse/handlers/oidc_handler.py \
synapse/handlers/presence.py \
+ synapse/handlers/room.py \
synapse/handlers/room_member.py \
synapse/handlers/room_member_worker.py \
synapse/handlers/saml_handler.py \
|