diff --git a/.buildkite/docker-compose-env b/.buildkite/docker-compose-env
new file mode 100644
index 0000000000..85b102d07f
--- /dev/null
+++ b/.buildkite/docker-compose-env
@@ -0,0 +1,13 @@
+CI
+BUILDKITE
+BUILDKITE_BUILD_NUMBER
+BUILDKITE_BRANCH
+BUILDKITE_BUILD_NUMBER
+BUILDKITE_JOB_ID
+BUILDKITE_BUILD_URL
+BUILDKITE_PROJECT_SLUG
+BUILDKITE_COMMIT
+BUILDKITE_PULL_REQUEST
+BUILDKITE_TAG
+CODECOV_TOKEN
+TRIAL_FLAGS
diff --git a/.buildkite/docker-compose.py35.pg95.yaml b/.buildkite/docker-compose.py35.pg95.yaml
new file mode 100644
index 0000000000..c6e8280e65
--- /dev/null
+++ b/.buildkite/docker-compose.py35.pg95.yaml
@@ -0,0 +1,23 @@
+version: '3.1'
+
+services:
+
+ postgres:
+ image: postgres:9.5
+ environment:
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
+ command: -c fsync=off
+
+ testenv:
+ image: python:3.5
+ depends_on:
+ - postgres
+ env_file: docker-compose-env
+ environment:
+ SYNAPSE_POSTGRES_HOST: postgres
+ SYNAPSE_POSTGRES_USER: postgres
+ SYNAPSE_POSTGRES_PASSWORD: postgres
+ working_dir: /src
+ volumes:
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/docker-compose.py37.pg11.yaml b/.buildkite/docker-compose.py37.pg11.yaml
new file mode 100644
index 0000000000..411c37f213
--- /dev/null
+++ b/.buildkite/docker-compose.py37.pg11.yaml
@@ -0,0 +1,23 @@
+version: '3.1'
+
+services:
+
+ postgres:
+ image: postgres:11
+ environment:
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
+ command: -c fsync=off
+
+ testenv:
+ image: python:3.7
+ depends_on:
+ - postgres
+ env_file: docker-compose-env
+ environment:
+ SYNAPSE_POSTGRES_HOST: postgres
+ SYNAPSE_POSTGRES_USER: postgres
+ SYNAPSE_POSTGRES_PASSWORD: postgres
+ working_dir: /src
+ volumes:
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/docker-compose.py37.pg95.yaml b/.buildkite/docker-compose.py37.pg95.yaml
new file mode 100644
index 0000000000..54ca794072
--- /dev/null
+++ b/.buildkite/docker-compose.py37.pg95.yaml
@@ -0,0 +1,23 @@
+version: '3.1'
+
+services:
+
+ postgres:
+ image: postgres:9.5
+ environment:
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
+ command: -c fsync=off
+
+ testenv:
+ image: python:3.7
+ depends_on:
+ - postgres
+ env_file: docker-compose-env
+ environment:
+ SYNAPSE_POSTGRES_HOST: postgres
+ SYNAPSE_POSTGRES_USER: postgres
+ SYNAPSE_POSTGRES_PASSWORD: postgres
+ working_dir: /src
+ volumes:
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/docker-compose.py38.pg12.yaml b/.buildkite/docker-compose.py38.pg12.yaml
new file mode 100644
index 0000000000..934a34cf02
--- /dev/null
+++ b/.buildkite/docker-compose.py38.pg12.yaml
@@ -0,0 +1,23 @@
+version: '3.1'
+
+services:
+
+ postgres:
+ image: postgres:12
+ environment:
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
+ command: -c fsync=off
+
+ testenv:
+ image: python:3.8
+ depends_on:
+ - postgres
+ env_file: docker-compose-env
+ environment:
+ SYNAPSE_POSTGRES_HOST: postgres
+ SYNAPSE_POSTGRES_USER: postgres
+ SYNAPSE_POSTGRES_PASSWORD: postgres
+ working_dir: /src
+ volumes:
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/docker-compose.sytest.py37.redis.yaml b/.buildkite/docker-compose.sytest.py37.redis.yaml
new file mode 100644
index 0000000000..b9e80cc557
--- /dev/null
+++ b/.buildkite/docker-compose.sytest.py37.redis.yaml
@@ -0,0 +1,22 @@
+version: '3.1'
+
+services:
+
+ redis:
+ image: redis:5.0
+
+ sytest:
+ image: matrixdotorg/sytest-synapse:py37
+ depends_on:
+ - redis
+ env_file: docker-compose-env
+ environment:
+ POSTGRES: "1"
+ WORKERS: "1"
+ BLACKLIST: "synapse-blacklist-with-workers"
+ REDIS: "redis"
+ working_dir: "/src"
+ entrypoint: ""
+ volumes:
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}/logs:/logs
diff --git a/.buildkite/docker-compose.yaml b/.buildkite/docker-compose.yaml
new file mode 100644
index 0000000000..73d5ccdd5e
--- /dev/null
+++ b/.buildkite/docker-compose.yaml
@@ -0,0 +1,23 @@
+version: '3.1'
+
+services:
+
+ postgres:
+ image: postgres:${POSTGRES_VERSION?}
+ environment:
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
+ command: -c fsync=off
+
+ testenv:
+ image: python:${PYTHON_VERSION?}
+ depends_on:
+ - postgres
+ env_file: docker-compose-env
+ environment:
+ SYNAPSE_POSTGRES_HOST: postgres
+ SYNAPSE_POSTGRES_USER: postgres
+ SYNAPSE_POSTGRES_PASSWORD: postgres
+ working_dir: /src
+ volumes:
+ - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/merge_base_branch.sh b/.buildkite/merge_base_branch.sh
index 361440fd1a..d0a7aef8cb 100755
--- a/.buildkite/merge_base_branch.sh
+++ b/.buildkite/merge_base_branch.sh
@@ -12,7 +12,7 @@ if [[ -z $BUILDKITE_PULL_REQUEST_BASE_BRANCH ]]; then
# It probably hasn't had a PR opened yet. Since all PRs land on develop, we
# can probably assume it's based on it and will be merged into it.
- GITBASE="develop"
+ GITBASE="dinsic"
else
# Get the reference, using the GitHub API
GITBASE=$BUILDKITE_PULL_REQUEST_BASE_BRANCH
diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
new file mode 100644
index 0000000000..8decfe33ed
--- /dev/null
+++ b/.buildkite/pipeline.yml
@@ -0,0 +1,530 @@
+# This is just a dummy entry (the `x-yaml-aliases` key is not an official pipeline key, and will be ignored by BuildKite)
+# that we use only to store YAML anchors (`&xxx`), that we plan to use and reference later in the YAML file (using `*xxx`)
+# without having to copy/paste the same values over and over.
+# Note: keys like `agent`, `env`, … used here are totally arbitrary; the only point is to define various separate `&xxx` anchors there.
+#
+x-yaml-aliases:
+ commands:
+ - &trial_setup |
+ # Install additional packages that are not part of buildpack-deps / python images.
+ apt-get update && apt-get install -y xmlsec1
+ python -m pip install tox
+
+env:
+ COVERALLS_REPO_TOKEN: wsJWOby6j0uCYFiCes3r0XauxO27mx8lD
+
+steps:
+ - label: "\U0001F9F9 Check Style"
+ command:
+ - "python -m pip install tox"
+ - "tox -e check_codestyle"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - label: "\U0001F9F9 packaging"
+ command:
+ - "python -m pip install tox"
+ - "tox -e packaging"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - label: "\U0001F9F9 isort"
+ command:
+ - "python -m pip install tox"
+ - "tox -e check_isort"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - label: "\U0001F9F9 check-sample-config"
+ command:
+ - "python -m pip install tox"
+ - "tox -e check-sampleconfig"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - label: ":mypy: mypy"
+ command:
+ - "python -m pip install tox"
+ - "tox -e mypy"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.7"
+ mount-buildkite-agent: false
+
+ - wait
+
+ ################################################################################
+ #
+ # `trial` tests
+ #
+ ################################################################################
+
+ - label: ":python: 3.5 / SQLite / Old Deps"
+ command:
+ - ".buildkite/scripts/test_old_deps.sh"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ plugins:
+ - docker#v3.0.1:
+ image: "ubuntu:xenial" # We use xenial to get an old sqlite and python
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.5 / SQLite"
+ command:
+ - *trial_setup
+ - "tox -e py35,combine"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.5"
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.6 / SQLite"
+ command:
+ - *trial_setup
+ - "tox -e py36,combine"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.7 / SQLite"
+ command:
+ - *trial_setup
+ - "tox -e py37,combine"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.7"
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.5 / :postgres: 9.5"
+ agents:
+ queue: "medium"
+ env:
+ TRIAL_FLAGS: "-j 8"
+ PYTHON_VERSION: "3.5"
+ POSTGRES_VERSION: "9.5"
+ command:
+ - *trial_setup
+ - "python -m tox -e py35-postgres,combine"
+ plugins:
+ - docker-compose#v3.7.0:
+ run: testenv
+ config:
+ - .buildkite/docker-compose.yaml
+ - artifacts#v1.3.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.7 / :postgres: 11"
+ agents:
+ queue: "medium"
+ env:
+ TRIAL_FLAGS: "-j 8"
+ PYTHON_VERSION: "3.7"
+ POSTGRES_VERSION: "11"
+ command:
+ - *trial_setup
+ - "tox -e py37-postgres,combine"
+ plugins:
+ - docker-compose#v3.7.0:
+ run: testenv
+ config:
+ - .buildkite/docker-compose.yaml
+ - artifacts#v1.3.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.8 / :postgres: 12"
+ agents:
+ queue: "medium"
+ env:
+ TRIAL_FLAGS: "-j 8"
+ PYTHON_VERSION: "3.8"
+ POSTGRES_VERSION: "12"
+ command:
+ - *trial_setup
+ - "tox -e py38-postgres,combine"
+ plugins:
+ - docker-compose#v3.7.0:
+ run: testenv
+ config:
+ - .buildkite/docker-compose.yaml
+ - artifacts#v1.3.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ ################################################################################
+ #
+ # Sytest
+ #
+ ################################################################################
+
+ - label: "SyTest - :python: 3.5 / SQLite / Monolith"
+ agents:
+ queue: "medium"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash /bootstrap.sh synapse"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:dinsic"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: "/bin/sh"
+ init: false
+ shell: ["-x", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ style: "error"
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: "SyTest - :python: 3.5 / :postgres: 9.6 / Monolith"
+ agents:
+ queue: "medium"
+ env:
+ POSTGRES: "1"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash /bootstrap.sh synapse"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:dinsic"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: "/bin/sh"
+ init: false
+ shell: ["-x", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ style: "error"
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: "SyTest - :python: 3.5 / :postgres: 9.6 / Workers"
+ agents:
+ queue: "medium"
+ env:
+ MULTI_POSTGRES: "1" # Test with split out databases
+ POSTGRES: "1"
+ WORKERS: "1"
+ BLACKLIST: "synapse-blacklist-with-workers"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash -c 'cat /src/sytest-blacklist /src/.buildkite/worker-blacklist > /src/synapse-blacklist-with-workers'"
+ - "bash /bootstrap.sh synapse"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:dinsic"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: "/bin/sh"
+ init: false
+ shell: ["-x", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ style: "error"
+ # - matrix-org/coveralls#v1.0:
+ # parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+
+ - label: "SyTest - :python: 3.8 / :postgres: 12 / Monolith"
+ agents:
+ queue: "medium"
+ env:
+ POSTGRES: "1"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash /bootstrap.sh synapse"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:dinsic"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: "/bin/sh"
+ init: false
+ shell: ["-x", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ style: "error"
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: "SyTest - :python: 3.7 / :postgres: 11 / Workers"
+ agents:
+ queue: "medium"
+ env:
+ MULTI_POSTGRES: "1" # Test with split out databases
+ POSTGRES: "1"
+ WORKERS: "1"
+ BLACKLIST: "synapse-blacklist-with-workers"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash -c 'cat /src/sytest-blacklist /src/.buildkite/worker-blacklist > /src/synapse-blacklist-with-workers'"
+ - "bash /bootstrap.sh synapse"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:dinsic"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: "/bin/sh"
+ init: false
+ shell: ["-x", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ style: "error"
+ # - matrix-org/coveralls#v1.0:
+ # parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+# TODO: Enable once Synapse v1.13.0 is merged in
+# - label: "SyTest - :python: 3.7 / :postgres: 11 / Workers / :redis: Redis"
+# agents:
+# queue: "medium"
+# command:
+# - bash -c "cat /src/sytest-blacklist /src/.buildkite/worker-blacklist > /src/synapse-blacklist-with-workers && ./.buildkite/merge_base_branch.sh && /bootstrap.sh synapse --redis-host redis"
+# plugins:
+# - matrix-org/download#v1.1.0:
+# urls:
+# - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose.sytest.py37.redis.yaml
+# - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose-env
+# - docker-compose#v2.1.0:
+# run: sytest
+# config:
+# - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.sytest.py37.redis.yaml
+# - artifacts#v1.2.0:
+# upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+# - matrix-org/annotate:
+# path: "logs/annotate.md"
+# style: "error"
+## - matrix-org/coveralls#v1.0:
+## parallel: "true"
+# retry:
+# automatic:
+# - exit_status: -1
+# limit: 2
+# - exit_status: 2
+# limit: 2
+
+ ################################################################################
+ #
+ # synapse_port_db
+ #
+ ################################################################################
+
+ - label: "synapse_port_db / :python: 3.5 / :postgres: 9.5"
+ agents:
+ queue: "medium"
+ command:
+ - "bash .buildkite/scripts/test_synapse_port_db.sh"
+ plugins:
+ - matrix-org/download#v1.1.0:
+ urls:
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose.py35.pg95.yaml
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose-env
+ - docker-compose#v2.1.0:
+ run: testenv
+ config:
+ - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.py35.pg95.yaml
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+
+ - label: "synapse_port_db / :python: 3.7 / :postgres: 11"
+ agents:
+ queue: "medium"
+ command:
+ - "bash .buildkite/scripts/test_synapse_port_db.sh"
+ plugins:
+ - matrix-org/download#v1.1.0:
+ urls:
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose.py37.pg11.yaml
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose-env
+ - docker-compose#v2.1.0:
+ run: testenv
+ config:
+ - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.py37.pg11.yaml
+ - artifacts#v1.2.0:
+ upload: [ "_trial_temp/*/*.log" ]
+# - matrix-org/coveralls#v1.0:
+# parallel: "true"
+
+# - wait: ~
+# continue_on_failure: true
+#
+# - label: Trigger webhook
+# command: "curl -k https://coveralls.io/webhook?repo_token=$COVERALLS_REPO_TOKEN -d \"payload[build_num]=$BUILDKITE_BUILD_NUMBER&payload[status]=done\""
+
+ ################################################################################
+ #
+ # Complement Test Suite
+ #
+ ################################################################################
+
+ - command:
+ # Build a docker image from the checked out Synapse source
+ - "docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile ."
+ # We use the complement:latest image to provide Complement's dependencies, but want
+ # to actually run against the latest version of Complement, so download it here.
+ - "wget https://github.com/matrix-org/complement/archive/anoa/knock_room_v7.tar.gz"
+ - "tar -xzf knock_room_v7.tar.gz"
+ # Build a second docker image on top of the above image. This one sets up Synapse with a generated config file,
+ # signing and SSL keys so Synapse can run and federate
+ - "docker build -t complement-synapse -f complement-anoa-knock_room_v7/dockerfiles/Synapse.Dockerfile complement-anoa-knock_room_v7/dockerfiles"
+ # Finally, compile and run the tests.
+ - "cd complement-anoa-knock_room_v7"
+ - "COMPLEMENT_BASE_IMAGE=complement-synapse:latest go test -v -tags synapse_blacklist,msc2403 ./tests"
+ label: "\U0001F9EA Complement"
+ agents:
+ queue: "medium"
+ plugins:
+ - docker#v3.7.0:
+ # The dockerfile for this image is at https://github.com/matrix-org/complement/blob/master/dockerfiles/ComplementCIBuildkite.Dockerfile.
+ image: "matrixdotorg/complement:latest"
+ mount-buildkite-agent: false
+ # Complement needs to know if it is running under CI
+ environment:
+ - "CI=true"
+ publish: [ "8448:8448" ]
+ # Complement uses Docker so pass through the docker socket. This means Complement shares
+ # the hosts Docker.
+ volumes:
+ - "/var/run/docker.sock:/var/run/docker.sock"
\ No newline at end of file
diff --git a/CHANGES.md b/CHANGES.md
index b4e1d25fe0..fcd782fa94 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,9 @@
+Unreleased
+==========
+
+Note that this release includes a change in Synapse to use Redis as a cache ─ as well as a pub/sub mechanism ─ if Redis support is enabled. No action is needed by server administrators, and we do not expect resource usage of the Redis instance to change dramatically.
+
+
Synapse 1.26.0 (2021-01-27)
===========================
diff --git a/MANIFEST.in b/MANIFEST.in
index 120ce5b776..0a9cf4f51c 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,4 +1,5 @@
include synctl
+include sytest-blacklist
include LICENSE
include VERSION
include *.rst
@@ -51,3 +52,11 @@ prune demo/etc
prune docker
prune snap
prune stubs
+
+exclude jenkins*
+recursive-exclude jenkins *.sh
+
+# FIXME: we shouldn't have these templates here
+recursive-include res/templates-dinsic *.css
+recursive-include res/templates-dinsic *.html
+recursive-include res/templates-dinsic *.txt
diff --git a/UPGRADE.rst b/UPGRADE.rst
index d09dbd4e21..eea0322695 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -85,6 +85,43 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
+Upgrading to v1.27.0
+====================
+
+Changes to HTML templates
+-------------------------
+
+The HTML templates for SSO and email notifications now have `Jinja2's autoescape <https://jinja.palletsprojects.com/en/2.11.x/api/#autoescaping>`_
+enabled for files ending in ``.html``, ``.htm``, and ``.xml``. If you have customised
+these templates and see issues when viewing them you might need to update them.
+It is expected that most configurations will need no changes.
+
+If you have customised the templates *names* for these templates, it is recommended
+to verify they end in ``.html`` to ensure autoescape is enabled.
+
+The above applies to the following templates:
+
+* ``add_threepid.html``
+* ``add_threepid_failure.html``
+* ``add_threepid_success.html``
+* ``notice_expiry.html``
+* ``notice_expiry.html``
+* ``notif_mail.html`` (which, by default, includes ``room.html`` and ``notif.html``)
+* ``password_reset.html``
+* ``password_reset_confirmation.html``
+* ``password_reset_failure.html``
+* ``password_reset_success.html``
+* ``registration.html``
+* ``registration_failure.html``
+* ``registration_success.html``
+* ``sso_account_deactivated.html``
+* ``sso_auth_bad_user.html``
+* ``sso_auth_confirm.html``
+* ``sso_auth_success.html``
+* ``sso_error.html``
+* ``sso_login_idp_picker.html``
+* ``sso_redirect_confirm.html``
+
Upgrading to v1.26.0
====================
diff --git a/changelog.d/1.feature b/changelog.d/1.feature
new file mode 100644
index 0000000000..845642e445
--- /dev/null
+++ b/changelog.d/1.feature
@@ -0,0 +1 @@
+Forbid changing the name, avatar or topic of a direct room.
diff --git a/changelog.d/10.bugfix b/changelog.d/10.bugfix
new file mode 100644
index 0000000000..51f89f46dd
--- /dev/null
+++ b/changelog.d/10.bugfix
@@ -0,0 +1 @@
+Don't apply retention policy based filtering on state events.
diff --git a/changelog.d/11.feature b/changelog.d/11.feature
new file mode 100644
index 0000000000..362e4b1efd
--- /dev/null
+++ b/changelog.d/11.feature
@@ -0,0 +1 @@
+Allow server admins to configure a custom global rate-limiting for third party invites.
\ No newline at end of file
diff --git a/changelog.d/12.feature b/changelog.d/12.feature
new file mode 100644
index 0000000000..8e6e7a28af
--- /dev/null
+++ b/changelog.d/12.feature
@@ -0,0 +1 @@
+Add `/user/:user_id/info` CS servlet and to give user deactivated/expired information.
\ No newline at end of file
diff --git a/changelog.d/13.feature b/changelog.d/13.feature
new file mode 100644
index 0000000000..c2d2e93abf
--- /dev/null
+++ b/changelog.d/13.feature
@@ -0,0 +1 @@
+Hide expired users from the user directory, and optionally re-add them on renewal.
\ No newline at end of file
diff --git a/changelog.d/14.feature b/changelog.d/14.feature
new file mode 100644
index 0000000000..020d0bac1e
--- /dev/null
+++ b/changelog.d/14.feature
@@ -0,0 +1 @@
+User displaynames now have capitalised letters after - symbols.
\ No newline at end of file
diff --git a/changelog.d/15.misc b/changelog.d/15.misc
new file mode 100644
index 0000000000..4cc4a5175f
--- /dev/null
+++ b/changelog.d/15.misc
@@ -0,0 +1 @@
+Fix the ordering on `scripts/generate_signing_key.py`'s import statement.
diff --git a/changelog.d/17.misc b/changelog.d/17.misc
new file mode 100644
index 0000000000..58120ab5c7
--- /dev/null
+++ b/changelog.d/17.misc
@@ -0,0 +1 @@
+Blacklist some flaky sytests until they're fixed.
\ No newline at end of file
diff --git a/changelog.d/18.feature b/changelog.d/18.feature
new file mode 100644
index 0000000000..f5aa29a6e8
--- /dev/null
+++ b/changelog.d/18.feature
@@ -0,0 +1 @@
+Add option `limit_profile_requests_to_known_users` to prevent requirement of a user sharing a room with another user to query their profile information.
\ No newline at end of file
diff --git a/changelog.d/19.feature b/changelog.d/19.feature
new file mode 100644
index 0000000000..95a44a4a89
--- /dev/null
+++ b/changelog.d/19.feature
@@ -0,0 +1 @@
+Add `max_avatar_size` and `allowed_avatar_mimetypes` to restrict the size of user avatars and their file type respectively.
\ No newline at end of file
diff --git a/changelog.d/2.bugfix b/changelog.d/2.bugfix
new file mode 100644
index 0000000000..4fe5691468
--- /dev/null
+++ b/changelog.d/2.bugfix
@@ -0,0 +1 @@
+Don't treat 3PID revocation as a new 3PID invite.
diff --git a/changelog.d/20.bugfix b/changelog.d/20.bugfix
new file mode 100644
index 0000000000..8ba53c28f9
--- /dev/null
+++ b/changelog.d/20.bugfix
@@ -0,0 +1 @@
+Validate `client_secret` parameter against the regex provided by the C-S spec.
\ No newline at end of file
diff --git a/changelog.d/21.bugfix b/changelog.d/21.bugfix
new file mode 100644
index 0000000000..630d7812f7
--- /dev/null
+++ b/changelog.d/21.bugfix
@@ -0,0 +1 @@
+Fix resetting user passwords via a phone number.
diff --git a/changelog.d/28.bugfix b/changelog.d/28.bugfix
new file mode 100644
index 0000000000..38d7455971
--- /dev/null
+++ b/changelog.d/28.bugfix
@@ -0,0 +1 @@
+Fix a bug causing account validity renewal emails to be sent even if the feature is turned off in some cases.
diff --git a/changelog.d/29.misc b/changelog.d/29.misc
new file mode 100644
index 0000000000..720e0ddcfb
--- /dev/null
+++ b/changelog.d/29.misc
@@ -0,0 +1 @@
+Improve performance when making `.well-known` requests by sharing the SSL options between requests.
diff --git a/changelog.d/3.bugfix b/changelog.d/3.bugfix
new file mode 100644
index 0000000000..cc4bcefa80
--- /dev/null
+++ b/changelog.d/3.bugfix
@@ -0,0 +1 @@
+Fix encoding on password reset HTML responses in Python 2.
diff --git a/changelog.d/30.misc b/changelog.d/30.misc
new file mode 100644
index 0000000000..ae68554be3
--- /dev/null
+++ b/changelog.d/30.misc
@@ -0,0 +1 @@
+Improve performance when making HTTP requests to sygnal, sydent, etc, by sharing the SSL context object between connections.
diff --git a/changelog.d/32.bugfix b/changelog.d/32.bugfix
new file mode 100644
index 0000000000..b6e7b90710
--- /dev/null
+++ b/changelog.d/32.bugfix
@@ -0,0 +1 @@
+Fixes a bug when using the default display name during registration.
diff --git a/changelog.d/39.feature b/changelog.d/39.feature
new file mode 100644
index 0000000000..426b7ef27e
--- /dev/null
+++ b/changelog.d/39.feature
@@ -0,0 +1 @@
+Merge Synapse v1.12.4 `master` into the `dinsic` branch.
\ No newline at end of file
diff --git a/changelog.d/4.bugfix b/changelog.d/4.bugfix
new file mode 100644
index 0000000000..fe717920a6
--- /dev/null
+++ b/changelog.d/4.bugfix
@@ -0,0 +1 @@
+Fix handling of filtered strings in Python 3.
diff --git a/changelog.d/45.feature b/changelog.d/45.feature
new file mode 100644
index 0000000000..d45ac34ac1
--- /dev/null
+++ b/changelog.d/45.feature
@@ -0,0 +1 @@
+Merge Synapse mainline releases v1.13.0 through v1.14.0 into the `dinsic` branch.
\ No newline at end of file
diff --git a/changelog.d/46.feature b/changelog.d/46.feature
new file mode 100644
index 0000000000..7872d956e3
--- /dev/null
+++ b/changelog.d/46.feature
@@ -0,0 +1 @@
+Add a bulk version of the User Info API. Deprecate the single-use version.
\ No newline at end of file
diff --git a/changelog.d/47.misc b/changelog.d/47.misc
new file mode 100644
index 0000000000..1d6596d788
--- /dev/null
+++ b/changelog.d/47.misc
@@ -0,0 +1 @@
+Improve performance of `mark_expired_users_as_inactive` background job.
\ No newline at end of file
diff --git a/changelog.d/48.feature b/changelog.d/48.feature
new file mode 100644
index 0000000000..b7939f3f51
--- /dev/null
+++ b/changelog.d/48.feature
@@ -0,0 +1 @@
+Prevent `/register` from raising `M_USER_IN_USE` until UI Auth has been completed. Have `/register/available` always return true.
diff --git a/changelog.d/5.bugfix b/changelog.d/5.bugfix
new file mode 100644
index 0000000000..53f57f46ca
--- /dev/null
+++ b/changelog.d/5.bugfix
@@ -0,0 +1 @@
+Fix room retention policy management in worker mode.
diff --git a/changelog.d/50.feature b/changelog.d/50.feature
new file mode 100644
index 0000000000..0801622c8a
--- /dev/null
+++ b/changelog.d/50.feature
@@ -0,0 +1 @@
+Merge Synapse mainline v1.15.1 into the `dinsic` branch.
\ No newline at end of file
diff --git a/changelog.d/5083.feature b/changelog.d/5083.feature
new file mode 100644
index 0000000000..2ffdd37eef
--- /dev/null
+++ b/changelog.d/5083.feature
@@ -0,0 +1 @@
+Adds auth_profile_reqs option to require access_token to GET /profile endpoints on CS API.
diff --git a/changelog.d/5098.misc b/changelog.d/5098.misc
new file mode 100644
index 0000000000..9cd83bf226
--- /dev/null
+++ b/changelog.d/5098.misc
@@ -0,0 +1 @@
+Add workarounds for pep-517 install errors.
diff --git a/changelog.d/51.feature b/changelog.d/51.feature
new file mode 100644
index 0000000000..e5c9990ad6
--- /dev/null
+++ b/changelog.d/51.feature
@@ -0,0 +1 @@
+Add `bind_new_user_emails_to_sydent` option for automatically binding user's emails after registration.
diff --git a/changelog.d/5214.feature b/changelog.d/5214.feature
new file mode 100644
index 0000000000..6c0f15c901
--- /dev/null
+++ b/changelog.d/5214.feature
@@ -0,0 +1 @@
+Allow server admins to define and enforce a password policy (MSC2000).
diff --git a/changelog.d/53.feature b/changelog.d/53.feature
new file mode 100644
index 0000000000..96c628e824
--- /dev/null
+++ b/changelog.d/53.feature
@@ -0,0 +1 @@
+Merge mainline Synapse v1.18.0 into the `dinsic` branch.
\ No newline at end of file
diff --git a/changelog.d/5416.misc b/changelog.d/5416.misc
new file mode 100644
index 0000000000..155e8c7cd3
--- /dev/null
+++ b/changelog.d/5416.misc
@@ -0,0 +1 @@
+Add unique index to the profile_replication_status table.
diff --git a/changelog.d/5420.feature b/changelog.d/5420.feature
new file mode 100644
index 0000000000..745864b903
--- /dev/null
+++ b/changelog.d/5420.feature
@@ -0,0 +1 @@
+Add configuration option to hide new users from the user directory.
diff --git a/changelog.d/56.misc b/changelog.d/56.misc
new file mode 100644
index 0000000000..f66c55af21
--- /dev/null
+++ b/changelog.d/56.misc
@@ -0,0 +1 @@
+Temporarily revert commit a3fbc23.
diff --git a/changelog.d/5610.feature b/changelog.d/5610.feature
new file mode 100644
index 0000000000..b99514f97e
--- /dev/null
+++ b/changelog.d/5610.feature
@@ -0,0 +1 @@
+Implement new custom event rules for power levels.
diff --git a/changelog.d/57.misc b/changelog.d/57.misc
new file mode 100644
index 0000000000..1bbe8611cd
--- /dev/null
+++ b/changelog.d/57.misc
@@ -0,0 +1 @@
+Add user_id back to presence in worker too https://github.com/matrix-org/synapse/commit/0bbbd10513008d30c17eb1d1e7ba1d091fb44ec7 .
diff --git a/changelog.d/5702.bugfix b/changelog.d/5702.bugfix
new file mode 100644
index 0000000000..43b6e39b13
--- /dev/null
+++ b/changelog.d/5702.bugfix
@@ -0,0 +1 @@
+Fix 3PID invite to invite association detection in the Tchap room access rules.
diff --git a/changelog.d/5760.feature b/changelog.d/5760.feature
new file mode 100644
index 0000000000..90302d793e
--- /dev/null
+++ b/changelog.d/5760.feature
@@ -0,0 +1 @@
+Force the access rule to be "restricted" if the join rule is "public".
diff --git a/changelog.d/58.misc b/changelog.d/58.misc
new file mode 100644
index 0000000000..64098a68a4
--- /dev/null
+++ b/changelog.d/58.misc
@@ -0,0 +1 @@
+Don't push if an user account has expired.
diff --git a/changelog.d/59.feature b/changelog.d/59.feature
new file mode 100644
index 0000000000..aa07f762d1
--- /dev/null
+++ b/changelog.d/59.feature
@@ -0,0 +1 @@
+Freeze a room when the last administrator in the room leaves.
\ No newline at end of file
diff --git a/changelog.d/6.bugfix b/changelog.d/6.bugfix
new file mode 100644
index 0000000000..43ab65cc95
--- /dev/null
+++ b/changelog.d/6.bugfix
@@ -0,0 +1 @@
+Don't forbid membership events which membership isn't 'join' or 'invite' in restricted rooms, so that users who got into these rooms before the access rules started to be enforced can leave them.
diff --git a/changelog.d/60.misc b/changelog.d/60.misc
new file mode 100644
index 0000000000..d2625a4f65
--- /dev/null
+++ b/changelog.d/60.misc
@@ -0,0 +1 @@
+Make all rooms noisy by default.
diff --git a/changelog.d/61.misc b/changelog.d/61.misc
new file mode 100644
index 0000000000..0c3ba98628
--- /dev/null
+++ b/changelog.d/61.misc
@@ -0,0 +1 @@
+Change the minimum power levels for invites and other state events in new rooms.
\ No newline at end of file
diff --git a/changelog.d/62.misc b/changelog.d/62.misc
new file mode 100644
index 0000000000..1e26456595
--- /dev/null
+++ b/changelog.d/62.misc
@@ -0,0 +1 @@
+Type hinting and other cleanups for `synapse.third_party_rules.access_rules`.
\ No newline at end of file
diff --git a/changelog.d/63.feature b/changelog.d/63.feature
new file mode 100644
index 0000000000..b45f38fa94
--- /dev/null
+++ b/changelog.d/63.feature
@@ -0,0 +1 @@
+Make AccessRules use the public rooms directory instead of checking a room's join rules on rule change.
diff --git a/changelog.d/64.bugfix b/changelog.d/64.bugfix
new file mode 100644
index 0000000000..60c077af94
--- /dev/null
+++ b/changelog.d/64.bugfix
@@ -0,0 +1 @@
+Ensure a `RoomAccessRules` test doesn't accidentally modify a room's access rule and then test that room assuming its access rule has not changed.
diff --git a/changelog.d/65.bugfix b/changelog.d/65.bugfix
new file mode 100644
index 0000000000..71b498cbc8
--- /dev/null
+++ b/changelog.d/65.bugfix
@@ -0,0 +1 @@
+Fix `nextLink` parameters being checked on validation endpoints even if they weren't provided by the client.
\ No newline at end of file
diff --git a/changelog.d/66.bugfix b/changelog.d/66.bugfix
new file mode 100644
index 0000000000..9547cfeddd
--- /dev/null
+++ b/changelog.d/66.bugfix
@@ -0,0 +1 @@
+Create a mapping between user ID and threepid when binding via the internal Sydent bind API.
\ No newline at end of file
diff --git a/changelog.d/67.misc b/changelog.d/67.misc
new file mode 100644
index 0000000000..0a2095e4d4
--- /dev/null
+++ b/changelog.d/67.misc
@@ -0,0 +1 @@
+Merge mainline Synapse v1.21.2 into 'dinsic'.
\ No newline at end of file
diff --git a/changelog.d/6739.feature b/changelog.d/6739.feature
new file mode 100644
index 0000000000..9c41140194
--- /dev/null
+++ b/changelog.d/6739.feature
@@ -0,0 +1 @@
+Implement "room knocking" as per [MSC2403](https://github.com/matrix-org/matrix-doc/pull/2403). Contributed by Sorunome and anoa.
\ No newline at end of file
diff --git a/changelog.d/68.misc b/changelog.d/68.misc
new file mode 100644
index 0000000000..99cc5f7483
--- /dev/null
+++ b/changelog.d/68.misc
@@ -0,0 +1 @@
+Override any missing default power level keys with DINUM's defaults when creating a room.
\ No newline at end of file
diff --git a/changelog.d/71.bugfix b/changelog.d/71.bugfix
new file mode 100644
index 0000000000..cad69c7bd2
--- /dev/null
+++ b/changelog.d/71.bugfix
@@ -0,0 +1 @@
+Fix users info for remote users.
diff --git a/changelog.d/72.bugfix b/changelog.d/72.bugfix
new file mode 100644
index 0000000000..7ebd16f437
--- /dev/null
+++ b/changelog.d/72.bugfix
@@ -0,0 +1 @@
+Update the version of mypy to 0.790.
diff --git a/changelog.d/9.misc b/changelog.d/9.misc
new file mode 100644
index 0000000000..24fd12c978
--- /dev/null
+++ b/changelog.d/9.misc
@@ -0,0 +1 @@
+Add SyTest to the BuildKite CI.
diff --git a/changelog.d/9045.misc b/changelog.d/9045.misc
new file mode 100644
index 0000000000..7f1886a0de
--- /dev/null
+++ b/changelog.d/9045.misc
@@ -0,0 +1 @@
+Add tests to `test_user.UsersListTestCase` for List Users Admin API.
\ No newline at end of file
diff --git a/changelog.d/9062.feature b/changelog.d/9062.feature
new file mode 100644
index 0000000000..8b950fa062
--- /dev/null
+++ b/changelog.d/9062.feature
@@ -0,0 +1 @@
+Add admin API for getting and deleting forward extremities for a room.
diff --git a/changelog.d/9084.bugfix b/changelog.d/9084.bugfix
new file mode 100644
index 0000000000..415dd8b259
--- /dev/null
+++ b/changelog.d/9084.bugfix
@@ -0,0 +1 @@
+Don't blacklist connections to the configured proxy. Contributed by @Bubu.
diff --git a/changelog.d/9121.bugfix b/changelog.d/9121.bugfix
new file mode 100644
index 0000000000..a566878ec0
--- /dev/null
+++ b/changelog.d/9121.bugfix
@@ -0,0 +1 @@
+Fix spurious errors in logs when deleting a non-existant pusher.
diff --git a/changelog.d/9163.bugfix b/changelog.d/9163.bugfix
new file mode 100644
index 0000000000..c51cf6ca80
--- /dev/null
+++ b/changelog.d/9163.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled).
diff --git a/changelog.d/9164.bugfix b/changelog.d/9164.bugfix
new file mode 100644
index 0000000000..1c54a256c1
--- /dev/null
+++ b/changelog.d/9164.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where an internal server error was raised when attempting to preview an HTML document in an unknown character encoding.
diff --git a/changelog.d/9165.bugfix b/changelog.d/9165.bugfix
new file mode 100644
index 0000000000..58db22f484
--- /dev/null
+++ b/changelog.d/9165.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where invalid data could cause errors when calculating the presentable room name for push.
diff --git a/changelog.d/9176.misc b/changelog.d/9176.misc
new file mode 100644
index 0000000000..9c41d7b0f9
--- /dev/null
+++ b/changelog.d/9176.misc
@@ -0,0 +1 @@
+Speed up chain cover calculation when persisting a batch of state events at once.
diff --git a/changelog.d/9188.misc b/changelog.d/9188.misc
new file mode 100644
index 0000000000..7820d09cd0
--- /dev/null
+++ b/changelog.d/9188.misc
@@ -0,0 +1 @@
+Speed up batch insertion when using PostgreSQL.
diff --git a/changelog.d/9190.misc b/changelog.d/9190.misc
new file mode 100644
index 0000000000..1b0cc56a92
--- /dev/null
+++ b/changelog.d/9190.misc
@@ -0,0 +1 @@
+Improve performance of concurrent use of `StreamIDGenerators`.
diff --git a/changelog.d/9191.misc b/changelog.d/9191.misc
new file mode 100644
index 0000000000..b4bc6be13a
--- /dev/null
+++ b/changelog.d/9191.misc
@@ -0,0 +1 @@
+Add some missing source directories to the automatic linting script.
\ No newline at end of file
diff --git a/changelog.d/9198.misc b/changelog.d/9198.misc
new file mode 100644
index 0000000000..a6cb77fbb2
--- /dev/null
+++ b/changelog.d/9198.misc
@@ -0,0 +1 @@
+Precompute joined hosts and store in Redis.
diff --git a/changelog.d/9199.removal b/changelog.d/9199.removal
new file mode 100644
index 0000000000..fbd2916cbf
--- /dev/null
+++ b/changelog.d/9199.removal
@@ -0,0 +1 @@
+The `service_url` parameter in `cas_config` is deprecated in favor of `public_baseurl`.
diff --git a/changelog.d/9200.misc b/changelog.d/9200.misc
new file mode 100644
index 0000000000..5f239ff9da
--- /dev/null
+++ b/changelog.d/9200.misc
@@ -0,0 +1 @@
+Clean-up template loading code.
diff --git a/changelog.d/9209.feature b/changelog.d/9209.feature
new file mode 100644
index 0000000000..ec926e8eb4
--- /dev/null
+++ b/changelog.d/9209.feature
@@ -0,0 +1 @@
+Add an admin API endpoint for shadow-banning users.
diff --git a/changelog.d/9218.bugfix b/changelog.d/9218.bugfix
new file mode 100644
index 0000000000..577fff5497
--- /dev/null
+++ b/changelog.d/9218.bugfix
@@ -0,0 +1 @@
+Fix bug where we sometimes didn't detect that Redis connections had died, causing workers to not see new data.
diff --git a/changelog.d/9222.misc b/changelog.d/9222.misc
new file mode 100644
index 0000000000..37490717b3
--- /dev/null
+++ b/changelog.d/9222.misc
@@ -0,0 +1 @@
+Update `isort` to v5.7.0 to bypass a bug where it would disagree with `black` about formatting.
\ No newline at end of file
diff --git a/changelog.d/9223.misc b/changelog.d/9223.misc
new file mode 100644
index 0000000000..9d44b621c9
--- /dev/null
+++ b/changelog.d/9223.misc
@@ -0,0 +1 @@
+Add type hints to handlers code.
diff --git a/changelog.d/9227.misc b/changelog.d/9227.misc
new file mode 100644
index 0000000000..a6cb77fbb2
--- /dev/null
+++ b/changelog.d/9227.misc
@@ -0,0 +1 @@
+Precompute joined hosts and store in Redis.
diff --git a/changelog.d/9229.bugfix b/changelog.d/9229.bugfix
new file mode 100644
index 0000000000..3ed32291de
--- /dev/null
+++ b/changelog.d/9229.bugfix
@@ -0,0 +1 @@
+Fix a bug where `None` was passed to Synapse modules instead of an empty dictionary if an empty module `config` block was provided in the homeserver config.
\ No newline at end of file
diff --git a/changelog.d/9232.misc b/changelog.d/9232.misc
new file mode 100644
index 0000000000..9d44b621c9
--- /dev/null
+++ b/changelog.d/9232.misc
@@ -0,0 +1 @@
+Add type hints to handlers code.
diff --git a/changelog.d/9235.bugfix b/changelog.d/9235.bugfix
new file mode 100644
index 0000000000..7809c8673b
--- /dev/null
+++ b/changelog.d/9235.bugfix
@@ -0,0 +1 @@
+Fix a bug in the `make_room_admin` admin API where it failed if the admin with the greatest power level was not in the room. Contributed by Pankaj Yadav.
diff --git a/changelog.d/9238.feature b/changelog.d/9238.feature
new file mode 100644
index 0000000000..143a3e14f5
--- /dev/null
+++ b/changelog.d/9238.feature
@@ -0,0 +1 @@
+Add ratelimited to 3PID `/requestToken` API.
diff --git a/changelog.d/9244.doc b/changelog.d/9244.doc
new file mode 100644
index 0000000000..2ad81429fc
--- /dev/null
+++ b/changelog.d/9244.doc
@@ -0,0 +1 @@
+Add notes on integrating with Facebook for SSO login.
diff --git a/changelog.d/9254.misc b/changelog.d/9254.misc
new file mode 100644
index 0000000000..b79b9abbd3
--- /dev/null
+++ b/changelog.d/9254.misc
@@ -0,0 +1 @@
+Fix Debian package building on Ubuntu 16.04 LTS (Xenial).
diff --git a/changelog.d/9255.misc b/changelog.d/9255.misc
new file mode 100644
index 0000000000..f723b8ec4f
--- /dev/null
+++ b/changelog.d/9255.misc
@@ -0,0 +1 @@
+Minor performance improvement during TLS handshake.
diff --git a/changelog.d/9258.feature b/changelog.d/9258.feature
new file mode 100644
index 0000000000..0028f42d26
--- /dev/null
+++ b/changelog.d/9258.feature
@@ -0,0 +1 @@
+Add ratelimits to invites in rooms and to specific users.
diff --git a/changelog.d/9265.bugfix b/changelog.d/9265.bugfix
new file mode 100644
index 0000000000..34f7bd8ddd
--- /dev/null
+++ b/changelog.d/9265.bugfix
@@ -0,0 +1 @@
+Prevent password hashes from getting dropped if a client failed threepid validation during a User Interactive Auth stage. Removes a workaround for an ancient bug in Riot Web <v0.7.4.
\ No newline at end of file
diff --git a/changelog.d/9270.misc b/changelog.d/9270.misc
new file mode 100644
index 0000000000..908e5ee78b
--- /dev/null
+++ b/changelog.d/9270.misc
@@ -0,0 +1 @@
+Restore PyPy compatibility by not calling CPython-specific GC methods when under PyPy.
diff --git a/changelog.d/9283.feature b/changelog.d/9283.feature
new file mode 100644
index 0000000000..54f133a064
--- /dev/null
+++ b/changelog.d/9283.feature
@@ -0,0 +1 @@
+Add phone home stats for encrypted messages.
diff --git a/changelog.d/9372.feature b/changelog.d/9372.feature
new file mode 100644
index 0000000000..3cb01004c9
--- /dev/null
+++ b/changelog.d/9372.feature
@@ -0,0 +1 @@
+The `no_proxy` and `NO_PROXY` environment variables are now respected in proxied HTTP clients with the lowercase form taking precedence if both are present. Additionally, the lowercase `https_proxy` environment variable is now respected in proxied HTTP clients on top of existing support for the uppercase `HTTPS_PROXY` form and takes precedence if both are present. Contributed by Timothy Leung.
diff --git a/changelog.d/9657.feature b/changelog.d/9657.feature
new file mode 100644
index 0000000000..c56a615a8b
--- /dev/null
+++ b/changelog.d/9657.feature
@@ -0,0 +1 @@
+Add support for credentials for proxy authentication in the `HTTPS_PROXY` environment variable.
diff --git a/contrib/systemd/README.md b/contrib/systemd/README.md
deleted file mode 100644
index 5d42b3464f..0000000000
--- a/contrib/systemd/README.md
+++ /dev/null
@@ -1,17 +0,0 @@
-# Setup Synapse with Systemd
-This is a setup for managing synapse with a user contributed systemd unit
-file. It provides a `matrix-synapse` systemd unit file that should be tailored
-to accommodate your installation in accordance with the installation
-instructions provided in [installation instructions](../../INSTALL.md).
-
-## Setup
-1. Under the service section, ensure the `User` variable matches which user
-you installed synapse under and wish to run it as.
-2. Under the service section, ensure the `WorkingDirectory` variable matches
-where you have installed synapse.
-3. Under the service section, ensure the `ExecStart` variable matches the
-appropriate locations of your installation.
-4. Copy the `matrix-synapse.service` to `/etc/systemd/system/`
-5. Start Synapse: `sudo systemctl start matrix-synapse`
-6. Verify Synapse is running: `sudo systemctl status matrix-synapse`
-7. *optional* Enable Synapse to start at system boot: `sudo systemctl enable matrix-synapse`
diff --git a/debian/build_virtualenv b/debian/build_virtualenv
index cbdde93f96..cf19084a9f 100755
--- a/debian/build_virtualenv
+++ b/debian/build_virtualenv
@@ -33,11 +33,13 @@ esac
# Use --builtin-venv to use the better `venv` module from CPython 3.4+ rather
# than the 2/3 compatible `virtualenv`.
+# Pin pip to 20.3.4 to fix breakage in 21.0 on py3.5 (xenial)
+
dh_virtualenv \
--install-suffix "matrix-synapse" \
--builtin-venv \
--python "$SNAKE" \
- --upgrade-pip \
+ --upgrade-pip-to="20.3.4" \
--preinstall="lxml" \
--preinstall="mock" \
--extra-pip-arg="--no-cache-dir" \
diff --git a/debian/changelog b/debian/changelog
index 1c6308e3a2..1a421a85bd 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,8 +1,18 @@
-matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium
+matrix-synapse-py3 (1.26.0+nmu1) UNRELEASED; urgency=medium
+ * Fix build on Ubuntu 16.04 LTS (Xenial).
+
+ -- Dan Callahan <danc@element.io> Thu, 28 Jan 2021 16:21:03 +0000
+
+matrix-synapse-py3 (1.26.0) stable; urgency=medium
+
+ [ Richard van der Hoff ]
* Remove dependency on `python3-distutils`.
- -- Richard van der Hoff <richard@matrix.org> Fri, 15 Jan 2021 12:44:19 +0000
+ [ Synapse Packaging team ]
+ * New synapse release 1.26.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Wed, 27 Jan 2021 12:43:35 -0500
matrix-synapse-py3 (1.25.0) stable; urgency=medium
diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv
index e529293803..0d74630370 100644
--- a/docker/Dockerfile-dhvirtualenv
+++ b/docker/Dockerfile-dhvirtualenv
@@ -27,6 +27,7 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \
wget
# fetch and unpack the package
+# TODO: Upgrade to 1.2.2 once xenial is dropped
RUN mkdir /dh-virtualenv
RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz
RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz
diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md
index 9e560003a9..f34cec1ff7 100644
--- a/docs/admin_api/rooms.md
+++ b/docs/admin_api/rooms.md
@@ -9,6 +9,7 @@
* [Response](#response)
* [Undoing room shutdowns](#undoing-room-shutdowns)
- [Make Room Admin API](#make-room-admin-api)
+- [Forward Extremities Admin API](#forward-extremities-admin-api)
# List Room API
@@ -511,3 +512,55 @@ optionally be specified, e.g.:
"user_id": "@foo:example.com"
}
```
+
+# Forward Extremities Admin API
+
+Enables querying and deleting forward extremities from rooms. When a lot of forward
+extremities accumulate in a room, performance can become degraded. For details, see
+[#1760](https://github.com/matrix-org/synapse/issues/1760).
+
+## Check for forward extremities
+
+To check the status of forward extremities for a room:
+
+```
+ GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+```
+
+A response as follows will be returned:
+
+```json
+{
+ "count": 1,
+ "results": [
+ {
+ "event_id": "$M5SP266vsnxctfwFgFLNceaCo3ujhRtg_NiiHabcdefgh",
+ "state_group": 439,
+ "depth": 123,
+ "received_ts": 1611263016761
+ }
+ ]
+}
+```
+
+## Deleting forward extremities
+
+**WARNING**: Please ensure you know what you're doing and have read
+the related issue [#1760](https://github.com/matrix-org/synapse/issues/1760).
+Under no situations should this API be executed as an automated maintenance task!
+
+If a room has lots of forward extremities, the extra can be
+deleted as follows:
+
+```
+ DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+```
+
+A response as follows will be returned, indicating the amount of forward extremities
+that were deleted.
+
+```json
+{
+ "deleted": 1
+}
+```
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index b3d413cf57..1eb674939e 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -760,3 +760,33 @@ The following fields are returned in the JSON response body:
- ``total`` - integer - Number of pushers.
See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_
+
+Shadow-banning users
+====================
+
+Shadow-banning is a useful tool for moderating malicious or egregiously abusive users.
+A shadow-banned users receives successful responses to their client-server API requests,
+but the events are not propagated into rooms. This can be an effective tool as it
+(hopefully) takes longer for the user to realise they are being moderated before
+pivoting to another account.
+
+Shadow-banning a user should be used as a tool of last resort and may lead to confusing
+or broken behaviour for the client. A shadow-banned user will not receive any
+notification and it is generally more appropriate to ban or kick abusive users.
+A shadow-banned user will be unable to contact anyone on the server.
+
+The API is::
+
+ POST /_synapse/admin/v1/users/<user_id>/shadow_ban
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+An empty JSON dict is returned.
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
+ be local.
diff --git a/docs/openid.md b/docs/openid.md
index f01f46d326..4ba3559e38 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -44,7 +44,7 @@ as follows:
To enable the OpenID integration, you should then add a section to the `oidc_providers`
setting in your configuration file (or uncomment one of the existing examples).
-See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as
+See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as
the text below for example configurations for specific providers.
## Sample configs
@@ -52,11 +52,11 @@ the text below for example configurations for specific providers.
Here are a few configs for providers that should work with Synapse.
### Microsoft Azure Active Directory
-Azure AD can act as an OpenID Connect Provider. Register a new application under
+Azure AD can act as an OpenID Connect Provider. Register a new application under
*App registrations* in the Azure AD management console. The RedirectURI for your
application should point to your matrix server: `[synapse public baseurl]/_synapse/oidc/callback`
-Go to *Certificates & secrets* and register a new client secret. Make note of your
+Go to *Certificates & secrets* and register a new client secret. Make note of your
Directory (tenant) ID as it will be used in the Azure links.
Edit your Synapse config file and change the `oidc_config` section:
@@ -118,7 +118,7 @@ oidc_providers:
```
### [Keycloak][keycloak-idp]
-[Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat.
+[Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat.
Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm.
@@ -194,7 +194,7 @@ Synapse config:
```yaml
oidc_providers:
- - idp_id: auth0
+ - idp_id: auth0
idp_name: Auth0
issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
client_id: "your-client-id" # TO BE FILLED
@@ -310,3 +310,46 @@ oidc_providers:
localpart_template: '{{ user.nickname }}'
display_name_template: '{{ user.name }}'
```
+
+### Facebook
+
+Like Github, Facebook provide a custom OAuth2 API rather than an OIDC-compliant
+one so requires a little more configuration.
+
+0. You will need a Facebook developer account. You can register for one
+ [here](https://developers.facebook.com/async/registration/).
+1. On the [apps](https://developers.facebook.com/apps/) page of the developer
+ console, "Create App", and choose "Build Connected Experiences".
+2. Once the app is created, add "Facebook Login" and choose "Web". You don't
+ need to go through the whole form here.
+3. In the left-hand menu, open "Products"/"Facebook Login"/"Settings".
+ * Add `[synapse public baseurl]/_synapse/oidc/callback` as an OAuth Redirect
+ URL.
+4. In the left-hand menu, open "Settings/Basic". Here you can copy the "App ID"
+ and "App Secret" for use below.
+
+Synapse config:
+
+```yaml
+ - idp_id: facebook
+ idp_name: Facebook
+ idp_brand: "org.matrix.facebook" # optional: styling hint for clients
+ discover: false
+ issuer: "https://facebook.com"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ scopes: ["openid", "email"]
+ authorization_endpoint: https://facebook.com/dialog/oauth
+ token_endpoint: https://graph.facebook.com/v9.0/oauth/access_token
+ user_profile_method: "userinfo_endpoint"
+ userinfo_endpoint: "https://graph.facebook.com/v9.0/me?fields=id,name,email,picture"
+ user_mapping_provider:
+ config:
+ subject_claim: "id"
+ display_name_template: "{{ user.name }}"
+```
+
+Relevant documents:
+ * https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow
+ * Using Facebook's Graph API: https://developers.facebook.com/docs/graph-api/using-graph-api/
+ * Reference to the User endpoint: https://developers.facebook.com/docs/graph-api/reference/user
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 15e9746696..71468dd46b 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -399,6 +399,74 @@ limit_remote_rooms:
#
#allow_per_room_profiles: false
+# Whether to show the users on this homeserver in the user directory. Defaults to
+# 'true'.
+#
+#show_users_in_user_directory: false
+
+# Message retention policy at the server level.
+#
+# Room admins and mods can define a retention period for their rooms using the
+# 'm.room.retention' state event, and server admins can cap this period by setting
+# the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+#
+# If this feature is enabled, Synapse will regularly look for and purge events
+# which are older than the room's maximum retention period. Synapse will also
+# filter events received over federation so that events that should have been
+# purged are ignored and not stored again.
+#
+retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
+
# How long to keep redacted events in unredacted form in the database. After
# this period redacted events get replaced with their redacted form in the DB.
#
@@ -817,6 +885,8 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+# - one that ratelimits third-party invites requests based on the account
+# that's making the requests.
# - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly.
@@ -824,6 +894,9 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which
# can be more expensive)
+# - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
+# - two for ratelimiting how often invites can be sent in a room or to a
+# specific user.
#
# The defaults are as shown below.
#
@@ -846,6 +919,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# per_second: 0.17
# burst_count: 3
#
+#rc_third_party_invite:
+# per_second: 0.2
+# burst_count: 10
+#
#rc_admin_redaction:
# per_second: 1
# burst_count: 50
@@ -857,7 +934,18 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# remote:
# per_second: 0.01
# burst_count: 3
-
+#
+#rc_3pid_validation:
+# per_second: 0.003
+# burst_count: 5
+#
+#rc_invites:
+# per_room:
+# per_second: 0.3
+# burst_count: 10
+# per_user:
+# per_second: 0.003
+# burst_count: 5
# Ratelimiting settings for incoming federation
#
@@ -920,6 +1008,30 @@ media_store_path: "DATADIR/media_store"
#
#max_upload_size: 50M
+# The largest allowed size for a user avatar. If not defined, no
+# restriction will be imposed.
+#
+# Note that this only applies when an avatar is changed globally.
+# Per-room avatar changes are not affected. See allow_per_room_profiles
+# for disabling that functionality.
+#
+# Note that user avatar changes will not work if this is set without
+# using Synapse's local media repo.
+#
+#max_avatar_size: 10M
+
+# Allow mimetypes for a user avatar. If not defined, no restriction will
+# be imposed.
+#
+# Note that this only applies when an avatar is changed globally.
+# Per-room avatar changes are not affected. See allow_per_room_profiles
+# for disabling that functionality.
+#
+# Note that user avatar changes will not work if this is set without
+# using Synapse's local media repo.
+#
+#allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
@@ -1125,70 +1237,6 @@ url_preview_accept_language:
#
#enable_registration: false
-# Optional account validity configuration. This allows for accounts to be denied
-# any request after a given period.
-#
-# Once this feature is enabled, Synapse will look for registered users without an
-# expiration date at startup and will add one to every account it found using the
-# current settings at that time.
-# This means that, if a validity period is set, and Synapse is restarted (it will
-# then derive an expiration date from the current validity period), and some time
-# after that the validity period changes and Synapse is restarted, the users'
-# expiration dates won't be updated unless their account is manually renewed. This
-# date will be randomly selected within a range [now + period - d ; now + period],
-# where d is equal to 10% of the validity period.
-#
-account_validity:
- # The account validity feature is disabled by default. Uncomment the
- # following line to enable it.
- #
- #enabled: true
-
- # The period after which an account is valid after its registration. When
- # renewing the account, its validity period will be extended by this amount
- # of time. This parameter is required when using the account validity
- # feature.
- #
- #period: 6w
-
- # The amount of time before an account's expiry date at which Synapse will
- # send an email to the account's email address with a renewal link. By
- # default, no such emails are sent.
- #
- # If you enable this setting, you will also need to fill out the 'email'
- # configuration section. You should also check that 'public_baseurl' is set
- # correctly.
- #
- #renew_at: 1w
-
- # The subject of the email sent out with the renewal link. '%(app)s' can be
- # used as a placeholder for the 'app_name' parameter from the 'email'
- # section.
- #
- # Note that the placeholder must be written '%(app)s', including the
- # trailing 's'.
- #
- # If this is not set, a default value is used.
- #
- #renew_email_subject: "Renew your %(app)s account"
-
- # Directory in which Synapse will try to find templates for the HTML files to
- # serve to the user when trying to renew an account. If not set, default
- # templates from within the Synapse package will be used.
- #
- #template_dir: "res/templates"
-
- # File within 'template_dir' giving the HTML to be displayed to the user after
- # they successfully renewed their account. If not set, default text is used.
- #
- #account_renewed_html_path: "account_renewed.html"
-
- # File within 'template_dir' giving the HTML to be displayed when the user
- # tries to renew an account with an invalid renewal token. If not set,
- # default text is used.
- #
- #invalid_token_html_path: "invalid_token.html"
-
# Time that a user's session remains valid for, after they log in.
#
# Note that this is not currently compatible with guest logins.
@@ -1211,9 +1259,32 @@ account_validity:
#
#disable_msisdn_registration: true
+# Derive the user's matrix ID from a type of 3PID used when registering.
+# This overrides any matrix ID the user proposes when calling /register
+# The 3PID type should be present in registrations_require_3pid to avoid
+# users failing to register if they don't specify the right kind of 3pid.
+#
+#register_mxid_from_3pid: email
+
+# Uncomment to set the display name of new users to their email address,
+# rather than using the default heuristic.
+#
+#register_just_use_email_for_display_name: true
+
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
+# Use an Identity Server to establish which 3PIDs are allowed to register?
+# Overrides allowed_local_3pids below.
+#
+#check_is_for_allowed_local_3pids: matrix.org
+#
+# If you are using an IS you can also check whether that IS registers
+# pending invites for the given 3PID (and then allow it to sign up on
+# the platform):
+#
+#allow_invited_3pids: false
+#
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\.org'
@@ -1222,6 +1293,11 @@ account_validity:
# - medium: msisdn
# pattern: '\+44'
+# If true, stop users from trying to change the 3PIDs associated with
+# their accounts.
+#
+#disable_3pid_changes: false
+
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -1252,6 +1328,30 @@ account_validity:
#
#default_identity_server: https://matrix.org
+# If enabled, user IDs, display names and avatar URLs will be replicated
+# to this server whenever they change.
+# This is an experimental API currently implemented by sydent to support
+# cross-homeserver user directories.
+#
+#replicate_user_profiles_to: example.com
+
+# If specified, attempt to replay registrations, profile changes & 3pid
+# bindings on the given target homeserver via the AS API. The HS is authed
+# via a given AS token.
+#
+#shadow_server:
+# hs_url: https://shadow.example.com
+# hs: shadow.example.com
+# as_token: 12u394refgbdhivsia
+
+# If enabled, don't let users set their own display names/avatars
+# other than for the very first time (unless they are a server admin).
+# Useful when provisioning users based on the contents of a 3rd party
+# directory and to avoid ambiguities.
+#
+#disable_set_displayname: false
+#disable_set_avatar_url: false
+
# Handle threepid (email/phone etc) registration and password resets through a set of
# *trusted* identity servers. Note that this allows the configured identity server to
# reset passwords for accounts!
@@ -1377,6 +1477,97 @@ account_threepid_delegates:
#
#auto_join_rooms_for_guests: false
+# Rewrite identity server URLs with a map from one URL to another. Applies to URLs
+# provided by clients (which have https:// prepended) and those specified
+# in `account_threepid_delegates`. URLs should not feature a trailing slash.
+#
+#rewrite_identity_server_urls:
+# "https://somewhere.example.com": "https://somewhereelse.example.com"
+
+# When a user registers an account with an email address, it can be useful to
+# bind that email address to their mxid on an identity server. Typically, this
+# requires the user to validate their email address with the identity server.
+# However if Synapse itself is handling email validation on registration, the
+# user ends up needing to validate their email twice, which leads to poor UX.
+#
+# It is possible to force Sydent, one identity server implementation, to bind
+# threepids using its internal, unauthenticated bind API:
+# https://github.com/matrix-org/sydent/#internal-bind-and-unbind-api
+#
+# Configure the address of a Sydent server here to have Synapse attempt
+# to automatically bind users' emails following registration. The
+# internal bind API must be reachable from Synapse, but should NOT be
+# exposed to any third party, as it allows the creation of bindings
+# without validation.
+#
+#bind_new_user_emails_to_sydent: https://example.com:8091
+
+
+## Account Validity ##
+#
+# Optional account validity configuration. This allows for accounts to be denied
+# any request after a given period.
+#
+# Once this feature is enabled, Synapse will look for registered users without an
+# expiration date at startup and will add one to every account it found using the
+# current settings at that time.
+# This means that, if a validity period is set, and Synapse is restarted (it will
+# then derive an expiration date from the current validity period), and some time
+# after that the validity period changes and Synapse is restarted, the users'
+# expiration dates won't be updated unless their account is manually renewed. This
+# date will be randomly selected within a range [now + period - d ; now + period],
+# where d is equal to 10% of the validity period.
+#
+account_validity:
+ # The account validity feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # The period after which an account is valid after its registration. When
+ # renewing the account, its validity period will be extended by this amount
+ # of time. This parameter is required when using the account validity
+ # feature.
+ #
+ #period: 6w
+
+ # The amount of time before an account's expiry date at which Synapse will
+ # send an email to the account's email address with a renewal link. By
+ # default, no such emails are sent.
+ #
+ # If you enable this setting, you will also need to fill out the 'email' and
+ # 'public_baseurl' configuration sections.
+ #
+ #renew_at: 1w
+
+ # The subject of the email sent out with the renewal link. '%(app)s' can be
+ # used as a placeholder for the 'app_name' parameter from the 'email'
+ # section.
+ #
+ # Note that the placeholder must be written '%(app)s', including the
+ # trailing 's'.
+ #
+ # If this is not set, a default value is used.
+ #
+ #renew_email_subject: "Renew your %(app)s account"
+
+ # Directory in which Synapse will try to find templates for the HTML files to
+ # serve to the user when trying to renew an account. If not set, default
+ # templates from within the Synapse package will be used.
+ #
+ #template_dir: "res/templates"
+
+ # File within 'template_dir' giving the HTML to be displayed to the user after
+ # they successfully renewed their account. If not set, default text is used.
+ #
+ #account_renewed_html_path: "account_renewed.html"
+
+ # File within 'template_dir' giving the HTML to be displayed when the user
+ # tries to renew an account with an invalid renewal token. If not set,
+ # default text is used.
+ #
+ #invalid_token_html_path: "invalid_token.html"
+
## Metrics ###
@@ -1416,7 +1607,9 @@ metrics_flags:
## API Configuration ##
-# A list of event types that will be included in the room_invite_state
+# A list of event types from a room that will be given to users when they
+# are invited to a room. This allows clients to display information about the
+# room that they've been invited to, without actually being in the room yet.
#
#room_invite_state_types:
# - "m.room.join_rules"
@@ -1893,10 +2086,6 @@ cas_config:
#
#server_url: "https://cas-server.com"
- # The public URL of the homeserver.
- #
- #service_url: "https://homeserver.domain.com:8448"
-
# The attribute of the CAS response to use as the display name.
#
# If unset, no displayname will be set.
@@ -2528,9 +2717,19 @@ spam_checker:
# rebuild the user_directory search indexes, see
# https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
#
+# 'prefer_local_users' defines whether to prioritise local users in
+# search query results. If True, local users are more likely to appear above
+# remote users when searching the user directory. Defaults to false.
+#
#user_directory:
# enabled: true
# search_all_users: false
+# prefer_local_users: false
+#
+# # If this is set, user search will be delegated to this ID server instead
+# # of synapse performing the search itself.
+# # This is an experimental API.
+# defer_to_id_server: https://id.example.com
# User Consent configuration
diff --git a/docs/workers.md b/docs/workers.md
index 0da805c333..c36549c621 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -40,6 +40,9 @@ which relays replication commands between processes. This can give a significant
cpu saving on the main process and will be a prerequisite for upcoming
performance improvements.
+If Redis support is enabled Synapse will use it as a shared cache, as well as a
+pub/sub mechanism.
+
See the [Architectural diagram](#architectural-diagram) section at the end for
a visualisation of what this looks like.
diff --git a/mypy.ini b/mypy.ini
index bd99069c81..68a4533973 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -23,39 +23,7 @@ files =
synapse/events/validator.py,
synapse/events/spamcheck.py,
synapse/federation,
- synapse/handlers/_base.py,
- synapse/handlers/account_data.py,
- synapse/handlers/account_validity.py,
- synapse/handlers/admin.py,
- synapse/handlers/appservice.py,
- synapse/handlers/auth.py,
- synapse/handlers/cas_handler.py,
- synapse/handlers/deactivate_account.py,
- synapse/handlers/device.py,
- synapse/handlers/devicemessage.py,
- synapse/handlers/directory.py,
- synapse/handlers/events.py,
- synapse/handlers/federation.py,
- synapse/handlers/identity.py,
- synapse/handlers/initial_sync.py,
- synapse/handlers/message.py,
- synapse/handlers/oidc_handler.py,
- synapse/handlers/pagination.py,
- synapse/handlers/password_policy.py,
- synapse/handlers/presence.py,
- synapse/handlers/profile.py,
- synapse/handlers/read_marker.py,
- synapse/handlers/receipts.py,
- synapse/handlers/register.py,
- synapse/handlers/room.py,
- synapse/handlers/room_list.py,
- synapse/handlers/room_member.py,
- synapse/handlers/room_member_worker.py,
- synapse/handlers/saml_handler.py,
- synapse/handlers/sso.py,
- synapse/handlers/sync.py,
- synapse/handlers/user_directory.py,
- synapse/handlers/ui_auth,
+ synapse/handlers,
synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/well_known_resolver.py,
@@ -194,3 +162,9 @@ ignore_missing_imports = True
[mypy-hiredis]
ignore_missing_imports = True
+
+[mypy-josepy.*]
+ignore_missing_imports = True
+
+[mypy-txacme.*]
+ignore_missing_imports = True
diff --git a/res/templates-dinsic/mail-Vector.css b/res/templates-dinsic/mail-Vector.css
new file mode 100644
index 0000000000..6a3e36eda1
--- /dev/null
+++ b/res/templates-dinsic/mail-Vector.css
@@ -0,0 +1,7 @@
+.header {
+ border-bottom: 4px solid #e4f7ed ! important;
+}
+
+.notif_link a, .footer a {
+ color: #76CFA6 ! important;
+}
diff --git a/res/templates-dinsic/mail.css b/res/templates-dinsic/mail.css
new file mode 100644
index 0000000000..5ab3e1b06d
--- /dev/null
+++ b/res/templates-dinsic/mail.css
@@ -0,0 +1,156 @@
+body {
+ margin: 0px;
+}
+
+pre, code {
+ word-break: break-word;
+ white-space: pre-wrap;
+}
+
+#page {
+ font-family: 'Open Sans', Helvetica, Arial, Sans-Serif;
+ font-color: #454545;
+ font-size: 12pt;
+ width: 100%;
+ padding: 20px;
+}
+
+#inner {
+ width: 640px;
+}
+
+.header {
+ width: 100%;
+ height: 87px;
+ color: #454545;
+ border-bottom: 4px solid #e5e5e5;
+}
+
+.logo {
+ text-align: right;
+ margin-left: 20px;
+}
+
+.salutation {
+ padding-top: 10px;
+ font-weight: bold;
+}
+
+.summarytext {
+}
+
+.room {
+ width: 100%;
+ color: #454545;
+ border-bottom: 1px solid #e5e5e5;
+}
+
+.room_header td {
+ padding-top: 38px;
+ padding-bottom: 10px;
+ border-bottom: 1px solid #e5e5e5;
+}
+
+.room_name {
+ vertical-align: middle;
+ font-size: 18px;
+ font-weight: bold;
+}
+
+.room_header h2 {
+ margin-top: 0px;
+ margin-left: 75px;
+ font-size: 20px;
+}
+
+.room_avatar {
+ width: 56px;
+ line-height: 0px;
+ text-align: center;
+ vertical-align: middle;
+}
+
+.room_avatar img {
+ width: 48px;
+ height: 48px;
+ object-fit: cover;
+ border-radius: 24px;
+}
+
+.notif {
+ border-bottom: 1px solid #e5e5e5;
+ margin-top: 16px;
+ padding-bottom: 16px;
+}
+
+.historical_message .sender_avatar {
+ opacity: 0.3;
+}
+
+/* spell out opacity and historical_message class names for Outlook aka Word */
+.historical_message .sender_name {
+ color: #e3e3e3;
+}
+
+.historical_message .message_time {
+ color: #e3e3e3;
+}
+
+.historical_message .message_body {
+ color: #c7c7c7;
+}
+
+.historical_message td,
+.message td {
+ padding-top: 10px;
+}
+
+.sender_avatar {
+ width: 56px;
+ text-align: center;
+ vertical-align: top;
+}
+
+.sender_avatar img {
+ margin-top: -2px;
+ width: 32px;
+ height: 32px;
+ border-radius: 16px;
+}
+
+.sender_name {
+ display: inline;
+ font-size: 13px;
+ color: #a2a2a2;
+}
+
+.message_time {
+ text-align: right;
+ width: 100px;
+ font-size: 11px;
+ color: #a2a2a2;
+}
+
+.message_body {
+}
+
+.notif_link td {
+ padding-top: 10px;
+ padding-bottom: 10px;
+ font-weight: bold;
+}
+
+.notif_link a, .footer a {
+ color: #454545;
+ text-decoration: none;
+}
+
+.debug {
+ font-size: 10px;
+ color: #888;
+}
+
+.footer {
+ margin-top: 20px;
+ text-align: center;
+}
\ No newline at end of file
diff --git a/res/templates-dinsic/notif.html b/res/templates-dinsic/notif.html
new file mode 100644
index 0000000000..bcdfeea9da
--- /dev/null
+++ b/res/templates-dinsic/notif.html
@@ -0,0 +1,45 @@
+{% for message in notif.messages %}
+ <tr class="{{ "historical_message" if message.is_historical else "message" }}">
+ <td class="sender_avatar">
+ {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
+ {% if message.sender_avatar_url %}
+ <img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
+ {% else %}
+ {% if message.sender_hash % 3 == 0 %}
+ <img class="sender_avatar" src="https://vector.im/beta/img/76cfa6.png" />
+ {% elif message.sender_hash % 3 == 1 %}
+ <img class="sender_avatar" src="https://vector.im/beta/img/50e2c2.png" />
+ {% else %}
+ <img class="sender_avatar" src="https://vector.im/beta/img/f4c371.png" />
+ {% endif %}
+ {% endif %}
+ {% endif %}
+ </td>
+ <td class="message_contents">
+ {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
+ <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
+ {% endif %}
+ <div class="message_body">
+ {% if message.msgtype == "m.text" %}
+ {{ message.body_text_html }}
+ {% elif message.msgtype == "m.emote" %}
+ {{ message.body_text_html }}
+ {% elif message.msgtype == "m.notice" %}
+ {{ message.body_text_html }}
+ {% elif message.msgtype == "m.image" %}
+ <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
+ {% elif message.msgtype == "m.file" %}
+ <span class="filename">{{ message.body_text_plain }}</span>
+ {% endif %}
+ </div>
+ </td>
+ <td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
+ </tr>
+{% endfor %}
+<tr class="notif_link">
+ <td></td>
+ <td>
+ <a href="{{ notif.link }}">Voir {{ room.title }}</a>
+ </td>
+ <td></td>
+</tr>
diff --git a/res/templates-dinsic/notif.txt b/res/templates-dinsic/notif.txt
new file mode 100644
index 0000000000..3dff1bb570
--- /dev/null
+++ b/res/templates-dinsic/notif.txt
@@ -0,0 +1,16 @@
+{% for message in notif.messages %}
+{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
+{% if message.msgtype == "m.text" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.emote" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.notice" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.image" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.file" %}
+{{ message.body_text_plain }}
+{% endif %}
+{% endfor %}
+
+Voir {{ room.title }} à {{ notif.link }}
diff --git a/res/templates-dinsic/notif_mail.html b/res/templates-dinsic/notif_mail.html
new file mode 100644
index 0000000000..1e1efa74b2
--- /dev/null
+++ b/res/templates-dinsic/notif_mail.html
@@ -0,0 +1,55 @@
+<!doctype html>
+<html lang="en">
+ <head>
+ <style type="text/css">
+ {% include 'mail.css' without context %}
+ {% include "mail-%s.css" % app_name ignore missing without context %}
+ </style>
+ </head>
+ <body>
+ <table id="page">
+ <tr>
+ <td> </td>
+ <td id="inner">
+ <table class="header">
+ <tr>
+ <td>
+ <div class="salutation">Bonjour {{ user_display_name }},</div>
+ <div class="summarytext">{{ summary_text }}</div>
+ </td>
+ <td class="logo">
+ {% if app_name == "Riot" %}
+ <img src="http://matrix.org/img/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
+ {% elif app_name == "Vector" %}
+ <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
+ {% else %}
+ <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
+ {% endif %}
+ </td>
+ </tr>
+ </table>
+ {% for room in rooms %}
+ {% include 'room.html' with context %}
+ {% endfor %}
+ <div class="footer">
+ <a href="{{ unsubscribe_link }}">Se désinscrire</a>
+ <br/>
+ <br/>
+ <div class="debug">
+ Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
+ an event was received at {{ reason.received_at|format_ts("%c") }}
+ which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
+ {% if reason.last_sent_ts %}
+ and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
+ which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
+ {% else %}
+ and we don't have a last time we sent a mail for this room.
+ {% endif %}
+ </div>
+ </div>
+ </td>
+ <td> </td>
+ </tr>
+ </table>
+ </body>
+</html>
diff --git a/res/templates-dinsic/notif_mail.txt b/res/templates-dinsic/notif_mail.txt
new file mode 100644
index 0000000000..fae877426f
--- /dev/null
+++ b/res/templates-dinsic/notif_mail.txt
@@ -0,0 +1,10 @@
+Bonjour {{ user_display_name }},
+
+{{ summary_text }}
+
+{% for room in rooms %}
+{% include 'room.txt' with context %}
+{% endfor %}
+
+Vous pouvez désactiver ces notifications en cliquant ici {{ unsubscribe_link }}
+
diff --git a/res/templates-dinsic/room.html b/res/templates-dinsic/room.html
new file mode 100644
index 0000000000..0487b1b11c
--- /dev/null
+++ b/res/templates-dinsic/room.html
@@ -0,0 +1,33 @@
+<table class="room">
+ <tr class="room_header">
+ <td class="room_avatar">
+ {% if room.avatar_url %}
+ <img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
+ {% else %}
+ {% if room.hash % 3 == 0 %}
+ <img alt="" src="https://vector.im/beta/img/76cfa6.png" />
+ {% elif room.hash % 3 == 1 %}
+ <img alt="" src="https://vector.im/beta/img/50e2c2.png" />
+ {% else %}
+ <img alt="" src="https://vector.im/beta/img/f4c371.png" />
+ {% endif %}
+ {% endif %}
+ </td>
+ <td class="room_name" colspan="2">
+ {{ room.title }}
+ </td>
+ </tr>
+ {% if room.invite %}
+ <tr>
+ <td></td>
+ <td>
+ <a href="{{ room.link }}">Rejoindre la conversation.</a>
+ </td>
+ <td></td>
+ </tr>
+ {% else %}
+ {% for notif in room.notifs %}
+ {% include 'notif.html' with context %}
+ {% endfor %}
+ {% endif %}
+</table>
diff --git a/res/templates-dinsic/room.txt b/res/templates-dinsic/room.txt
new file mode 100644
index 0000000000..dd36d01d21
--- /dev/null
+++ b/res/templates-dinsic/room.txt
@@ -0,0 +1,9 @@
+{{ room.title }}
+
+{% if room.invite %}
+ Vous avez été invité, rejoignez la conversation en cliquant sur le lien suivant {{ room.link }}
+{% else %}
+ {% for notif in room.notifs %}
+ {% include 'notif.txt' with context %}
+ {% endfor %}
+{% endif %}
diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment
index 448cadb829..d742c522b5 100755
--- a/scripts-dev/check-newsfragment
+++ b/scripts-dev/check-newsfragment
@@ -7,9 +7,9 @@ echo -e "+++ \033[32mChecking newsfragment\033[m"
set -e
-# make sure that origin/develop is up to date
-git remote set-branches --add origin develop
-git fetch -q origin develop
+# make sure that origin/dinsic is up to date
+git remote set-branches --add origin dinsic
+git fetch -q origin dinsic
pr="$BUILDKITE_PULL_REQUEST"
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index f328ab57d5..fe2965cd36 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -80,7 +80,8 @@ else
# then lint everything!
if [[ -z ${files+x} ]]; then
# Lint all source code files and directories
- files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark")
+ # Note: this list aims the mirror the one in tox.ini
+ files=("synapse" "docker" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite")
fi
fi
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 69bf9110a6..00d638eb9a 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -47,6 +47,7 @@ from synapse.storage.databases.main.events_bg_updates import (
from synapse.storage.databases.main.media_repository import (
MediaRepositoryBackgroundUpdateStore,
)
+from synapse.storage.databases.main.profile import ProfileStore
from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart,
@@ -168,6 +169,7 @@ class Store(
DeviceBackgroundUpdateStore,
EventsBackgroundUpdatesStore,
MediaRepositoryBackgroundUpdateStore,
+ ProfileStore,
RegistrationBackgroundUpdateStore,
RoomBackgroundUpdateStore,
RoomMemberBackgroundUpdateStore,
diff --git a/setup.py b/setup.py
index ddbe9f511a..99425d52de 100755
--- a/setup.py
+++ b/setup.py
@@ -96,7 +96,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
#
# We pin black so that our tests don't start failing on new releases.
CONDITIONAL_REQUIREMENTS["lint"] = [
- "isort==5.0.3",
+ "isort==5.7.0",
"black==19.10b0",
"flake8-comprehensions",
"flake8",
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index bfac6840e6..618548a305 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -15,13 +15,23 @@
"""Contains *incomplete* type hints for txredisapi.
"""
-
-from typing import List, Optional, Type, Union
+from typing import Any, List, Optional, Type, Union
class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
+ async def ping(self) -> None: ...
+ async def set(
+ self,
+ key: str,
+ value: Any,
+ expire: Optional[int] = None,
+ pexpire: Optional[int] = None,
+ only_if_not_exists: bool = False,
+ only_if_exists: bool = False,
+ ) -> None: ...
+ async def get(self, key: str) -> Any: ...
-class SubscriberProtocol:
+class SubscriberProtocol(RedisProtocol):
def __init__(self, *args, **kwargs): ...
password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ...
@@ -40,14 +50,13 @@ def lazyConnection(
convertNumbers: bool = ...,
) -> RedisProtocol: ...
-class SubscriberFactory:
- def buildProtocol(self, addr): ...
-
class ConnectionHandler: ...
class RedisFactory:
continueTrying: bool
handler: RedisProtocol
+ pool: List[RedisProtocol]
+ replyTimeout: Optional[int]
def __init__(
self,
uuid: str,
@@ -60,3 +69,7 @@ class RedisFactory:
replyTimeout: Optional[int] = None,
convertNumbers: Optional[int] = True,
): ...
+ def buildProtocol(self, addr) -> RedisProtocol: ...
+
+class SubscriberFactory(RedisFactory):
+ def __init__(self): ...
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 67ecbd32ff..b575d85976 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -79,7 +79,7 @@ class Auth:
self._auth_blocking = AuthBlocking(self.hs)
- self._account_validity = hs.config.account_validity
+ self._account_validity_enabled = hs.config.account_validity_enabled
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
@@ -192,7 +192,7 @@ class Auth:
access_token = self.get_access_token_from_request(request)
- user_id, app_service = await self._get_appservice_user_id(request)
+ user_id, app_service = self._get_appservice_user_id(request)
if user_id:
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
@@ -222,7 +222,7 @@ class Auth:
shadow_banned = user_info.shadow_banned
# Deny the request if the user account has expired.
- if self._account_validity.enabled and not allow_expired:
+ if self._account_validity_enabled and not allow_expired:
if await self.store.is_account_expired(
user_info.user_id, self.clock.time_msec()
):
@@ -268,10 +268,11 @@ class Auth:
except KeyError:
raise MissingClientTokenError()
- async def _get_appservice_user_id(self, request):
+ def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
+
if app_service is None:
return None, None
@@ -289,8 +290,12 @@ class Auth:
if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
- if not (await self.store.get_user_by_id(user_id)):
- raise AuthError(403, "Application service has not registered this user")
+ # Let ASes manipulate nonexistent users (e.g. to shadow-register them)
+ # if not (yield self.store.get_user_by_id(user_id)):
+ # raise AuthError(
+ # 403,
+ # "Application service has not registered this user"
+ # )
return user_id, app_service
async def get_user_by_access_token(
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index cd6670d0a2..90bb01f414 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index de2cc15d33..139fbf5524 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -57,7 +57,7 @@ class RoomVersion:
state_res = attr.ib(type=int) # one of the StateResolutionVersions
enforce_key_validity = attr.ib(type=bool)
- # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
+ # Before MSC2432, m.room.aliases had special auth rules and redaction rules
special_case_aliases_auth = attr.ib(type=bool)
# Strictly enforce canonicaljson, do not allow:
# * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
@@ -69,6 +69,11 @@ class RoomVersion:
limit_notifications_power_levels = attr.ib(type=bool)
# MSC2174/MSC2176: Apply updated redaction rules algorithm.
msc2176_redaction_rules = attr.ib(type=bool)
+ # MSC2174/MSC2176: Apply updated redaction rules algorithm.
+ msc2176_redaction_rules = attr.ib(type=bool)
+ # MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending
+ # m.room.membership event with membership 'knock'.
+ allow_knocking = attr.ib(type=bool)
class RoomVersions:
@@ -82,6 +87,7 @@ class RoomVersions:
strict_canonicaljson=False,
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
+ allow_knocking=False,
)
V2 = RoomVersion(
"2",
@@ -93,6 +99,7 @@ class RoomVersions:
strict_canonicaljson=False,
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
+ allow_knocking=False,
)
V3 = RoomVersion(
"3",
@@ -104,6 +111,7 @@ class RoomVersions:
strict_canonicaljson=False,
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
+ allow_knocking=False,
)
V4 = RoomVersion(
"4",
@@ -115,6 +123,7 @@ class RoomVersions:
strict_canonicaljson=False,
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
+ allow_knocking=False,
)
V5 = RoomVersion(
"5",
@@ -126,6 +135,7 @@ class RoomVersions:
strict_canonicaljson=False,
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
+ allow_knocking=False,
)
V6 = RoomVersion(
"6",
@@ -137,6 +147,19 @@ class RoomVersions:
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
+ allow_knocking=False,
+ )
+ V7 = RoomVersion(
+ "7",
+ RoomDisposition.UNSTABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
+ special_case_aliases_auth=False,
+ strict_canonicaljson=True,
+ limit_notifications_power_levels=True,
+ msc2176_redaction_rules=False,
+ allow_knocking=True,
)
MSC2176 = RoomVersion(
"org.matrix.msc2176",
@@ -148,6 +171,7 @@ class RoomVersions:
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=True,
+ allow_knocking=False,
)
@@ -160,6 +184,7 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V4,
RoomVersions.V5,
RoomVersions.V6,
+ RoomVersions.V7,
RoomVersions.MSC2176,
)
} # type: Dict[str, RoomVersion]
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 395e202b89..9840a9d55b 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -16,6 +16,7 @@
import gc
import logging
import os
+import platform
import signal
import socket
import sys
@@ -339,7 +340,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
# rest of time. Doing so means less work each GC (hopefully).
#
# This only works on Python 3.7
- if sys.version_info >= (3, 7):
+ if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7):
gc.collect()
gc.freeze()
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 516f2464b4..e363d681fd 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -165,7 +165,7 @@ class PresenceStatusStubServlet(RestServlet):
async def on_GET(self, request, user_id):
await self.auth.get_user_by_req(request)
- return 200, {"presence": "offline"}
+ return 200, {"presence": "offline", "user_id": user_id}
async def on_PUT(self, request, user_id):
await self.auth.get_user_by_req(request)
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index c38cf8231f..8f86cecb76 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -93,15 +93,20 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+ daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms()
+ stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
+ stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages()
+ daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages()
+ stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
+ daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
+ stats["daily_sent_messages"] = daily_sent_messages
r30_results = await hs.get_datastore().count_r30_users()
for name, count in r30_results.items():
stats["r30_users_" + name] = count
- daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
- stats["daily_sent_messages"] = daily_sent_messages
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index e366a982b8..5dcd9ea290 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, ThirdPartyEntityKind
+from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events import EventBase
from synapse.events.utils import serialize_event
@@ -249,9 +249,14 @@ class ApplicationServiceApi(SimpleHttpClient):
e,
time_now,
as_client_event=True,
- is_invite=(
+ # If this is an invite or a knock membership event, and we're interested
+ # in this user, then include any stripped state alongside the event.
+ include_stripped_room_state=(
e.type == EventTypes.Member
- and e.membership == "invite"
+ and (
+ e.membership == Membership.INVITE
+ or e.membership == Membership.KNOCK
+ )
and service.is_interested_in_user(e.state_key)
),
)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 35e5594b73..c629990dd4 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -20,6 +20,7 @@ import errno
import os
from collections import OrderedDict
from hashlib import sha256
+from io import open as io_open
from textwrap import dedent
from typing import Any, Iterable, List, MutableMapping, Optional
@@ -200,14 +201,31 @@ class Config:
@classmethod
def read_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
- with open(file_path) as file_stream:
+ with io_open(file_path, encoding="utf-8") as file_stream:
return file_stream.read()
+ def read_template(self, filename: str) -> jinja2.Template:
+ """Load a template file from disk.
+
+ This function will attempt to load the given template from the default Synapse
+ template directory.
+
+ Files read are treated as Jinja templates. The templates is not rendered yet
+ and has autoescape enabled.
+
+ Args:
+ filename: A template filename to read.
+
+ Raises:
+ ConfigError: if the file's path is incorrect or otherwise cannot be read.
+
+ Returns:
+ A jinja2 template.
+ """
+ return self.read_templates([filename])[0]
+
def read_templates(
- self,
- filenames: List[str],
- custom_template_directory: Optional[str] = None,
- autoescape: bool = False,
+ self, filenames: List[str], custom_template_directory: Optional[str] = None,
) -> List[jinja2.Template]:
"""Load a list of template files from disk using the given variables.
@@ -215,7 +233,8 @@ class Config:
template directory. If `custom_template_directory` is supplied, that directory
is tried first.
- Files read are treated as Jinja templates. These templates are not rendered yet.
+ Files read are treated as Jinja templates. The templates are not rendered yet
+ and have autoescape enabled.
Args:
filenames: A list of template filenames to read.
@@ -223,16 +242,12 @@ class Config:
custom_template_directory: A directory to try to look for the templates
before using the default Synapse template directory instead.
- autoescape: Whether to autoescape variables before inserting them into the
- template.
-
Raises:
ConfigError: if the file's path is incorrect or otherwise cannot be read.
Returns:
A list of jinja2 templates.
"""
- templates = []
search_directories = [self.default_template_dir]
# The loader will first look in the custom template directory (if specified) for the
@@ -250,7 +265,7 @@ class Config:
# TODO: switch to synapse.util.templates.build_jinja_env
loader = jinja2.FileSystemLoader(search_directories)
- env = jinja2.Environment(loader=loader, autoescape=autoescape)
+ env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),)
# Update the environment with our custom filters
env.filters.update(
@@ -260,12 +275,8 @@ class Config:
}
)
- for filename in filenames:
- # Load the template
- template = env.get_template(filename)
- templates.append(template)
-
- return templates
+ # Load the templates
+ return [env.get_template(filename) for filename in filenames]
class RootConfig:
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 3ccea4b02d..0565418e60 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,6 +1,7 @@
from typing import Any, Iterable, List, Optional
from synapse.config import (
+ account_validity,
api,
appservice,
auth,
@@ -19,6 +20,7 @@ from synapse.config import (
password_auth_providers,
push,
ratelimiting,
+ redis,
registration,
repository,
room_directory,
@@ -53,11 +55,12 @@ class RootConfig:
tls: tls.TlsConfig
database: database.DatabaseConfig
logging: logger.LoggingConfig
- ratelimit: ratelimiting.RatelimitConfig
+ ratelimiting: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig
captcha: captcha.CaptchaConfig
voip: voip.VoipConfig
registration: registration.RegistrationConfig
+ account_validity: account_validity.AccountValidityConfig
metrics: metrics.MetricsConfig
api: api.ApiConfig
appservice: appservice.AppServiceConfig
@@ -81,6 +84,7 @@ class RootConfig:
roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig
+ redis: redis.RedisConfig
config_classes: List = ...
def __init__(self) -> None: ...
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
new file mode 100644
index 0000000000..6d107944a3
--- /dev/null
+++ b/synapse/config/account_validity.py
@@ -0,0 +1,149 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.config._base import Config, ConfigError
+
+
+class AccountValidityConfig(Config):
+ section = "account_validity"
+
+ def read_config(self, config, **kwargs):
+ account_validity_config = config.get("account_validity") or {}
+ self.account_validity_enabled = account_validity_config.get("enabled", False)
+ self.account_validity_renew_by_email_enabled = (
+ "renew_at" in account_validity_config
+ )
+
+ if self.account_validity_enabled:
+ if "period" in account_validity_config:
+ self.account_validity_period = self.parse_duration(
+ account_validity_config["period"]
+ )
+ else:
+ raise ConfigError("'period' is required when using account validity")
+
+ if "renew_at" in account_validity_config:
+ self.account_validity_renew_at = self.parse_duration(
+ account_validity_config["renew_at"]
+ )
+
+ if "renew_email_subject" in account_validity_config:
+ self.account_validity_renew_email_subject = account_validity_config[
+ "renew_email_subject"
+ ]
+ else:
+ self.account_validity_renew_email_subject = "Renew your %(app)s account"
+
+ self.account_validity_startup_job_max_delta = (
+ self.account_validity_period * 10.0 / 100.0
+ )
+
+ if self.account_validity_renew_by_email_enabled:
+ if not self.public_baseurl:
+ raise ConfigError("Can't send renewal emails without 'public_baseurl'")
+
+ # Load account validity templates.
+ # We do this here instead of in AccountValidityConfig as read_templates
+ # relies on state that hasn't been initialised in AccountValidityConfig
+ account_renewed_template_filename = account_validity_config.get(
+ "account_renewed_html_path", "account_renewed.html"
+ )
+ account_previously_renewed_template_filename = account_validity_config.get(
+ "account_previously_renewed_html_path", "account_previously_renewed.html"
+ )
+ invalid_token_template_filename = account_validity_config.get(
+ "invalid_token_html_path", "invalid_token.html"
+ )
+ custom_template_directory = account_validity_config.get("template_dir")
+
+ (
+ self.account_validity_account_renewed_template,
+ self.account_validity_account_previously_renewed_template,
+ self.account_validity_invalid_token_template,
+ ) = self.read_templates(
+ [
+ account_renewed_template_filename,
+ account_previously_renewed_template_filename,
+ invalid_token_template_filename,
+ ],
+ custom_template_directory=custom_template_directory,
+ )
+
+ def generate_config_section(self, **kwargs):
+ return """\
+ ## Account Validity ##
+ #
+ # Optional account validity configuration. This allows for accounts to be denied
+ # any request after a given period.
+ #
+ # Once this feature is enabled, Synapse will look for registered users without an
+ # expiration date at startup and will add one to every account it found using the
+ # current settings at that time.
+ # This means that, if a validity period is set, and Synapse is restarted (it will
+ # then derive an expiration date from the current validity period), and some time
+ # after that the validity period changes and Synapse is restarted, the users'
+ # expiration dates won't be updated unless their account is manually renewed. This
+ # date will be randomly selected within a range [now + period - d ; now + period],
+ # where d is equal to 10% of the validity period.
+ #
+ account_validity:
+ # The account validity feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # The period after which an account is valid after its registration. When
+ # renewing the account, its validity period will be extended by this amount
+ # of time. This parameter is required when using the account validity
+ # feature.
+ #
+ #period: 6w
+
+ # The amount of time before an account's expiry date at which Synapse will
+ # send an email to the account's email address with a renewal link. By
+ # default, no such emails are sent.
+ #
+ # If you enable this setting, you will also need to fill out the 'email' and
+ # 'public_baseurl' configuration sections.
+ #
+ #renew_at: 1w
+
+ # The subject of the email sent out with the renewal link. '%(app)s' can be
+ # used as a placeholder for the 'app_name' parameter from the 'email'
+ # section.
+ #
+ # Note that the placeholder must be written '%(app)s', including the
+ # trailing 's'.
+ #
+ # If this is not set, a default value is used.
+ #
+ #renew_email_subject: "Renew your %(app)s account"
+
+ # Directory in which Synapse will try to find templates for the HTML files to
+ # serve to the user when trying to renew an account. If not set, default
+ # templates from within the Synapse package will be used.
+ #
+ #template_dir: "res/templates"
+
+ # File within 'template_dir' giving the HTML to be displayed to the user after
+ # they successfully renewed their account. If not set, default text is used.
+ #
+ #account_renewed_html_path: "account_renewed.html"
+
+ # File within 'template_dir' giving the HTML to be displayed when the user
+ # tries to renew an account with an invalid renewal token. If not set,
+ # default text is used.
+ #
+ #invalid_token_html_path: "invalid_token.html"
+ """
diff --git a/synapse/config/api.py b/synapse/config/api.py
index 74cd53a8ed..0638ed8d2e 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,27 +17,31 @@ from synapse.api.constants import EventTypes
from ._base import Config
+# The default types of room state to send to users to are invited to or knock on a room.
+DEFAULT_ROOM_STATE_TYPES = [
+ EventTypes.JoinRules,
+ EventTypes.CanonicalAlias,
+ EventTypes.RoomAvatar,
+ EventTypes.RoomEncryption,
+ EventTypes.Name,
+]
+
class ApiConfig(Config):
section = "api"
def read_config(self, config, **kwargs):
self.room_invite_state_types = config.get(
- "room_invite_state_types",
- [
- EventTypes.JoinRules,
- EventTypes.CanonicalAlias,
- EventTypes.RoomAvatar,
- EventTypes.RoomEncryption,
- EventTypes.Name,
- ],
+ "room_invite_state_types", DEFAULT_ROOM_STATE_TYPES
)
def generate_config_section(cls, **kwargs):
return """\
## API Configuration ##
- # A list of event types that will be included in the room_invite_state
+ # A list of event types from a room that will be given to users when they
+ # are invited to a room. This allows clients to display information about the
+ # room that they've been invited to, without actually being in the room yet.
#
#room_invite_state_types:
# - "{JoinRules}"
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index cb00958165..9e48f865cc 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -28,9 +28,7 @@ class CaptchaConfig(Config):
"recaptcha_siteverify_api",
"https://www.recaptcha.net/recaptcha/api/siteverify",
)
- self.recaptcha_template = self.read_templates(
- ["recaptcha.html"], autoescape=True
- )[0]
+ self.recaptcha_template = self.read_template("recaptcha.html")
def generate_config_section(self, **kwargs):
return """\
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index c7877b4095..b226890c2a 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -30,7 +30,13 @@ class CasConfig(Config):
if self.cas_enabled:
self.cas_server_url = cas_config["server_url"]
- self.cas_service_url = cas_config["service_url"]
+ public_base_url = cas_config.get("service_url") or self.public_baseurl
+ if public_base_url[-1] != "/":
+ public_base_url += "/"
+ # TODO Update this to a _synapse URL.
+ self.cas_service_url = (
+ public_base_url + "_matrix/client/r0/login/cas/ticket"
+ )
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
self.cas_required_attributes = cas_config.get("required_attributes") or {}
else:
@@ -53,10 +59,6 @@ class CasConfig(Config):
#
#server_url: "https://cas-server.com"
- # The public URL of the homeserver.
- #
- #service_url: "https://homeserver.domain.com:8448"
-
# The attribute of the CAS response to use as the display name.
#
# If unset, no displayname will be set.
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index 6efa59b110..c47f364b14 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -89,7 +89,7 @@ class ConsentConfig(Config):
def read_config(self, config, **kwargs):
consent_config = config.get("user_consent")
- self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0]
+ self.terms_template = self.read_template("terms.html")
if consent_config is None:
return
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 6a487afd34..458f5eb0da 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -291,7 +291,7 @@ class EmailConfig(Config):
"client_base_url", email_config.get("riot_base_url", None)
)
- if self.account_validity.renew_by_email_enabled:
+ if self.account_validity_renew_by_email_enabled:
expiry_template_html = email_config.get(
"expiry_template_html", "notice_expiry.html"
)
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index b1c1c51e4d..ba9d37553b 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.config._base import Config
from synapse.types import JsonDict
@@ -25,5 +26,11 @@ class ExperimentalConfig(Config):
def read_config(self, config: JsonDict, **kwargs):
experimental = config.get("experimental_features") or {}
+ # MSC2403 (room knocking)
+ self.msc2403_enabled = experimental.get("msc2403_enabled", False) # type: bool
+ if self.msc2403_enabled:
+ # Enable the MSC2403 unstable room version
+ KNOWN_ROOM_VERSIONS.update({RoomVersions.V7.identifier: RoomVersions.V7})
+
# MSC2858 (multiple SSO identity providers)
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 64a2429f77..58961679ff 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -13,8 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from ._base import RootConfig
+from .account_validity import AccountValidityConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
from .auth import AuthConfig
@@ -69,6 +69,7 @@ class HomeServerConfig(RootConfig):
CaptchaConfig,
VoipConfig,
RegistrationConfig,
+ AccountValidityConfig,
MetricsConfig,
ApiConfig,
AppServiceConfig,
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 784b416f95..bb122ef182 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -53,8 +53,7 @@ class OIDCConfig(Config):
"Multiple OIDC providers have the idp_id %r." % idp_id
)
- public_baseurl = self.public_baseurl
- self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
+ self.oidc_callback_url = self.public_baseurl + "_synapse/oidc/callback"
@property
def oidc_enabled(self) -> bool:
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 14b8836197..070eb1b761 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -24,7 +24,7 @@ class RateLimitConfig:
defaults={"per_second": 0.17, "burst_count": 3.0},
):
self.per_second = config.get("per_second", defaults["per_second"])
- self.burst_count = config.get("burst_count", defaults["burst_count"])
+ self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
class FederationRateLimitConfig:
@@ -76,6 +76,9 @@ class RatelimitConfig(Config):
)
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+ self.rc_third_party_invite = RateLimitConfig(
+ config.get("rc_third_party_invite", {})
+ )
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
@@ -102,6 +105,20 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.01, "burst_count": 3},
)
+ self.rc_3pid_validation = RateLimitConfig(
+ config.get("rc_3pid_validation") or {},
+ defaults={"per_second": 0.003, "burst_count": 5},
+ )
+
+ self.rc_invites_per_room = RateLimitConfig(
+ config.get("rc_invites", {}).get("per_room", {}),
+ defaults={"per_second": 0.3, "burst_count": 10},
+ )
+ self.rc_invites_per_user = RateLimitConfig(
+ config.get("rc_invites", {}).get("per_user", {}),
+ defaults={"per_second": 0.003, "burst_count": 5},
+ )
+
def generate_config_section(self, **kwargs):
return """\
## Ratelimiting ##
@@ -124,6 +141,8 @@ class RatelimitConfig(Config):
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+ # - one that ratelimits third-party invites requests based on the account
+ # that's making the requests.
# - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly.
@@ -131,6 +150,9 @@ class RatelimitConfig(Config):
# users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which
# can be more expensive)
+ # - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
+ # - two for ratelimiting how often invites can be sent in a room or to a
+ # specific user.
#
# The defaults are as shown below.
#
@@ -153,6 +175,10 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
#
+ #rc_third_party_invite:
+ # per_second: 0.2
+ # burst_count: 10
+ #
#rc_admin_redaction:
# per_second: 1
# burst_count: 50
@@ -164,7 +190,18 @@ class RatelimitConfig(Config):
# remote:
# per_second: 0.01
# burst_count: 3
-
+ #
+ #rc_3pid_validation:
+ # per_second: 0.003
+ # burst_count: 5
+ #
+ #rc_invites:
+ # per_room:
+ # per_second: 0.3
+ # burst_count: 10
+ # per_user:
+ # per_second: 0.003
+ # burst_count: 5
# Ratelimiting settings for incoming federation
#
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 4bfc69cb7a..c96530d4e3 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -13,70 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
-
-import pkg_resources
-
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias, UserID
from synapse.util.stringutils import random_string_with_symbols, strtobool
-class AccountValidityConfig(Config):
- section = "accountvalidity"
-
- def __init__(self, config, synapse_config):
- if config is None:
- return
- super().__init__()
- self.enabled = config.get("enabled", False)
- self.renew_by_email_enabled = "renew_at" in config
-
- if self.enabled:
- if "period" in config:
- self.period = self.parse_duration(config["period"])
- else:
- raise ConfigError("'period' is required when using account validity")
-
- if "renew_at" in config:
- self.renew_at = self.parse_duration(config["renew_at"])
-
- if "renew_email_subject" in config:
- self.renew_email_subject = config["renew_email_subject"]
- else:
- self.renew_email_subject = "Renew your %(app)s account"
-
- self.startup_job_max_delta = self.period * 10.0 / 100.0
-
- template_dir = config.get("template_dir")
-
- if not template_dir:
- template_dir = pkg_resources.resource_filename("synapse", "res/templates")
-
- if "account_renewed_html_path" in config:
- file_path = os.path.join(template_dir, config["account_renewed_html_path"])
-
- self.account_renewed_html_content = self.read_file(
- file_path, "account_validity.account_renewed_html_path"
- )
- else:
- self.account_renewed_html_content = (
- "<html><body>Your account has been successfully renewed.</body><html>"
- )
-
- if "invalid_token_html_path" in config:
- file_path = os.path.join(template_dir, config["invalid_token_html_path"])
-
- self.invalid_token_html_content = self.read_file(
- file_path, "account_validity.invalid_token_html_path"
- )
- else:
- self.invalid_token_html_content = (
- "<html><body>Invalid renewal token.</body><html>"
- )
-
-
class RegistrationConfig(Config):
section = "registration"
@@ -89,14 +31,21 @@ class RegistrationConfig(Config):
str(config["disable_registration"])
)
- self.account_validity = AccountValidityConfig(
- config.get("account_validity") or {}, config
- )
-
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
+ self.check_is_for_allowed_local_3pids = config.get(
+ "check_is_for_allowed_local_3pids", None
+ )
+ self.allow_invited_3pids = config.get("allow_invited_3pids", False)
+
+ self.disable_3pid_changes = config.get("disable_3pid_changes", False)
+
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
self.registration_shared_secret = config.get("registration_shared_secret")
+ self.register_mxid_from_3pid = config.get("register_mxid_from_3pid")
+ self.register_just_use_email_for_display_name = config.get(
+ "register_just_use_email_for_display_name", False
+ )
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config.get(
@@ -104,7 +53,21 @@ class RegistrationConfig(Config):
)
account_threepid_delegates = config.get("account_threepid_delegates") or {}
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
+ if (
+ self.account_threepid_delegate_email
+ and not self.account_threepid_delegate_email.startswith("http")
+ ):
+ raise ConfigError(
+ "account_threepid_delegates.email must begin with http:// or https://"
+ )
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
+ if (
+ self.account_threepid_delegate_msisdn
+ and not self.account_threepid_delegate_msisdn.startswith("http")
+ ):
+ raise ConfigError(
+ "account_threepid_delegates.msisdn must begin with http:// or https://"
+ )
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
@@ -166,6 +129,15 @@ class RegistrationConfig(Config):
self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
self.enable_3pid_changes = config.get("enable_3pid_changes", True)
+ self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", [])
+ if not isinstance(self.replicate_user_profiles_to, list):
+ self.replicate_user_profiles_to = [self.replicate_user_profiles_to]
+
+ self.shadow_server = config.get("shadow_server", None)
+ self.rewrite_identity_server_urls = (
+ config.get("rewrite_identity_server_urls") or {}
+ )
+
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
)
@@ -176,9 +148,24 @@ class RegistrationConfig(Config):
self.session_lifetime = session_lifetime
# The success template used during fallback auth.
- self.fallback_success_template = self.read_templates(
- ["auth_success.html"], autoescape=True
- )[0]
+ self.fallback_success_template = self.read_template("auth_success.html")
+
+ self.bind_new_user_emails_to_sydent = config.get(
+ "bind_new_user_emails_to_sydent"
+ )
+
+ if self.bind_new_user_emails_to_sydent:
+ if not isinstance(
+ self.bind_new_user_emails_to_sydent, str
+ ) or not self.bind_new_user_emails_to_sydent.startswith("http"):
+ raise ConfigError(
+ "Option bind_new_user_emails_to_sydent has invalid value"
+ )
+
+ # Remove trailing slashes
+ self.bind_new_user_emails_to_sydent = self.bind_new_user_emails_to_sydent.strip(
+ "/"
+ )
def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
@@ -199,70 +186,6 @@ class RegistrationConfig(Config):
#
#enable_registration: false
- # Optional account validity configuration. This allows for accounts to be denied
- # any request after a given period.
- #
- # Once this feature is enabled, Synapse will look for registered users without an
- # expiration date at startup and will add one to every account it found using the
- # current settings at that time.
- # This means that, if a validity period is set, and Synapse is restarted (it will
- # then derive an expiration date from the current validity period), and some time
- # after that the validity period changes and Synapse is restarted, the users'
- # expiration dates won't be updated unless their account is manually renewed. This
- # date will be randomly selected within a range [now + period - d ; now + period],
- # where d is equal to 10%% of the validity period.
- #
- account_validity:
- # The account validity feature is disabled by default. Uncomment the
- # following line to enable it.
- #
- #enabled: true
-
- # The period after which an account is valid after its registration. When
- # renewing the account, its validity period will be extended by this amount
- # of time. This parameter is required when using the account validity
- # feature.
- #
- #period: 6w
-
- # The amount of time before an account's expiry date at which Synapse will
- # send an email to the account's email address with a renewal link. By
- # default, no such emails are sent.
- #
- # If you enable this setting, you will also need to fill out the 'email'
- # configuration section. You should also check that 'public_baseurl' is set
- # correctly.
- #
- #renew_at: 1w
-
- # The subject of the email sent out with the renewal link. '%%(app)s' can be
- # used as a placeholder for the 'app_name' parameter from the 'email'
- # section.
- #
- # Note that the placeholder must be written '%%(app)s', including the
- # trailing 's'.
- #
- # If this is not set, a default value is used.
- #
- #renew_email_subject: "Renew your %%(app)s account"
-
- # Directory in which Synapse will try to find templates for the HTML files to
- # serve to the user when trying to renew an account. If not set, default
- # templates from within the Synapse package will be used.
- #
- #template_dir: "res/templates"
-
- # File within 'template_dir' giving the HTML to be displayed to the user after
- # they successfully renewed their account. If not set, default text is used.
- #
- #account_renewed_html_path: "account_renewed.html"
-
- # File within 'template_dir' giving the HTML to be displayed when the user
- # tries to renew an account with an invalid renewal token. If not set,
- # default text is used.
- #
- #invalid_token_html_path: "invalid_token.html"
-
# Time that a user's session remains valid for, after they log in.
#
# Note that this is not currently compatible with guest logins.
@@ -285,9 +208,32 @@ class RegistrationConfig(Config):
#
#disable_msisdn_registration: true
+ # Derive the user's matrix ID from a type of 3PID used when registering.
+ # This overrides any matrix ID the user proposes when calling /register
+ # The 3PID type should be present in registrations_require_3pid to avoid
+ # users failing to register if they don't specify the right kind of 3pid.
+ #
+ #register_mxid_from_3pid: email
+
+ # Uncomment to set the display name of new users to their email address,
+ # rather than using the default heuristic.
+ #
+ #register_just_use_email_for_display_name: true
+
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
+ # Use an Identity Server to establish which 3PIDs are allowed to register?
+ # Overrides allowed_local_3pids below.
+ #
+ #check_is_for_allowed_local_3pids: matrix.org
+ #
+ # If you are using an IS you can also check whether that IS registers
+ # pending invites for the given 3PID (and then allow it to sign up on
+ # the platform):
+ #
+ #allow_invited_3pids: false
+ #
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\\.org'
@@ -296,6 +242,11 @@ class RegistrationConfig(Config):
# - medium: msisdn
# pattern: '\\+44'
+ # If true, stop users from trying to change the 3PIDs associated with
+ # their accounts.
+ #
+ #disable_3pid_changes: false
+
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -326,6 +277,30 @@ class RegistrationConfig(Config):
#
#default_identity_server: https://matrix.org
+ # If enabled, user IDs, display names and avatar URLs will be replicated
+ # to this server whenever they change.
+ # This is an experimental API currently implemented by sydent to support
+ # cross-homeserver user directories.
+ #
+ #replicate_user_profiles_to: example.com
+
+ # If specified, attempt to replay registrations, profile changes & 3pid
+ # bindings on the given target homeserver via the AS API. The HS is authed
+ # via a given AS token.
+ #
+ #shadow_server:
+ # hs_url: https://shadow.example.com
+ # hs: shadow.example.com
+ # as_token: 12u394refgbdhivsia
+
+ # If enabled, don't let users set their own display names/avatars
+ # other than for the very first time (unless they are a server admin).
+ # Useful when provisioning users based on the contents of a 3rd party
+ # directory and to avoid ambiguities.
+ #
+ #disable_set_displayname: false
+ #disable_set_avatar_url: false
+
# Handle threepid (email/phone etc) registration and password resets through a set of
# *trusted* identity servers. Note that this allows the configured identity server to
# reset passwords for accounts!
@@ -450,6 +425,31 @@ class RegistrationConfig(Config):
# Defaults to true.
#
#auto_join_rooms_for_guests: false
+
+ # Rewrite identity server URLs with a map from one URL to another. Applies to URLs
+ # provided by clients (which have https:// prepended) and those specified
+ # in `account_threepid_delegates`. URLs should not feature a trailing slash.
+ #
+ #rewrite_identity_server_urls:
+ # "https://somewhere.example.com": "https://somewhereelse.example.com"
+
+ # When a user registers an account with an email address, it can be useful to
+ # bind that email address to their mxid on an identity server. Typically, this
+ # requires the user to validate their email address with the identity server.
+ # However if Synapse itself is handling email validation on registration, the
+ # user ends up needing to validate their email twice, which leads to poor UX.
+ #
+ # It is possible to force Sydent, one identity server implementation, to bind
+ # threepids using its internal, unauthenticated bind API:
+ # https://github.com/matrix-org/sydent/#internal-bind-and-unbind-api
+ #
+ # Configure the address of a Sydent server here to have Synapse attempt
+ # to automatically bind users' emails following registration. The
+ # internal bind API must be reachable from Synapse, but should NOT be
+ # exposed to any third party, as it allows the creation of bindings
+ # without validation.
+ #
+ #bind_new_user_emails_to_sydent: https://example.com:8091
"""
% locals()
)
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 850ac3ebd6..31e3f7148b 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -107,6 +107,12 @@ class ContentRepositoryConfig(Config):
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
+ self.max_avatar_size = config.get("max_avatar_size")
+ if self.max_avatar_size:
+ self.max_avatar_size = self.parse_size(self.max_avatar_size)
+
+ self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes", [])
+
self.media_store_path = self.ensure_directory(
config.get("media_store_path", "media_store")
)
@@ -250,6 +256,30 @@ class ContentRepositoryConfig(Config):
#
#max_upload_size: 50M
+ # The largest allowed size for a user avatar. If not defined, no
+ # restriction will be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #max_avatar_size: 10M
+
+ # Allow mimetypes for a user avatar. If not defined, no restriction will
+ # be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 47a0370173..b76afce5e5 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -338,6 +338,12 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # True.
+ self.show_users_in_user_directory = config.get(
+ "show_users_in_user_directory", True
+ )
+
retention_config = config.get("retention")
if retention_config is None:
retention_config = {}
@@ -1047,6 +1053,74 @@ class ServerConfig(Config):
#
#allow_per_room_profiles: false
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # 'true'.
+ #
+ #show_users_in_user_directory: false
+
+ # Message retention policy at the server level.
+ #
+ # Room admins and mods can define a retention period for their rooms using the
+ # 'm.room.retention' state event, and server admins can cap this period by setting
+ # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+ #
+ # If this feature is enabled, Synapse will regularly look for and purge events
+ # which are older than the room's maximum retention period. Synapse will also
+ # filter events received over federation so that events that should have been
+ # purged are ignored and not stored again.
+ #
+ retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
+
# How long to keep redacted events in unredacted form in the database. After
# this period redacted events get replaced with their redacted form in the DB.
#
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c8d19c5d6b..306e0cc8a4 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -26,6 +26,8 @@ class UserDirectoryConfig(Config):
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
+ self.user_directory_defer_to_id_server = None
+ self.user_directory_search_prefer_local_users = False
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_enabled = user_directory_config.get(
@@ -34,6 +36,12 @@ class UserDirectoryConfig(Config):
self.user_directory_search_all_users = user_directory_config.get(
"search_all_users", False
)
+ self.user_directory_defer_to_id_server = user_directory_config.get(
+ "defer_to_id_server", None
+ )
+ self.user_directory_search_prefer_local_users = user_directory_config.get(
+ "prefer_local_users", False
+ )
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
@@ -49,7 +57,17 @@ class UserDirectoryConfig(Config):
# rebuild the user_directory search indexes, see
# https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
#
+ # 'prefer_local_users' defines whether to prioritise local users in
+ # search query results. If True, local users are more likely to appear above
+ # remote users when searching the user directory. Defaults to false.
+ #
#user_directory:
# enabled: true
# search_all_users: false
+ # prefer_local_users: false
+ #
+ # # If this is set, user search will be delegated to this ID server instead
+ # # of synapse performing the search itself.
+ # # This is an experimental API.
+ # defer_to_id_server: https://id.example.com
"""
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 74b67b230a..14b21796d9 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -125,19 +125,24 @@ class FederationPolicyForHTTPS:
self._no_verify_ssl_context = _no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(_context_info_cb)
- def get_options(self, host: bytes):
+ self._should_verify = self._config.federation_verify_certificates
+
+ self._federation_certificate_verification_whitelist = (
+ self._config.federation_certificate_verification_whitelist
+ )
+ def get_options(self, host: bytes):
# IPolicyForHTTPS.get_options takes bytes, but we want to compare
# against the str whitelist. The hostnames in the whitelist are already
# IDNA-encoded like the hosts will be here.
ascii_host = host.decode("ascii")
# Check if certificate verification has been enabled
- should_verify = self._config.federation_verify_certificates
+ should_verify = self._should_verify
# Check if we've disabled certificate verification for this host
- if should_verify:
- for regex in self._config.federation_certificate_verification_whitelist:
+ if self._should_verify:
+ for regex in self._federation_certificate_verification_whitelist:
if regex.match(ascii_host):
should_verify = False
break
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 56f8dc9caf..498a699290 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -161,6 +161,7 @@ def check(
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
+ # 5. If type is m.room.membership
if event.type == EventTypes.Member:
_is_membership_change_allowed(event, auth_events)
logger.debug("Allowing! %s", event)
@@ -247,6 +248,7 @@ def _is_membership_change_allowed(
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
+ caller_knocked = caller and caller.membership == Membership.KNOCK
# get info about the target
key = (EventTypes.Member, target_user_id)
@@ -289,9 +291,12 @@ def _is_membership_change_allowed(
raise AuthError(403, "%s is banned from the room" % (target_user_id,))
return
- if Membership.JOIN != membership:
+ # Require the user to be in the room for membership changes other than join/knock.
+ if Membership.JOIN != membership and Membership.KNOCK != membership:
+ # If the user has been invited or has knocked, they are allowed to change their
+ # membership event to leave
if (
- caller_invited
+ (caller_invited or caller_knocked)
and Membership.LEAVE == membership
and target_user_id == event.user_id
):
@@ -324,7 +329,7 @@ def _is_membership_change_allowed(
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
- elif join_rule == JoinRules.INVITE:
+ elif join_rule in (JoinRules.INVITE, JoinRules.KNOCK):
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
@@ -343,6 +348,17 @@ def _is_membership_change_allowed(
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
+ elif Membership.KNOCK == membership:
+ if join_rule != JoinRules.KNOCK:
+ raise AuthError(403, "You don't have permission to knock")
+ elif target_user_id != event.user_id:
+ raise AuthError(403, "You cannot knock for other users")
+ elif target_in_room:
+ raise AuthError(403, "You cannot knock on a room you are already in")
+ elif caller_invited:
+ raise AuthError(403, "You are already invited to this room")
+ elif target_banned:
+ raise AuthError(403, "You are banned from this room")
else:
raise AuthError(500, "Unknown membership %s" % membership)
@@ -699,7 +715,7 @@ def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]:
if event.type == EventTypes.Member:
membership = event.content["membership"]
- if membership in [Membership.JOIN, Membership.INVITE]:
+ if membership in [Membership.JOIN, Membership.INVITE, Membership.KNOCK]:
auth_types.add((EventTypes.JoinRules, ""))
auth_types.add((EventTypes.Member, event.state_key))
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index e7e3a7b9a4..f3322af499 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -63,16 +63,32 @@ class SpamChecker:
return False
async def user_may_invite(
- self, inviter_userid: str, invitee_userid: str, room_id: str
+ self,
+ inviter_userid: str,
+ invitee_userid: Optional[str],
+ third_party_invite: Optional[Dict],
+ room_id: str,
+ new_room: bool,
+ published_room: bool,
) -> bool:
"""Checks if a given user may send an invite
If this method returns false, the invite will be rejected.
Args:
- inviter_userid: The user ID of the sender of the invitation
- invitee_userid: The user ID targeted in the invitation
- room_id: The room ID
+ inviter_userid:
+ invitee_userid: The user ID of the invitee. Is None
+ if this is a third party invite and the 3PID is not bound to a
+ user ID.
+ third_party_invite: If a third party invite then is a
+ dict containing the medium and address of the invitee.
+ room_id:
+ new_room: Whether the user is being invited to the room as
+ part of a room creation, if so the invitee would have been
+ included in the call to `user_may_create_room`.
+ published_room: Whether the room the user is being invited
+ to has been published in the local homeserver's public room
+ directory.
Returns:
True if the user may send an invite, otherwise False
@@ -81,7 +97,12 @@ class SpamChecker:
if (
await maybe_awaitable(
spam_checker.user_may_invite(
- inviter_userid, invitee_userid, room_id
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room,
)
)
is False
@@ -90,20 +111,36 @@ class SpamChecker:
return True
- async def user_may_create_room(self, userid: str) -> bool:
+ async def user_may_create_room(
+ self,
+ userid: str,
+ invite_list: List[str],
+ third_party_invite_list: List[Dict],
+ cloning: bool,
+ ) -> bool:
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
Args:
userid: The ID of the user attempting to create a room
+ invite_list: List of user IDs that would be invited to
+ the new room.
+ third_party_invite_list: List of third party invites
+ for the new room.
+ cloning: Whether the user is cloning an existing room, e.g.
+ upgrading a room.
Returns:
True if the user may create a room, otherwise False
"""
for spam_checker in self.spam_checkers:
if (
- await maybe_awaitable(spam_checker.user_may_create_room(userid))
+ await maybe_awaitable(
+ spam_checker.user_may_create_room(
+ userid, invite_list, third_party_invite_list, cloning
+ )
+ )
is False
):
return False
@@ -156,6 +193,25 @@ class SpamChecker:
return True
+ def user_may_join_room(self, userid: str, room_id: str, is_invited: bool):
+ """Checks if a given users is allowed to join a room.
+
+ Not called when a user creates a room.
+
+ Args:
+ userid:
+ room_id:
+ is_invited: Whether the user is invited into the room
+
+ Returns:
+ bool: Whether the user may join the room
+ """
+ for spam_checker in self.spam_checkers:
+ if spam_checker.user_may_join_room(userid, room_id, is_invited) is False:
+ return False
+
+ return True
+
async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 9c22e33813..4eadda4b40 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -240,6 +240,7 @@ def format_event_for_client_v1(d):
"replaces_state",
"prev_content",
"invite_room_state",
+ "knock_room_state",
)
for key in copy_keys:
if key in d["unsigned"]:
@@ -276,7 +277,7 @@ def serialize_event(
event_format=format_event_for_client_v1,
token_id=None,
only_event_fields=None,
- is_invite=False,
+ include_stripped_room_state=False,
):
"""Serialize event for clients
@@ -287,8 +288,10 @@ def serialize_event(
event_format
token_id
only_event_fields
- is_invite (bool): Whether this is an invite that is being sent to the
- invitee
+ include_stripped_room_state (bool): Some events can have stripped room state
+ stored in the `unsigned` field. This is required for invite and knock
+ functionality. If this option is False, that state will be removed from the
+ event before it is returned. Otherwise, it will be kept.
Returns:
dict
@@ -320,11 +323,13 @@ def serialize_event(
if txn_id is not None:
d["unsigned"]["transaction_id"] = txn_id
- # If this is an invite for somebody else, then we don't care about the
- # invite_room_state as that's meant solely for the invitee. Other clients
- # will already have the state since they're in the room.
- if not is_invite:
+ # invite_room_state and knock_room_state are a list of stripped room state events
+ # that are meant to provide metadata about a room to an invitee/knocker. They are
+ # intended to only be included in specific circumstances, such as down sync, and
+ # should not be included in any other case.
+ if not include_stripped_room_state:
d["unsigned"].pop("invite_room_state", None)
+ d["unsigned"].pop("knock_room_state", None)
if as_client_event:
d = event_format(d)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d330ae5dbc..22bb1afd68 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyrignt 2020 Sorunome
+# Copyrignt 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -553,7 +555,7 @@ class FederationClient(FederationBase):
RuntimeError: if no servers were reachable.
"""
- valid_memberships = {Membership.JOIN, Membership.LEAVE}
+ valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s"
@@ -810,7 +812,7 @@ class FederationClient(FederationBase):
"User's homeserver does not support this room version",
Codes.UNSUPPORTED_ROOM_VERSION,
)
- elif e.code == 403:
+ elif e.code in (403, 429):
raise e.to_synapse_error()
else:
raise
@@ -888,6 +890,62 @@ class FederationClient(FederationBase):
# content.
return resp[1]
+ async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict:
+ """Attempts to send a knock event to given a list of servers. Iterates
+ through the list until one attempt succeeds.
+
+ Doing so will cause the remote server to add the event to the graph,
+ and send the event out to the rest of the federation.
+
+ Args:
+ destinations: A list of candidate homeservers which are likely to be
+ participating in the room.
+ pdu: The event to be sent.
+
+ Returns:
+ The remote homeserver return some state from the room. The response
+ dictionary is in the form:
+
+ {"knock_state_events": [<state event dict>, ...]}
+
+ The list of state events may be empty.
+
+ Raises:
+ SynapseError: If the chosen remote server returns a 3xx/4xx code.
+ RuntimeError: If no servers were reachable.
+ """
+
+ async def send_request(destination: str) -> JsonDict:
+ return await self._do_send_knock(destination, pdu)
+
+ return await self._try_destination_list(
+ "xyz.amorgan.knock/send_knock", destinations, send_request
+ )
+
+ async def _do_send_knock(self, destination: str, pdu: EventBase) -> JsonDict:
+ """Send a knock event to a remote homeserver.
+
+ Args:
+ destination: The homeserver to send to.
+ pdu: The event to send.
+
+ Returns:
+ The remote homeserver can optionally return some state from the room. The response
+ dictionary is in the form:
+
+ {"knock_state_events": [<state event dict>, ...]}
+
+ The list of state events may be empty.
+ """
+ time_now = self._clock.time_msec()
+
+ return await self.transport_layer.send_knock_v2(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
async def get_public_rooms(
self,
remote_server: str,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 171d25c945..c5e57b9d11 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.config.api import DEFAULT_ROOM_STATE_TYPES
from synapse.events import EventBase
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
@@ -567,6 +568,76 @@ class FederationServer(FederationBase):
await self.handler.on_send_leave_request(origin, pdu)
return {}
+ async def on_make_knock_request(
+ self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
+ ) -> Dict[str, Union[EventBase, str]]:
+ """We've received a /make_knock/ request, so we create a partial knock
+ event for the room and hand that back, along with the room version, to the knocking
+ homeserver. We do *not* persist or process this event until the other server has
+ signed it and sent it back.
+
+ Args:
+ origin: The (verified) server name of the requesting server.
+ room_id: The room to create the knock event in.
+ user_id: The user to create the knock for.
+ supported_versions: The room versions supported by the requesting server.
+
+ Returns:
+ The partial knock event.
+ """
+ origin_host, _ = parse_server_name(origin)
+ await self.check_server_matches_acl(origin_host, room_id)
+
+ room_version = await self.store.get_room_version_id(room_id)
+ if room_version not in supported_versions:
+ logger.warning(
+ "Room version %s not in %s", room_version, supported_versions
+ )
+ raise IncompatibleRoomVersionError(room_version=room_version)
+
+ pdu = await self.handler.on_make_knock_request(origin, room_id, user_id)
+ time_now = self._clock.time_msec()
+ return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
+
+ async def on_send_knock_request(
+ self, origin: str, content: JsonDict, room_id: str,
+ ) -> Dict[str, List[JsonDict]]:
+ """
+ We have received a knock event for a room. Verify and send the event into the room
+ on the knocking homeserver's behalf. Then reply with some stripped state from the
+ room for the knockee.
+
+ Args:
+ origin: The remote homeserver of the knocking user.
+ content: The content of the request.
+ room_id: The ID of the room to knock on.
+
+ Returns:
+ The stripped room state.
+ """
+ logger.debug("on_send_knock_request: content: %s", content)
+
+ room_version = await self.store.get_room_version(room_id)
+ pdu = event_from_pdu_json(content, room_version)
+
+ origin_host, _ = parse_server_name(origin)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
+
+ logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures)
+
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
+
+ # Handle the event, and retrieve the EventContext
+ event_context = await self.handler.on_send_knock_request(origin, pdu)
+
+ # Retrieve stripped state events from the room and send them back to the remote
+ # server. This will allow the remote server's clients to display information
+ # related to the room while the knock request is pending.
+ stripped_room_state = await self.store.get_stripped_room_state_from_event_context(
+ event_context, DEFAULT_ROOM_STATE_TYPES
+ )
+ return {"knock_state_events": stripped_room_state}
+
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 604cfd1935..643b26ae6d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -142,6 +142,8 @@ class FederationSender:
self._wake_destinations_needing_catchup,
)
+ self._external_cache = hs.get_external_cache()
+
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
@@ -197,22 +199,40 @@ class FederationSender:
if not event.internal_metadata.should_proactively_send():
return
- try:
- # Get the state from before the event.
- # We need to make sure that this is the state from before
- # the event and not from after it.
- # Otherwise if the last member on a server in a room is
- # banned then it won't receive the event because it won't
- # be in the room after the ban.
- destinations = await self.state.get_hosts_in_room_at_events(
- event.room_id, event_ids=event.prev_event_ids()
- )
- except Exception:
- logger.exception(
- "Failed to calculate hosts in room for event: %s",
- event.event_id,
+ destinations = None # type: Optional[Set[str]]
+ if not event.prev_event_ids():
+ # If there are no prev event IDs then the state is empty
+ # and so no remote servers in the room
+ destinations = set()
+ else:
+ # We check the external cache for the destinations, which is
+ # stored per state group.
+
+ sg = await self._external_cache.get(
+ "event_to_prev_state_group", event.event_id
)
- return
+ if sg:
+ destinations = await self._external_cache.get(
+ "get_joined_hosts", str(sg)
+ )
+
+ if destinations is None:
+ try:
+ # Get the state from before the event.
+ # We need to make sure that this is the state from before
+ # the event and not from after it.
+ # Otherwise if the last member on a server in a room is
+ # banned then it won't receive the event because it won't
+ # be in the room after the ban.
+ destinations = await self.state.get_hosts_in_room_at_events(
+ event.room_id, event_ids=event.prev_event_ids()
+ )
+ except Exception:
+ logger.exception(
+ "Failed to calculate hosts in room for event: %s",
+ event.event_id,
+ )
+ return
destinations = {
d
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index abe9168c78..9c454e5885 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2020 Sorunome
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +18,7 @@
import logging
import urllib
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
from synapse.api.constants import Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
@@ -26,6 +28,7 @@ from synapse.api.urls import (
FEDERATION_V2_PREFIX,
)
from synapse.logging.utils import log_function
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -209,13 +212,24 @@ class TransportLayerClient:
Fails with ``FederationDeniedError`` if the remote destination
is not in our federation whitelist
"""
- valid_memberships = {Membership.JOIN, Membership.LEAVE}
+ valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s"
% (membership, ",".join(valid_memberships))
)
- path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
+
+ # Knock currently uses an unstable prefix
+ if membership == Membership.KNOCK:
+ # Create a path in the form of /unstable/xyz.amorgan.knock/make_knock/...
+ path = _create_path(
+ FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock",
+ "/make_knock/%s/%s",
+ room_id,
+ user_id,
+ )
+ else:
+ path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
ignore_backoff = False
retry_on_dns_fail = False
@@ -294,6 +308,41 @@ class TransportLayerClient:
return response
@log_function
+ async def send_knock_v2(
+ self, destination: str, room_id: str, event_id: str, content: JsonDict,
+ ) -> JsonDict:
+ """
+ Sends a signed knock membership event to a remote server. This is the second
+ step for knocking after make_knock.
+
+ Args:
+ destination: The remote homeserver.
+ room_id: The ID of the room to knock on.
+ event_id: The ID of the knock membership event that we're sending.
+ content: The knock membership event that we're sending. Note that this is not the
+ `content` field of the membership event, but the entire signed membership event
+ itself represented as a JSON dict.
+
+ Returns:
+ The remote homeserver can optionally return some state from the room. The response
+ dictionary is in the form:
+
+ {"knock_state_events": [<state event dict>, ...]}
+
+ The list of state events may be empty.
+ """
+ path = _create_path(
+ FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock",
+ "/send_knock/%s/%s",
+ room_id,
+ event_id,
+ )
+
+ return await self.client.put_json(
+ destination=destination, path=path, data=content
+ )
+
+ @log_function
async def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
@@ -1004,6 +1053,20 @@ class TransportLayerClient:
return self.client.get_json(destination=destination, path=path)
+ def get_info_of_users(self, destination: str, user_ids: List[str]):
+ """
+ Args:
+ destination: The remote server
+ user_ids: A list of user IDs to query info about
+
+ Returns:
+ Deferred[List]: A dictionary of User ID to information about that user.
+ """
+ path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/users/info")
+ data = {"user_ids": user_ids}
+
+ return self.client.post_json(destination=destination, path=path, data=data)
+
def _create_path(federation_prefix, path, *args):
"""
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 95c64510a9..e9fb8d4079 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +15,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import functools
import logging
import re
@@ -30,9 +30,11 @@ from synapse.api.urls import (
)
from synapse.http.server import JsonResource
from synapse.http.servlet import (
+ assert_params_in_dict,
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
+ parse_list_from_args,
parse_string_from_args,
)
from synapse.logging.context import run_in_background
@@ -542,6 +544,34 @@ class FederationV2SendLeaveServlet(BaseFederationServlet):
return 200, content
+class FederationMakeKnockServlet(BaseFederationServlet):
+ PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
+
+ PREFIX = FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock"
+
+ async def on_GET(self, origin, content, query, room_id, user_id):
+ try:
+ # Retrieve the room versions the remote homeserver claims to support
+ supported_versions = parse_list_from_args(query, "ver", encoding="utf-8")
+ except KeyError:
+ raise SynapseError(400, "Missing required query parameter 'ver'")
+
+ content = await self.handler.on_make_knock_request(
+ origin, room_id, user_id, supported_versions=supported_versions
+ )
+ return 200, content
+
+
+class FederationV2SendKnockServlet(BaseFederationServlet):
+ PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+
+ PREFIX = FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock"
+
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ content = await self.handler.on_send_knock_request(origin, content, room_id)
+ return 200, content
+
+
class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
@@ -842,6 +872,57 @@ class PublicRoomList(BaseFederationServlet):
return 200, data
+class FederationUserInfoServlet(BaseFederationServlet):
+ """
+ Return information about a set of users.
+
+ This API returns expiration and deactivation information about a set of
+ users. Requested users not local to this homeserver will be ignored.
+
+ Example request:
+ POST /users/info
+
+ {
+ "user_ids": [
+ "@alice:example.com",
+ "@bob:example.com"
+ ]
+ }
+
+ Example response
+ {
+ "@alice:example.com": {
+ "expired": false,
+ "deactivated": true
+ }
+ }
+ """
+
+ PATH = "/users/info"
+ PREFIX = FEDERATION_UNSTABLE_PREFIX
+
+ def __init__(self, handler, authenticator, ratelimiter, server_name):
+ super(FederationUserInfoServlet, self).__init__(
+ handler, authenticator, ratelimiter, server_name
+ )
+ self.handler = handler
+
+ async def on_POST(self, origin, content, query):
+ assert_params_in_dict(content, required=["user_ids"])
+
+ user_ids = content.get("user_ids", [])
+
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ data = await self.handler.store.get_info_for_users(user_ids)
+ return 200, data
+
+
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
@@ -1387,11 +1468,13 @@ FEDERATION_SERVLET_CLASSES = (
FederationQueryServlet,
FederationMakeJoinServlet,
FederationMakeLeaveServlet,
+ FederationMakeKnockServlet,
FederationEventServlet,
FederationV1SendJoinServlet,
FederationV2SendJoinServlet,
FederationV1SendLeaveServlet,
FederationV2SendLeaveServlet,
+ FederationV2SendKnockServlet,
FederationV1InviteServlet,
FederationV2InviteServlet,
FederationGetMissingEventsServlet,
@@ -1403,6 +1486,7 @@ FEDERATION_SERVLET_CLASSES = (
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
+ FederationUserInfoServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
OPENID_SERVLET_CLASSES = (
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 664d09da1c..ce97fa70d7 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,11 +18,14 @@ import email.utils
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.errors import StoreError, SynapseError
from synapse.logging.context import make_deferred_yieldable
-from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.types import UserID
from synapse.util import stringutils
@@ -40,27 +43,37 @@ class AccountValidityHandler:
self.sendmail = self.hs.get_sendmail()
self.clock = self.hs.get_clock()
- self._account_validity = self.hs.config.account_validity
+ self._account_validity_enabled = self.hs.config.account_validity_enabled
+ self._account_validity_renew_by_email_enabled = (
+ self.hs.config.account_validity_renew_by_email_enabled
+ )
+ self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory
+ self.profile_handler = self.hs.get_profile_handler()
+
+ self._account_validity_period = None
+ if self._account_validity_enabled:
+ self._account_validity_period = self.hs.config.account_validity_period
if (
- self._account_validity.enabled
- and self._account_validity.renew_by_email_enabled
+ self._account_validity_enabled
+ and self._account_validity_renew_by_email_enabled
):
# Don't do email-specific configuration if renewal by email is disabled.
self._template_html = self.config.account_validity_template_html
self._template_text = self.config.account_validity_template_text
+ account_validity_renew_email_subject = (
+ self.hs.config.account_validity_renew_email_subject
+ )
try:
app_name = self.hs.config.email_app_name
- self._subject = self._account_validity.renew_email_subject % {
- "app": app_name
- }
+ self._subject = account_validity_renew_email_subject % {"app": app_name}
self._from_string = self.hs.config.email_notif_from % {"app": app_name}
except Exception:
# If substitution failed, fall back to the bare strings.
- self._subject = self._account_validity.renew_email_subject
+ self._subject = account_validity_renew_email_subject
self._from_string = self.hs.config.email_notif_from
self._raw_from = email.utils.parseaddr(self._from_string)[1]
@@ -69,6 +82,18 @@ class AccountValidityHandler:
if hs.config.run_background_tasks:
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
+ # Mark users as inactive when they expired. Check once every hour
+ if self._account_validity_enabled:
+
+ def mark_expired_users_as_inactive():
+ # run as a background process to allow async functions to work
+ return run_as_background_process(
+ "_mark_expired_users_as_inactive",
+ self._mark_expired_users_as_inactive,
+ )
+
+ self.clock.looping_call(mark_expired_users_as_inactive, 60 * 60 * 1000)
+
@wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self) -> None:
"""Gets the list of users whose account is expiring in the amount of time
@@ -221,47 +246,107 @@ class AccountValidityHandler:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
- async def renew_account(self, renewal_token: str) -> bool:
+ async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
+ If it turns out that the token is valid but has already been used, then the
+ token is considered stale. A token is stale if the 'token_used_ts_ms' db column
+ is non-null.
+
Args:
renewal_token: Token sent with the renewal request.
Returns:
- Whether the provided token is valid.
+ A tuple containing:
+ * A bool representing whether the token is valid and unused.
+ * A bool representing whether the token is stale.
+ * An int representing the user's expiry timestamp as milliseconds since the
+ epoch, or 0 if the token was invalid.
"""
try:
- user_id = await self.store.get_user_from_renewal_token(renewal_token)
+ (
+ user_id,
+ current_expiration_ts,
+ token_used_ts,
+ ) = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
- return False
+ return False, False, 0
+
+ # Check whether this token has already been used.
+ if token_used_ts:
+ logger.info(
+ "User '%s' attempted to use previously used token '%s' to renew account",
+ user_id,
+ renewal_token,
+ )
+ return False, True, current_expiration_ts
logger.debug("Renewing an account for user %s", user_id)
- await self.renew_account_for_user(user_id)
- return True
+ # Renew the account. Pass the renewal_token here so that it is not cleared.
+ # We want to keep the token around in case the user attempts to renew their
+ # account with the same token twice (clicking the email link twice).
+ #
+ # In that case, the token will be accepted, but the account's expiration ts
+ # will remain unchanged.
+ new_expiration_ts = await self.renew_account_for_user(
+ user_id, renewal_token=renewal_token
+ )
+
+ return True, False, new_expiration_ts
async def renew_account_for_user(
- self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+ self,
+ user_id: str,
+ expiration_ts: Optional[int] = None,
+ email_sent: bool = False,
+ renewal_token: Optional[str] = None,
) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
Args:
- renewal_token: Token sent with the renewal request.
+ user_id: The ID of the user to renew.
expiration_ts: New expiration date. Defaults to now + validity period.
- email_sen: Whether an email has been sent for this validity period.
- Defaults to False.
+ email_sent: Whether an email has been sent for this validity period.
+ renewal_token: Token sent with the renewal request. The user's token
+ will be cleared if this is None.
Returns:
New expiration date for this account, as a timestamp in
milliseconds since epoch.
"""
+ now = self.clock.time_msec()
if expiration_ts is None:
- expiration_ts = self.clock.time_msec() + self._account_validity.period
+ expiration_ts = now + self._account_validity_period
await self.store.set_account_validity_for_user(
- user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
+ user_id=user_id,
+ expiration_ts=expiration_ts,
+ email_sent=email_sent,
+ renewal_token=renewal_token,
+ token_used_ts=now,
)
+ # Check if renewed users should be reintroduced to the user directory
+ if self._show_users_in_user_directory:
+ # Show the user in the directory again by setting them to active
+ await self.profile_handler.set_active(
+ [UserID.from_string(user_id)], True, True
+ )
+
return expiration_ts
+
+ async def _mark_expired_users_as_inactive(self):
+ """Iterate over active, expired users. Mark them as inactive in order to hide them
+ from the user directory.
+
+ Returns:
+ Deferred
+ """
+ # Get active, expired users
+ active_expired_users = await self.store.get_expired_users()
+
+ # Mark each as non-active
+ await self.profile_handler.set_active(active_expired_users, False, True)
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 8476256a59..5ecb2da1ac 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
import twisted
import twisted.internet.error
@@ -22,6 +23,9 @@ from twisted.web.resource import Resource
from synapse.app import check_bind_error
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
ACME_REGISTER_FAIL_ERROR = """
@@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
class AcmeHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain
- async def start_listening(self):
+ async def start_listening(self) -> None:
from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug
@@ -85,7 +89,7 @@ class AcmeHandler:
logger.error(ACME_REGISTER_FAIL_ERROR)
raise
- async def provision_certificate(self):
+ async def provision_certificate(self) -> None:
logger.warning("Reprovisioning %s", self._acme_domain)
@@ -110,5 +114,3 @@ class AcmeHandler:
except Exception:
logger.exception("Failed saving!")
raise
-
- return True
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 7294649d71..ae2a9dd9c2 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
imported conditionally.
"""
import logging
+from typing import Dict, Iterable, List
import attr
+import pem
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from josepy import JWKRSA
@@ -36,20 +38,27 @@ from txacme.util import generate_private_key
from zope.interface import implementer
from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTCP
from twisted.python.filepath import FilePath
from twisted.python.url import URL
+from twisted.web.resource import IResource
logger = logging.getLogger(__name__)
-def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
+def create_issuing_service(
+ reactor: IReactorTCP,
+ acme_url: str,
+ account_key_file: str,
+ well_known_resource: IResource,
+) -> AcmeIssuingService:
"""Create an ACME issuing service, and attach it to a web Resource
Args:
reactor: twisted reactor
- acme_url (str): URL to use to request certificates
- account_key_file (str): where to store the account key
- well_known_resource (twisted.web.IResource): web resource for .well-known.
+ acme_url: URL to use to request certificates
+ account_key_file: where to store the account key
+ well_known_resource: web resource for .well-known.
we will attach a child resource for "acme-challenge".
Returns:
@@ -83,18 +92,20 @@ class ErsatzStore:
A store that only stores in memory.
"""
- certs = attr.ib(default=attr.Factory(dict))
+ certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
- def store(self, server_name, pem_objects):
+ def store(
+ self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
+ ) -> defer.Deferred:
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)
-def load_or_create_client_key(key_file):
+def load_or_create_client_key(key_file: str) -> JWKRSA:
"""Load the ACME account key from a file, creating it if it does not exist.
Args:
- key_file (str): name of the file to use as the account key
+ key_file: name of the file to use as the account key
"""
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
# hardcode the 'client.key' filename
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6f746711ca..a19c556437 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -568,16 +568,6 @@ class AuthHandler(BaseHandler):
session.session_id, login_type, result
)
except LoginError as e:
- if login_type == LoginType.EMAIL_IDENTITY:
- # riot used to have a bug where it would request a new
- # validation token (thus sending a new email) each time it
- # got a 401 with a 'flows' field.
- # (https://github.com/vector-im/vector-web/issues/2447).
- #
- # Grandfather in the old behaviour for now to avoid
- # breaking old riot deployments.
- raise
-
# this step failed. Merge the error dict into the response
# so that the client can have another go.
errordict = e.error_dict()
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 048523ec94..bd35d1fb87 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -100,11 +100,7 @@ class CasHandler:
Returns:
The URL to use as a "service" parameter.
"""
- return "%s%s?%s" % (
- self._cas_service_url,
- "/_matrix/client/r0/login/cas/ticket",
- urllib.parse.urlencode(args),
- )
+ return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index c4a3b26a84..ac25e3e94f 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -50,7 +50,7 @@ class DeactivateAccountHandler(BaseHandler):
if hs.config.run_background_tasks:
hs.get_reactor().callWhenRunning(self._start_user_parting)
- self._account_validity_enabled = hs.config.account_validity.enabled
+ self._account_validity_enabled = hs.config.account_validity_enabled
async def deactivate_account(
self,
@@ -120,6 +120,9 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.user_set_password_hash(user_id, None)
+ user = UserID.from_string(user_id)
+ await self._profile_handler.set_active([user], False, False)
+
# Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of)
await self.store.add_user_pending_deactivation(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index debb1b4f29..0863154f7a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api import errors
from synapse.api.constants import EventTypes
@@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
@trace
- async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
+ async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
Retrieve the given user's devices
@@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler):
return devices
@trace
- async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
+ async def get_device(self, user_id: str, device_id: str) -> JsonDict:
""" Retrieve the given device
Args:
@@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips(
- device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
+ device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@@ -946,8 +946,8 @@ class DeviceListUpdater:
async def process_cross_signing_key_update(
self,
user_id: str,
- master_key: Optional[Dict[str, Any]],
- self_signing_key: Optional[Dict[str, Any]],
+ master_key: Optional[JsonDict],
+ self_signing_key: Optional[JsonDict],
) -> List[str]:
"""Process the given new master and self-signing key for the given remote user.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 929752150d..8f3a6b35a4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,7 +16,7 @@
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
+ JsonDict,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
@@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class E2eKeysHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
@@ -78,7 +82,9 @@ class E2eKeysHandler:
)
@trace
- async def query_devices(self, query_body, timeout, from_user_id):
+ async def query_devices(
+ self, query_body: JsonDict, timeout: int, from_user_id: str
+ ) -> JsonDict:
""" Handle a device key query from a client
{
@@ -98,12 +104,14 @@ class E2eKeysHandler:
}
Args:
- from_user_id (str): the user making the query. This is used when
+ from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
"""
- device_keys_query = query_body.get("device_keys", {})
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Iterable[str]]
# separate users by domain.
# make a map from domain to user_id to device_ids
@@ -121,7 +129,8 @@ class E2eKeysHandler:
set_tag("remote_key_query", remote_queries)
# First get local devices.
- failures = {}
+ # A map of destination -> failure response.
+ failures = {} # type: Dict[str, JsonDict]
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
@@ -135,9 +144,10 @@ class E2eKeysHandler:
)
# Now attempt to get any remote devices from our local cache.
- remote_queries_not_in_cache = {}
+ # A map of destination -> user ID -> device IDs.
+ remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
- query_list = []
+ query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
@@ -284,15 +294,15 @@ class E2eKeysHandler:
return ret
async def get_cross_signing_keys_from_cache(
- self, query, from_user_id
+ self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
"""Get cross-signing keys for users from the database
Args:
- query (Iterable[string]) an iterable of user IDs. A dict whose keys
+ query: an iterable of user IDs. A dict whose keys
are user IDs satisfies this, so the query format used for
query_devices can be used here.
- from_user_id (str): the user making the query. This is used when
+ from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
@@ -315,14 +325,12 @@ class E2eKeysHandler:
if "self_signing" in user_info:
self_signing_keys[user_id] = user_info["self_signing"]
- if (
- from_user_id in keys
- and keys[from_user_id] is not None
- and "user_signing" in keys[from_user_id]
- ):
- # users can see other users' master and self-signing keys, but can
- # only see their own user-signing keys
- user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+ # users can see other users' master and self-signing keys, but can
+ # only see their own user-signing keys
+ if from_user_id:
+ from_user_key = keys.get(from_user_id)
+ if from_user_key and "user_signing" in from_user_key:
+ user_signing_keys[from_user_id] = from_user_key["user_signing"]
return {
"master_keys": master_keys,
@@ -344,9 +352,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
- local_query = []
+ local_query = [] # type: List[Tuple[str, Optional[str]]]
- result_dict = {}
+ result_dict = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
@@ -380,10 +388,14 @@ class E2eKeysHandler:
log_kv(results)
return result_dict
- async def on_federation_query_client_keys(self, query_body):
+ async def on_federation_query_client_keys(
+ self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
+ ) -> JsonDict:
""" Handle a device key query from a federated server
"""
- device_keys_query = query_body.get("device_keys", {})
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Optional[List[str]]]
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
@@ -397,31 +409,34 @@ class E2eKeysHandler:
return ret
@trace
- async def claim_one_time_keys(self, query, timeout):
- local_query = []
- remote_queries = {}
+ async def claim_one_time_keys(
+ self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+ ) -> JsonDict:
+ local_query = [] # type: List[Tuple[str, str, str]]
+ remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
- for user_id, device_keys in query.get("one_time_keys", {}).items():
+ for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
- for device_id, algorithm in device_keys.items():
+ for device_id, algorithm in one_time_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
- remote_queries.setdefault(domain, {})[user_id] = device_keys
+ remote_queries.setdefault(domain, {})[user_id] = one_time_keys
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
results = await self.store.claim_e2e_one_time_keys(local_query)
- json_result = {}
- failures = {}
+ # A map of user ID -> device ID -> key ID -> key.
+ json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+ failures = {} # type: Dict[str, JsonDict]
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
- for key_id, json_bytes in keys.items():
+ for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
- key_id: json_decoder.decode(json_bytes)
+ key_id: json_decoder.decode(json_str)
}
@trace
@@ -468,7 +483,9 @@ class E2eKeysHandler:
return {"one_time_keys": json_result, "failures": failures}
@tag_args
- async def upload_keys_for_user(self, user_id, device_id, keys):
+ async def upload_keys_for_user(
+ self, user_id: str, device_id: str, keys: JsonDict
+ ) -> JsonDict:
time_now = self.clock.time_msec()
@@ -543,8 +560,8 @@ class E2eKeysHandler:
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user(
- self, user_id, device_id, time_now, one_time_keys
- ):
+ self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+ ) -> None:
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(),
@@ -585,12 +602,14 @@ class E2eKeysHandler:
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
- async def upload_signing_keys_for_user(self, user_id, keys):
+ async def upload_signing_keys_for_user(
+ self, user_id: str, keys: JsonDict
+ ) -> JsonDict:
"""Upload signing keys for cross-signing
Args:
- user_id (string): the user uploading the keys
- keys (dict[string, dict]): the signing keys
+ user_id: the user uploading the keys
+ keys: the signing keys
"""
# if a master key is uploaded, then check it. Otherwise, load the
@@ -667,16 +686,17 @@ class E2eKeysHandler:
return {}
- async def upload_signatures_for_device_keys(self, user_id, signatures):
+ async def upload_signatures_for_device_keys(
+ self, user_id: str, signatures: JsonDict
+ ) -> JsonDict:
"""Upload device signatures for cross-signing
Args:
- user_id (string): the user uploading the signatures
- signatures (dict[string, dict[string, dict]]): map of users to
- devices to signed keys. This is the submission from the user; an
- exception will be raised if it is malformed.
+ user_id: the user uploading the signatures
+ signatures: map of users to devices to signed keys. This is the submission
+ from the user; an exception will be raised if it is malformed.
Returns:
- dict: response to be sent back to the client. The response will have
+ The response to be sent back to the client. The response will have
a "failures" key, which will be a dict mapping users to devices
to errors for the signatures that failed.
Raises:
@@ -719,7 +739,9 @@ class E2eKeysHandler:
return {"failures": failures}
- async def _process_self_signatures(self, user_id, signatures):
+ async def _process_self_signatures(
+ self, user_id: str, signatures: JsonDict
+ ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of the user's own keys.
Signatures of the user's own keys from this API come in two forms:
@@ -731,15 +753,14 @@ class E2eKeysHandler:
signatures (dict[string, dict]): map of devices to signed keys
Returns:
- (list[SignatureListItem], dict[string, dict[string, dict]]):
- a list of signatures to store, and a map of users to devices to failure
- reasons
+ A tuple of a list of signatures to store, and a map of users to
+ devices to failure reasons
Raises:
SynapseError: if the input is malformed
"""
- signature_list = []
- failures = {}
+ signature_list = [] # type: List[SignatureListItem]
+ failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures:
return signature_list, failures
@@ -834,19 +855,24 @@ class E2eKeysHandler:
return signature_list, failures
def _check_master_key_signature(
- self, user_id, master_key_id, signed_master_key, stored_master_key, devices
- ):
+ self,
+ user_id: str,
+ master_key_id: str,
+ signed_master_key: JsonDict,
+ stored_master_key: JsonDict,
+ devices: Dict[str, Dict[str, JsonDict]],
+ ) -> List["SignatureListItem"]:
"""Check signatures of a user's master key made by their devices.
Args:
- user_id (string): the user whose master key is being checked
- master_key_id (string): the ID of the user's master key
- signed_master_key (dict): the user's signed master key that was uploaded
- stored_master_key (dict): our previously-stored copy of the user's master key
- devices (iterable(dict)): the user's devices
+ user_id: the user whose master key is being checked
+ master_key_id: the ID of the user's master key
+ signed_master_key: the user's signed master key that was uploaded
+ stored_master_key: our previously-stored copy of the user's master key
+ devices: the user's devices
Returns:
- list[SignatureListItem]: a list of signatures to store
+ A list of signatures to store
Raises:
SynapseError: if a signature is invalid
@@ -877,25 +903,26 @@ class E2eKeysHandler:
return master_key_signature_list
- async def _process_other_signatures(self, user_id, signatures):
+ async def _process_other_signatures(
+ self, user_id: str, signatures: Dict[str, dict]
+ ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of other users' keys. These will be the
target user's master keys, signed by the uploading user's user-signing
key.
Args:
- user_id (string): the user uploading the keys
- signatures (dict[string, dict]): map of users to devices to signed keys
+ user_id: the user uploading the keys
+ signatures: map of users to devices to signed keys
Returns:
- (list[SignatureListItem], dict[string, dict[string, dict]]):
- a list of signatures to store, and a map of users to devices to failure
+ A list of signatures to store, and a map of users to devices to failure
reasons
Raises:
SynapseError: if the input is malformed
"""
- signature_list = []
- failures = {}
+ signature_list = [] # type: List[SignatureListItem]
+ failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures:
return signature_list, failures
@@ -983,7 +1010,7 @@ class E2eKeysHandler:
async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None
- ):
+ ) -> Tuple[JsonDict, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key.
First, attempt to fetch the cross-signing public key from storage.
@@ -997,8 +1024,7 @@ class E2eKeysHandler:
This affects what signatures are fetched.
Returns:
- dict, str, VerifyKey: the raw key data, the key ID, and the
- signedjson verify key
+ The raw key data, the key ID, and the signedjson verify key
Raises:
NotFoundError: if the key is not found
@@ -1135,16 +1161,18 @@ class E2eKeysHandler:
return desired_key, desired_key_id, desired_verify_key
-def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
+def _check_cross_signing_key(
+ key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
+) -> None:
"""Check a cross-signing key uploaded by a user. Performs some basic sanity
checking, and ensures that it is signed, if a signature is required.
Args:
- key (dict): the key data to verify
- user_id (str): the user whose key is being checked
- key_type (str): the type of key that the key should be
- signing_key (VerifyKey): (optional) the signing key that the key should
- be signed with. If omitted, signatures will not be checked.
+ key: the key data to verify
+ user_id: the user whose key is being checked
+ key_type: the type of key that the key should be
+ signing_key: the signing key that the key should be signed with. If
+ omitted, signatures will not be checked.
"""
if (
key.get("user_id") != user_id
@@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
)
-def _check_device_signature(user_id, verify_key, signed_device, stored_device):
+def _check_device_signature(
+ user_id: str,
+ verify_key: VerifyKey,
+ signed_device: JsonDict,
+ stored_device: JsonDict,
+) -> None:
"""Check that a signature on a device or cross-signing key is correct and
matches the copy of the device/key that we have stored. Throws an
exception if an error is detected.
Args:
- user_id (str): the user ID whose signature is being checked
- verify_key (VerifyKey): the key to verify the device with
- signed_device (dict): the uploaded signed device data
- stored_device (dict): our previously stored copy of the device
+ user_id: the user ID whose signature is being checked
+ verify_key: the key to verify the device with
+ signed_device: the uploaded signed device data
+ stored_device: our previously stored copy of the device
Raises:
SynapseError: if the signature was invalid or the sent device is not the
@@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
-def _exception_to_failure(e):
+def _exception_to_failure(e: Exception) -> JsonDict:
if isinstance(e, SynapseError):
return {"status": e.code, "errcode": e.errcode, "message": str(e)}
@@ -1218,7 +1251,7 @@ def _exception_to_failure(e):
return {"status": 503, "message": str(e)}
-def _one_time_keys_match(old_key_json, new_key):
+def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly
@@ -1239,16 +1272,16 @@ class SignatureListItem:
"""An item in the signature list as used by upload_signatures_for_device_keys.
"""
- signing_key_id = attr.ib()
- target_user_id = attr.ib()
- target_device_id = attr.ib()
- signature = attr.ib()
+ signing_key_id = attr.ib(type=str)
+ target_user_id = attr.ib(type=str)
+ target_device_id = attr.ib(type=str)
+ signature = attr.ib(type=JsonDict)
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""
- def __init__(self, hs, e2e_keys_handler):
+ def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
@@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
- self._pending_updates = {}
+ self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
@@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
iterable=True,
)
- async def incoming_signing_key_update(self, origin, edu_content):
+ async def incoming_signing_key_update(
+ self, origin: str, edu_content: JsonDict
+ ) -> None:
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.
Args:
- origin (string): the server that sent the EDU
- edu_content (dict): the contents of the EDU
+ origin: the server that sent the EDU
+ edu_content: the contents of the EDU
"""
user_id = edu_content.pop("user_id")
@@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
await self._handle_signing_key_updates(user_id)
- async def _handle_signing_key_updates(self, user_id):
+ async def _handle_signing_key_updates(self, user_id: str) -> None:
"""Actually handle pending updates.
Args:
- user_id (string): the user whose updates we are processing
+ user_id: the user whose updates we are processing
"""
device_handler = self.e2e_keys_handler.device_handler
@@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates
return
- device_ids = []
+ device_ids = [] # type: List[str]
logger.info("pending updates: %r", pending_updates)
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f01b090772..622cae23be 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, List, Optional
from synapse.api.errors import (
Codes,
@@ -24,8 +25,12 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, trace
+from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -37,7 +42,7 @@ class E2eRoomKeysHandler:
The actual payload of the encrypted keys is completely opaque to the handler.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
# Used to lock whenever a client is uploading key data. This prevents collisions
@@ -48,21 +53,27 @@ class E2eRoomKeysHandler:
self._upload_linearizer = Linearizer("upload_room_keys_lock")
@trace
- async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_room_keys(
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> List[JsonDict]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
Args:
- user_id(str): the user whose keys we're getting
- version(str): the version ID of the backup we're getting keys from
- room_id(string): room ID to get keys for, for None to get keys for all rooms
- session_id(string): session ID to get keys for, for None to get keys for all
+ user_id: the user whose keys we're getting
+ version: the version ID of the backup we're getting keys from
+ room_id: room ID to get keys for, for None to get keys for all rooms
+ session_id: session ID to get keys for, for None to get keys for all
sessions
Raises:
NotFoundError: if the backup version does not exist
Returns:
- A deferred list of dicts giving the session_data and message metadata for
+ A list of dicts giving the session_data and message metadata for
these room keys.
"""
@@ -86,17 +97,23 @@ class E2eRoomKeysHandler:
return results
@trace
- async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def delete_room_keys(
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> JsonDict:
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
Args:
- user_id(str): the user whose backup we're deleting
- version(str): the version ID of the backup we're deleting
- room_id(string): room ID to delete keys for, for None to delete keys for all
+ user_id: the user whose backup we're deleting
+ version: the version ID of the backup we're deleting
+ room_id: room ID to delete keys for, for None to delete keys for all
rooms
- session_id(string): session ID to delete keys for, for None to delete keys
+ session_id: session ID to delete keys for, for None to delete keys
for all sessions
Raises:
NotFoundError: if the backup version does not exist
@@ -128,15 +145,17 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count}
@trace
- async def upload_room_keys(self, user_id, version, room_keys):
+ async def upload_room_keys(
+ self, user_id: str, version: str, room_keys: JsonDict
+ ) -> JsonDict:
"""Bulk upload a list of room keys into a given backup version, asserting
that the given version is the current backup version. room_keys are merged
into the current backup as described in RoomKeysServlet.on_PUT().
Args:
- user_id(str): the user whose backup we're setting
- version(str): the version ID of the backup we're updating
- room_keys(dict): a nested dict describing the room_keys we're setting:
+ user_id: the user whose backup we're setting
+ version: the version ID of the backup we're updating
+ room_keys: a nested dict describing the room_keys we're setting:
{
"rooms": {
@@ -254,14 +273,16 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count}
@staticmethod
- def _should_replace_room_key(current_room_key, room_key):
+ def _should_replace_room_key(
+ current_room_key: Optional[JsonDict], room_key: JsonDict
+ ) -> bool:
"""
Determine whether to replace a given current_room_key (if any)
with a newly uploaded room_key backup
Args:
- current_room_key (dict): Optional, the current room_key dict if any
- room_key (dict): The new room_key dict which may or may not be fit to
+ current_room_key: Optional, the current room_key dict if any
+ room_key : The new room_key dict which may or may not be fit to
replace the current_room_key
Returns:
@@ -286,14 +307,14 @@ class E2eRoomKeysHandler:
return True
@trace
- async def create_version(self, user_id, version_info):
+ async def create_version(self, user_id: str, version_info: JsonDict) -> str:
"""Create a new backup version. This automatically becomes the new
backup version for the user's keys; previous backups will no longer be
writeable to.
Args:
- user_id(str): the user whose backup version we're creating
- version_info(dict): metadata about the new version being created
+ user_id: the user whose backup version we're creating
+ version_info: metadata about the new version being created
{
"algorithm": "m.megolm_backup.v1",
@@ -301,7 +322,7 @@ class E2eRoomKeysHandler:
}
Returns:
- A deferred of a string that gives the new version number.
+ The new version number.
"""
# TODO: Validate the JSON to make sure it has the right keys.
@@ -313,17 +334,19 @@ class E2eRoomKeysHandler:
)
return new_version
- async def get_version_info(self, user_id, version=None):
+ async def get_version_info(
+ self, user_id: str, version: Optional[str] = None
+ ) -> JsonDict:
"""Get the info about a given version of the user's backup
Args:
- user_id(str): the user whose current backup version we're querying
- version(str): Optional; if None gives the most recent version
+ user_id: the user whose current backup version we're querying
+ version: Optional; if None gives the most recent version
otherwise a historical one.
Raises:
NotFoundError: if the requested backup version doesn't exist
Returns:
- A deferred of a info dict that gives the info about the new version.
+ A info dict that gives the info about the new version.
{
"version": "1234",
@@ -346,7 +369,7 @@ class E2eRoomKeysHandler:
return res
@trace
- async def delete_version(self, user_id, version=None):
+ async def delete_version(self, user_id: str, version: Optional[str] = None) -> None:
"""Deletes a given version of the user's e2e_room_keys backup
Args:
@@ -366,17 +389,19 @@ class E2eRoomKeysHandler:
raise
@trace
- async def update_version(self, user_id, version, version_info):
+ async def update_version(
+ self, user_id: str, version: str, version_info: JsonDict
+ ) -> JsonDict:
"""Update the info about a given version of the user's backup
Args:
- user_id(str): the user whose current backup version we're updating
- version(str): the backup version we're updating
- version_info(dict): the new information about the backup
+ user_id: the user whose current backup version we're updating
+ version: the backup version we're updating
+ version_info: the new information about the backup
Raises:
NotFoundError: if the requested backup version doesn't exist
Returns:
- A deferred of an empty dict.
+ An empty dict.
"""
if "version" not in version_info:
version_info["version"] = version
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index fd8de8696d..fa56b31438 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -186,7 +187,7 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("handling received PDU: %s", pdu)
+ logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
# We reprocess pdus when we have seen them only as outliers
existing = await self.store.get_event(
@@ -301,6 +302,14 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
+ elif missing_prevs:
+ logger.info(
+ "[%s %s] Not recursively fetching %d missing prev_events: %s",
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -345,12 +354,6 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
- logger.info(
- "Event %s is missing prev_events: calculating state for a "
- "backwards extremity",
- event_id,
- )
-
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: pdu}
@@ -368,7 +371,10 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "Requesting state at missing prev_event %s", event_id,
+ "[%s %s] Requesting state at missing prev_event %s",
+ room_id,
+ event_id,
+ p,
)
with nested_logging_context(p):
@@ -402,9 +408,7 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self.store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.AS_IS,
+ list(state_map.values()), get_prev_content=False,
)
event_map.update(evs)
@@ -1439,6 +1443,73 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue)
+ @log_function
+ async def do_knock(
+ self, target_hosts: List[str], room_id: str, knockee: str, content: JsonDict,
+ ) -> Tuple[str, int]:
+ """Sends the knock to the remote server.
+
+ This first triggers a make_knock request that returns a partial
+ event that we can fill out and sign. This is then sent to the
+ remote server via send_knock.
+
+ Knock events must be signed by the knockee's server before distributing.
+
+ Args:
+ target_hosts: A list of hosts that we want to try knocking through.
+ room_id: The ID of the room to knock on.
+ knockee: The ID of the user who is knocking.
+ content: The content of the knock event.
+
+ Returns:
+ A tuple of (event ID, stream ID).
+
+ Raises:
+ SynapseError: If the chosen remote server returns a 3xx/4xx code.
+ RuntimeError: If no servers were reachable.
+ """
+ logger.debug("Knocking on room %s on behalf of user %s", room_id, knockee)
+
+ # Inform the remote server of the room versions we support
+ supported_room_versions = list(KNOWN_ROOM_VERSIONS.keys())
+
+ # Ask the remote server to create a valid knock event for us. Once received,
+ # we sign the event
+ params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]]
+ origin, event, event_format_version = await self._make_and_verify_event(
+ target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
+ )
+
+ # Record the room ID and its version so that we have a record of the room
+ await self._maybe_store_room_on_outlier_membership(
+ room_id=event.room_id, room_version=event_format_version
+ )
+
+ # Initially try the host that we successfully called /make_knock on
+ try:
+ target_hosts.remove(origin)
+ target_hosts.insert(0, origin)
+ except ValueError:
+ pass
+
+ # Send the signed event back to the room, and potentially receive some
+ # further information about the room in the form of partial state events
+ stripped_room_state = await self.federation_client.send_knock(
+ target_hosts, event
+ )
+
+ # Store any stripped room state events in the "unsigned" key of the event.
+ # This is a bit of a hack and is cribbing off of invites. Basically we
+ # store the room state here and retrieve it again when this event appears
+ # in the invitee's sync stream. It is stripped out for all other local users.
+ event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
+
+ context = await self.state_handler.compute_event_context(event)
+ stream_id = await self.persist_events_and_notify(
+ event.room_id, [(event, context)]
+ )
+ return event.event_id, stream_id
+
async def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
@@ -1593,8 +1664,15 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
+ is_published = await self.store.is_room_published(event.room_id)
+
if not await self.spam_checker.user_may_invite(
- event.sender, event.state_key, event.room_id
+ event.sender,
+ event.state_key,
+ None,
+ room_id=event.room_id,
+ new_room=False,
+ published_room=is_published,
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
@@ -1617,6 +1695,10 @@ class FederationHandler(BaseHandler):
if event.state_key == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
+ # We retrieve the room member handler here as to not cause a cyclic dependency
+ member_handler = self.hs.get_room_member_handler()
+ member_handler.ratelimit_invite(event.room_id, event.state_key)
+
# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
# join dance).
@@ -1775,6 +1857,120 @@ class FederationHandler(BaseHandler):
return None
+ @log_function
+ async def on_make_knock_request(
+ self, origin: str, room_id: str, user_id: str
+ ) -> EventBase:
+ """We've received a make_knock request, so we create a partial
+ knock event for the room and return that. We do *not* persist or
+ process it until the other server has signed it and sent it back.
+
+ Args:
+ origin: The (verified) server name of the requesting server.
+ room_id: The room to create the knock event in.
+ user_id: The user to create the knock for.
+
+ Returns:
+ The partial knock event.
+ """
+ if get_domain_from_id(user_id) != origin:
+ logger.info(
+ "Get /xyz.amorgan.knock/make_knock request for user %r"
+ "from different origin %s, ignoring",
+ user_id,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
+ room_version = await self.store.get_room_version_id(room_id)
+
+ builder = self.event_builder_factory.new(
+ room_version,
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": Membership.KNOCK},
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": user_id,
+ },
+ )
+
+ event, context = await self.event_creation_handler.create_new_client_event(
+ builder=builder
+ )
+
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.warning("Creation of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ try:
+ # The remote hasn't signed it yet, obviously. We'll do the full checks
+ # when we get the event back in `on_send_knock_request`
+ await self.auth.check_from_context(
+ room_version, event, context, do_sig_check=False
+ )
+ except AuthError as e:
+ logger.warning("Failed to create new knock %r because %s", event, e)
+ raise e
+
+ return event
+
+ @log_function
+ async def on_send_knock_request(
+ self, origin: str, event: EventBase
+ ) -> EventContext:
+ """
+ We have received a knock event for a room. Verify that event and send it into the room
+ on the knocking homeserver's behalf.
+
+ Args:
+ origin: The remote homeserver of the knocking user.
+ event: The knocking member event that has been signed by the remote homeserver.
+
+ Returns:
+ The context of the event after inserting it into the room graph.
+ """
+ logger.debug(
+ "on_send_knock_request: Got event: %s, signatures: %s",
+ event.event_id,
+ event.signatures,
+ )
+
+ if get_domain_from_id(event.sender) != origin:
+ logger.info(
+ "Got /xyz.amorgan.knock/send_knock request for user %r "
+ "from different origin %s",
+ event.sender,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
+ event.internal_metadata.outlier = False
+
+ context = await self._handle_new_event(origin, event)
+
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info("Sending of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ logger.debug(
+ "on_send_knock_request: After _handle_new_event: %s, sigs: %s",
+ event.event_id,
+ event.signatures,
+ )
+
+ return context
+
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event.
"""
@@ -2093,6 +2289,11 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
+ # If we are going to send this event over federation we precaclculate
+ # the joined hosts.
+ if event.internal_metadata.get_send_on_behalf_of():
+ await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
return context
async def _check_for_soft_fail(
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index df29edeb83..71f11ef94a 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -15,9 +15,13 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import GroupID, get_domain_from_id
+from synapse.types import GroupID, JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -56,7 +60,7 @@ def _create_rerouter(func_name):
class GroupsLocalWorkerHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler()
@@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")
- async def get_group_summary(self, group_id, requester_user_id):
+ async def get_group_summary(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get the group summary for a group.
If the group is remote we check that the users have valid attestations.
@@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler:
return res
- async def get_users_in_group(self, group_id, requester_user_id):
+ async def get_users_in_group(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get users in a group
"""
if self.is_mine_id(group_id):
- res = await self.groups_server_handler.get_users_in_group(
+ return await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
- return res
group_server_name = get_domain_from_id(group_id)
@@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler:
return res
- async def get_joined_groups(self, user_id):
+ async def get_joined_groups(self, user_id: str) -> JsonDict:
group_ids = await self.store.get_joined_groups(user_id)
return {"groups": group_ids}
- async def get_publicised_groups_for_user(self, user_id):
+ async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
if self.hs.is_mine_id(user_id):
result = await self.store.get_publicised_groups_for_user(user_id)
@@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler:
# TODO: Verify attestations
return {"groups": result}
- async def bulk_get_publicised_groups(self, user_ids, proxy=True):
- destinations = {}
+ async def bulk_get_publicised_groups(
+ self, user_ids: Iterable[str], proxy: bool = True
+ ) -> JsonDict:
+ destinations = {} # type: Dict[str, Set[str]]
local_users = set()
for user_id in user_ids:
@@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler:
raise SynapseError(400, "Some user_ids are not local")
results = {}
- failed_results = []
+ failed_results = [] # type: List[str]
for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
@@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler:
class GroupsLocalHandler(GroupsLocalWorkerHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
# Ensure attestations get renewed
@@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy")
- async def create_group(self, group_id, user_id, content):
+ async def create_group(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Create a group
"""
@@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = None
remote_attestation = None
else:
- local_attestation = self.attestations.create_attestation(group_id, user_id)
- content["attestation"] = local_attestation
-
- content["user_profile"] = await self.profile_handler.get_profile(user_id)
-
- try:
- res = await self.transport_client.create_group(
- get_domain_from_id(group_id), group_id, user_id, content
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- remote_attestation = res["attestation"]
- await self.attestations.verify_attestation(
- remote_attestation,
- group_id=group_id,
- user_id=user_id,
- server_name=get_domain_from_id(group_id),
- )
+ raise SynapseError(400, "Unable to create remote groups")
is_publicised = content.get("publicise", False)
token = await self.store.register_user_group_membership(
@@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- async def join_group(self, group_id, user_id, content):
+ async def join_group(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Request to join a group
"""
if self.is_mine_id(group_id):
@@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
- async def accept_invite(self, group_id, user_id, content):
+ async def accept_invite(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
@@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
- async def invite(self, group_id, user_id, requester_user_id, config):
+ async def invite(
+ self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
+ ) -> JsonDict:
"""Invite a user to a group
"""
content = {"requester_user_id": requester_user_id, "config": config}
@@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- async def on_invite(self, group_id, user_id, content):
+ async def on_invite(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""One of our users were invited to a group
"""
# TODO: Support auto join and rejection
@@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile}
async def remove_user_from_group(
- self, group_id, user_id, requester_user_id, content
- ):
+ self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Remove a user from a group
"""
if user_id == requester_user_id:
@@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- async def user_removed_from_group(self, group_id, user_id, content):
+ async def user_removed_from_group(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> None:
"""One of our users was removed/kicked from a group
"""
# TODO: Check if user in group
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index f61844d688..8dbf9bef3f 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,14 +22,18 @@ import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from synapse.api.errors import (
+ AuthError,
CodeMessageException,
Codes,
HttpResponseException,
+ ProxiedRequestError,
SynapseError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
+from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
@@ -39,8 +43,6 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
-id_server_scheme = "https://"
-
class IdentityHandler(BaseHandler):
def __init__(self, hs):
@@ -55,17 +57,46 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
+ self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls
+ self._enable_lookup = hs.config.enable_3pid_lookup
+
self._web_client_location = hs.config.invite_client_location
+ # Ratelimiters for `/requestToken` endpoints.
+ self._3pid_validation_ratelimiter_ip = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+ burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+ )
+ self._3pid_validation_ratelimiter_address = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+ burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+ )
+
+ def ratelimit_request_token_requests(
+ self, request: SynapseRequest, medium: str, address: str,
+ ):
+ """Used to ratelimit requests to `/requestToken` by IP and address.
+
+ Args:
+ request: The associated request
+ medium: The type of threepid, e.g. "msisdn" or "email"
+ address: The actual threepid ID, e.g. the phone number or email address
+ """
+
+ self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
+ self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+
async def threepid_from_creds(
- self, id_server: str, creds: Dict[str, str]
+ self, id_server_url: str, creds: Dict[str, str]
) -> Optional[JsonDict]:
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
Args:
- id_server: The identity server to validate 3PIDs against. Must be a
+ id_server_url: The identity server to validate 3PIDs against. Must be a
complete URL including the protocol (http(s)://)
creds: Dictionary containing the following keys:
* client_secret|clientSecret: A unique secret str provided by the client
@@ -90,7 +121,14 @@ class IdentityHandler(BaseHandler):
query_params = {"sid": session_id, "client_secret": client_secret}
- url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
+ url = "%s%s" % (
+ id_server_url,
+ "/_matrix/identity/api/v1/3pid/getValidated3pid",
+ )
try:
data = await self.http_client.get_json(url, query_params)
@@ -99,7 +137,7 @@ class IdentityHandler(BaseHandler):
except HttpResponseException as e:
logger.info(
"%s returned %i for threepid validation for: %s",
- id_server,
+ id_server_url,
e.code,
creds,
)
@@ -113,7 +151,7 @@ class IdentityHandler(BaseHandler):
if "medium" in data:
return data
- logger.info("%s reported non-validated threepid: %s", id_server, creds)
+ logger.info("%s reported non-validated threepid: %s", id_server_url, creds)
return None
async def bind_threepid(
@@ -145,14 +183,19 @@ class IdentityHandler(BaseHandler):
if id_access_token is None:
use_v2 = False
+ # if we have a rewrite rule set for the identity server,
+ # apply it now, but only for sending the request (not
+ # storing in the database).
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
# Decide which API endpoint URLs to use
headers = {}
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
if use_v2:
- bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+ bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,)
headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore
else:
- bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+ bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,)
try:
# Use the blacklisting http client as this call is only to identity servers
@@ -239,9 +282,6 @@ class IdentityHandler(BaseHandler):
True on success, otherwise False if the identity
server doesn't support unbinding
"""
- url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
- url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
-
content = {
"mxid": mxid,
"threepid": {"medium": threepid["medium"], "address": threepid["address"]},
@@ -250,6 +290,7 @@ class IdentityHandler(BaseHandler):
# we abuse the federation http client to sign the request, but we have to send it
# using the normal http client since we don't want the SRV lookup and want normal
# 'browser-like' HTTPS.
+ url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
method=b"POST",
@@ -259,6 +300,15 @@ class IdentityHandler(BaseHandler):
)
headers = {b"Authorization": auth_headers}
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ #
+ # Note that destination_is has to be the real id_server, not
+ # the server we connect to.
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ url = "%s/_matrix/identity/api/v1/3pid/unbind" % (id_server_url,)
+
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
@@ -373,9 +423,28 @@ class IdentityHandler(BaseHandler):
return session_id
+ def rewrite_id_server_url(self, url: str, add_https=False) -> str:
+ """Given an identity server URL, optionally add a protocol scheme
+ before rewriting it according to the rewrite_identity_server_urls
+ config option
+
+ Adds https:// to the URL if specified, then tries to rewrite the
+ url. Returns either the rewritten URL or the URL with optional
+ protocol scheme additions.
+ """
+ rewritten_url = url
+ if add_https:
+ rewritten_url = "https://" + rewritten_url
+
+ rewritten_url = self.rewrite_identity_server_urls.get(
+ rewritten_url, rewritten_url
+ )
+ logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url)
+ return rewritten_url
+
async def requestEmailToken(
self,
- id_server: str,
+ id_server_url: str,
email: str,
client_secret: str,
send_attempt: int,
@@ -386,7 +455,7 @@ class IdentityHandler(BaseHandler):
validation.
Args:
- id_server: The identity server to proxy to
+ id_server_url: The identity server to proxy to
email: The email to send the message to
client_secret: The unique client_secret sends by the user
send_attempt: Which attempt this is
@@ -400,6 +469,11 @@ class IdentityHandler(BaseHandler):
"client_secret": client_secret,
"send_attempt": send_attempt,
}
+
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
if next_link:
params["next_link"] = next_link
@@ -414,7 +488,8 @@ class IdentityHandler(BaseHandler):
try:
data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
+ "%s/_matrix/identity/api/v1/validate/email/requestToken"
+ % (id_server_url,),
params,
)
return data
@@ -426,7 +501,7 @@ class IdentityHandler(BaseHandler):
async def requestMsisdnToken(
self,
- id_server: str,
+ id_server_url: str,
country: str,
phone_number: str,
client_secret: str,
@@ -437,7 +512,7 @@ class IdentityHandler(BaseHandler):
Request an external server send an SMS message on our behalf for the purposes of
threepid validation.
Args:
- id_server: The identity server to proxy to
+ id_server_url: The identity server to proxy to
country: The country code of the phone number
phone_number: The number to send the message to
client_secret: The unique client_secret sends by the user
@@ -465,9 +540,13 @@ class IdentityHandler(BaseHandler):
"details and update your config file."
)
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
try:
data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
+ "%s/_matrix/identity/api/v1/validate/msisdn/requestToken"
+ % (id_server_url,),
params,
)
except HttpResponseException as e:
@@ -559,6 +638,86 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
+ # TODO: The following two methods are used for proxying IS requests using
+ # the CS API. They should be consolidated with those in RoomMemberHandler
+ # https://github.com/matrix-org/synapse-dinsic/issues/25
+
+ async def proxy_lookup_3pid(
+ self, id_server: str, medium: str, address: str
+ ) -> JsonDict:
+ """Looks up a 3pid in the passed identity server.
+
+ Args:
+ id_server: The server name (including port, if required)
+ of the identity server to use.
+ medium: The type of the third party identifier (e.g. "email").
+ address: The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ try:
+ data = await self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
+ {"medium": medium, "address": address},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %s: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ return data
+
+ async def proxy_bulk_lookup_3pid(
+ self, id_server: str, threepids: List[List[str]]
+ ) -> JsonDict:
+ """Looks up given 3pids in the passed identity server.
+
+ Args:
+ id_server: The server name (including port, if required)
+ of the identity server to use.
+ threepids: The third party identifiers to lookup, as
+ a list of 2-string sized lists ([medium, address]).
+
+ Returns:
+ The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ try:
+ data = await self.http_client.post_json_get_json(
+ "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,),
+ {"threepids": threepids},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %s: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ return data
+
async def lookup_3pid(
self,
id_server: str,
@@ -579,10 +738,13 @@ class IdentityHandler(BaseHandler):
Returns:
the matrix ID of the 3pid, or None if it is not recognized.
"""
+ # Rewrite id_server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
if id_access_token is not None:
try:
results = await self._lookup_3pid_v2(
- id_server, id_access_token, medium, address
+ id_server_url, id_access_token, medium, address
)
return results
@@ -600,16 +762,17 @@ class IdentityHandler(BaseHandler):
logger.warning("Error when looking up hashing details: %s", e)
return None
- return await self._lookup_3pid_v1(id_server, medium, address)
+ return await self._lookup_3pid_v1(id_server, id_server_url, medium, address)
async def _lookup_3pid_v1(
- self, id_server: str, medium: str, address: str
+ self, id_server: str, id_server_url: str, medium: str, address: str
) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
id_server: The server name (including port, if required)
of the identity server to use.
+ id_server_url: The actual, reachable domain of the id server
medium: The type of the third party identifier (e.g. "email").
address: The third party identifier (e.g. "foo@example.com").
@@ -617,8 +780,8 @@ class IdentityHandler(BaseHandler):
the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
+ data = await self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
{"medium": medium, "address": address},
)
@@ -635,13 +798,12 @@ class IdentityHandler(BaseHandler):
return None
async def _lookup_3pid_v2(
- self, id_server: str, id_access_token: str, medium: str, address: str
+ self, id_server_url: str, id_access_token: str, medium: str, address: str
) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
- id_server: The server name (including port, if required)
- of the identity server to use.
+ id_server_url: The protocol scheme and domain of the id server
id_access_token: The access token to authenticate to the identity server with
medium: The type of the third party identifier (e.g. "email").
address: The third party identifier (e.g. "foo@example.com").
@@ -651,8 +813,8 @@ class IdentityHandler(BaseHandler):
"""
# Check what hashing details are supported by this identity server
try:
- hash_details = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
+ hash_details = await self.http_client.get_json(
+ "%s/_matrix/identity/v2/hash_details" % (id_server_url,),
{"access_token": id_access_token},
)
except RequestTimedOutError:
@@ -660,15 +822,14 @@ class IdentityHandler(BaseHandler):
if not isinstance(hash_details, dict):
logger.warning(
- "Got non-dict object when checking hash details of %s%s: %s",
- id_server_scheme,
- id_server,
+ "Got non-dict object when checking hash details of %s: %s",
+ id_server_url,
hash_details,
)
raise SynapseError(
400,
- "Non-dict object from %s%s during v2 hash_details request: %s"
- % (id_server_scheme, id_server, hash_details),
+ "Non-dict object from %s during v2 hash_details request: %s"
+ % (id_server_url, hash_details),
)
# Extract information from hash_details
@@ -682,8 +843,8 @@ class IdentityHandler(BaseHandler):
):
raise SynapseError(
400,
- "Invalid hash details received from identity server %s%s: %s"
- % (id_server_scheme, id_server, hash_details),
+ "Invalid hash details received from identity server %s: %s"
+ % (id_server_url, hash_details),
)
# Check if any of the supported lookup algorithms are present
@@ -705,7 +866,7 @@ class IdentityHandler(BaseHandler):
else:
logger.warning(
"None of the provided lookup algorithms of %s are supported: %s",
- id_server,
+ id_server_url,
supported_lookup_algorithms,
)
raise SynapseError(
@@ -718,8 +879,8 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
- lookup_results = await self.blacklisting_http_client.post_json_get_json(
- "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
+ lookup_results = await self.http_client.post_json_get_json(
+ "%s/_matrix/identity/v2/lookup" % (id_server_url,),
{
"addresses": [lookup_value],
"algorithm": lookup_algorithm,
@@ -807,15 +968,17 @@ class IdentityHandler(BaseHandler):
if self._web_client_location:
invite_config["org.matrix.web_client_location"] = self._web_client_location
+ # Rewrite the identity server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
data = None
- base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
+ base_url = "%s/_matrix/identity" % (id_server_url,)
if id_access_token:
- key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
+ key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % (
+ id_server_url,
)
# Attempt a v2 lookup
@@ -834,9 +997,8 @@ class IdentityHandler(BaseHandler):
raise e
if data is None:
- key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
+ key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % (
+ id_server_url,
)
url = base_url + "/api/v1/store-invite"
@@ -848,10 +1010,7 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning(
- "Error trying to call /store-invite on %s%s: %s",
- id_server_scheme,
- id_server,
- e,
+ "Error trying to call /store-invite on %s: %s", id_server_url, e,
)
if data is None:
@@ -864,10 +1023,9 @@ class IdentityHandler(BaseHandler):
)
except HttpResponseException as e:
logger.warning(
- "Error calling /store-invite on %s%s with fallback "
+ "Error calling /store-invite on %s with fallback "
"encoding: %s",
- id_server_scheme,
- id_server,
+ id_server_url,
e,
)
raise e
@@ -888,6 +1046,42 @@ class IdentityHandler(BaseHandler):
display_name = data["display_name"]
return token, public_keys, fallback_public_key, display_name
+ async def bind_email_using_internal_sydent_api(
+ self, id_server_url: str, email: str, user_id: str,
+ ):
+ """Bind an email to a fully qualified user ID using the internal API of an
+ instance of Sydent.
+
+ Args:
+ id_server_url: The URL of the Sydent instance
+ email: The email address to bind
+ user_id: The user ID to bind the email to
+
+ Raises:
+ HTTPResponseException: On a non-2xx HTTP response.
+ """
+ # Extract the domain name from the IS URL as we store IS domains instead of URLs
+ id_server = urllib.parse.urlparse(id_server_url).hostname
+ if not id_server:
+ # We were unable to determine the hostname, bail out
+ return
+
+ # id_server_url is assumed to have no trailing slashes
+ url = id_server_url + "/_matrix/identity/internal/bind"
+ body = {
+ "address": email,
+ "medium": "email",
+ "mxid": user_id,
+ }
+
+ # Bind the threepid
+ await self.http_client.post_json_get_json(url, body)
+
+ # Remember where we bound the threepid
+ await self.store.add_user_bound_threepid(
+ user_id=user_id, medium="email", address=email, id_server=id_server,
+ )
+
def create_id_access_token_header(id_access_token: str) -> List[str]:
"""Create an Authorization header for passing to SimpleHttpClient as the header value
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9dfeab09cd..f3694e0973 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+# Copyrignt 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -40,6 +41,7 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder
+from synapse.config.api import DEFAULT_ROOM_STATE_TYPES
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
@@ -432,6 +434,8 @@ class EventCreationHandler:
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._external_cache = hs.get_external_cache()
+
async def create_event(
self,
requester: Requester,
@@ -492,7 +496,7 @@ class EventCreationHandler:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
- if membership in {Membership.JOIN, Membership.INVITE}:
+ if membership in {Membership.JOIN, Membership.INVITE, Membership.KNOCK}:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
@@ -918,8 +922,8 @@ class EventCreationHandler:
room_version = await self.store.get_room_version_id(event.room_id)
if event.internal_metadata.is_out_of_band_membership():
- # the only sort of out-of-band-membership events we expect to see here
- # are invite rejections we have generated ourselves.
+ # the only sort of out-of-band-membership events we expect to see here are
+ # invite rejections and rescinded knocks that we have generated ourselves.
assert event.type == EventTypes.Member
assert event.content["membership"] == Membership.LEAVE
else:
@@ -939,6 +943,8 @@ class EventCreationHandler:
await self.action_generator.handle_push_actions_for_event(event, context)
+ await self.cache_joined_hosts_for_event(event)
+
try:
# If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -978,6 +984,44 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
+ async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
+ """Precalculate the joined hosts at the event, when using Redis, so that
+ external federation senders don't have to recalculate it themselves.
+ """
+
+ if not self._external_cache.is_enabled():
+ return
+
+ # We actually store two mappings, event ID -> prev state group,
+ # state group -> joined hosts, which is much more space efficient
+ # than event ID -> joined hosts.
+ #
+ # Note: We have to cache event ID -> prev state group, as we don't
+ # store that in the DB.
+ #
+ # Note: We always set the state group -> joined hosts cache, even if
+ # we already set it, so that the expiry time is reset.
+
+ state_entry = await self.state.resolve_state_groups_for_events(
+ event.room_id, event_ids=event.prev_event_ids()
+ )
+
+ if state_entry.state_group:
+ joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+
+ await self._external_cache.set(
+ "event_to_prev_state_group",
+ event.event_id,
+ state_entry.state_group,
+ expiry_ms=60 * 60 * 1000,
+ )
+ await self._external_cache.set(
+ "get_joined_hosts",
+ str(state_entry.state_group),
+ list(joined_hosts),
+ expiry_ms=60 * 60 * 1000,
+ )
+
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
) -> None:
@@ -1125,6 +1169,13 @@ class EventCreationHandler:
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
+ if event.content["membership"] == Membership.KNOCK:
+ event.unsigned[
+ "knock_room_state"
+ ] = await self.store.get_stripped_room_state_from_event_context(
+ context, DEFAULT_ROOM_STATE_TYPES,
+ )
+
if event.type == EventTypes.Redaction:
original_event = await self.store.get_event(
event.redacts,
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c02b951031..4b102ff9a9 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +15,11 @@
# limitations under the License.
import logging
import random
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, List, Optional
+
+from signedjson.sign import sign_json
+
+from twisted.internet import reactor
from synapse.api.errors import (
AuthError,
@@ -24,7 +29,11 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
-from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.logging.context import run_in_background
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.types import (
JsonDict,
Requester,
@@ -54,6 +63,8 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
+ PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000
+
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -64,11 +75,98 @@ class ProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
+
+ self.max_avatar_size = hs.config.max_avatar_size
+ self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes
+ self.replicate_user_profiles_to = hs.config.replicate_user_profiles_to
+
if hs.config.run_background_tasks:
self.clock.looping_call(
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ reactor.callWhenRunning(self._do_assign_profile_replication_batches)
+ reactor.callWhenRunning(self._start_replicate_profiles)
+ # Add a looping call to replicate_profiles: this handles retries
+ # if the replication is unsuccessful when the user updated their
+ # profile.
+ self.clock.looping_call(
+ self._start_replicate_profiles, self.PROFILE_REPLICATE_INTERVAL
+ )
+
+ def _do_assign_profile_replication_batches(self):
+ return run_as_background_process(
+ "_assign_profile_replication_batches",
+ self._assign_profile_replication_batches,
+ )
+
+ def _start_replicate_profiles(self):
+ return run_as_background_process(
+ "_replicate_profiles", self._replicate_profiles
+ )
+
+ async def _assign_profile_replication_batches(self):
+ """If no profile replication has been done yet, allocate replication batch
+ numbers to each profile to start the replication process.
+ """
+ logger.info("Assigning profile batch numbers...")
+ total = 0
+ while True:
+ assigned = await self.store.assign_profile_batch()
+ total += assigned
+ if assigned == 0:
+ break
+ logger.info("Assigned %d profile batch numbers", total)
+
+ async def _replicate_profiles(self):
+ """If any profile data has been updated and not pushed to the replication targets,
+ replicate it.
+ """
+ host_batches = await self.store.get_replication_hosts()
+ latest_batch = await self.store.get_latest_profile_replication_batch_number()
+ if latest_batch is None:
+ latest_batch = -1
+ for repl_host in self.hs.config.replicate_user_profiles_to:
+ if repl_host not in host_batches:
+ host_batches[repl_host] = -1
+ try:
+ for i in range(host_batches[repl_host] + 1, latest_batch + 1):
+ await self._replicate_host_profile_batch(repl_host, i)
+ except Exception:
+ logger.exception(
+ "Exception while replicating to %s: aborting for now", repl_host
+ )
+
+ async def _replicate_host_profile_batch(self, host, batchnum):
+ logger.info("Replicating profile batch %d to %s", batchnum, host)
+ batch_rows = await self.store.get_profile_batch(batchnum)
+ batch = {
+ UserID(r["user_id"], self.hs.hostname).to_string(): (
+ {"display_name": r["displayname"], "avatar_url": r["avatar_url"]}
+ if r["active"]
+ else None
+ )
+ for r in batch_rows
+ }
+
+ url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,)
+ body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
+ signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
+ try:
+ await self.http_client.post_json_get_json(url, signed_body)
+ await self.store.update_replication_batch_for_host(host, batchnum)
+ logger.info(
+ "Successfully replicated profile batch %d to %s", batchnum, host
+ )
+ except Exception:
+ # This will get retried when the looping call next comes around
+ logger.exception(
+ "Failed to replicate profile batch %d to %s", batchnum, host
+ )
+ raise
+
async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id)
@@ -210,8 +308,16 @@ class ProfileHandler(BaseHandler):
target_user, authenticated_entity=requester.authenticated_entity,
)
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
await self.store.set_profile_displayname(
- target_user.localpart, displayname_to_set
+ target_user.localpart, displayname_to_set, new_batchnum
)
if self.hs.config.user_directory_search_all_users:
@@ -222,6 +328,46 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ async def set_active(
+ self, users: List[UserID], active: bool, hide: bool,
+ ):
+ """
+ Sets the 'active' flag on a set of user profiles. If set to false, the
+ accounts are considered deactivated or hidden.
+
+ If 'hide' is true, then we interpret active=False as a request to try to
+ hide the users rather than deactivating them. This means withholding the
+ profiles from replication (and mark it as inactive) rather than clearing
+ the profile from the HS DB.
+
+ Note that unlike set_displayname and set_avatar_url, this does *not*
+ perform authorization checks! This is because the only place it's used
+ currently is in account deactivation where we've already done these
+ checks anyway.
+
+ Args:
+ users: The users to modify
+ active: Whether to set the user to active or inactive
+ hide: Whether to hide the user (withold from replication). If
+ False and active is False, user will have their profile
+ erased
+ """
+ if len(self.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
+ await self.store.set_profiles_active(users, active, hide, new_batchnum)
+
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
if self.hs.is_mine(target_user):
try:
@@ -290,14 +436,56 @@ class ProfileHandler(BaseHandler):
if new_avatar_url == "":
avatar_url_to_set = None
+ # Enforce a max avatar size if one is defined
+ if avatar_url_to_set and (
+ self.max_avatar_size or self.allowed_avatar_mimetypes
+ ):
+ media_id = self._validate_and_parse_media_id_from_avatar_url(
+ avatar_url_to_set
+ )
+
+ # Check that this media exists locally
+ media_info = await self.store.get_local_media(media_id)
+ if not media_info:
+ raise SynapseError(
+ 400, "Unknown media id supplied", errcode=Codes.NOT_FOUND
+ )
+
+ # Ensure avatar does not exceed max allowed avatar size
+ media_size = media_info["media_length"]
+ if self.max_avatar_size and media_size > self.max_avatar_size:
+ raise SynapseError(
+ 400,
+ "Avatars must be less than %s bytes in size"
+ % (self.max_avatar_size,),
+ errcode=Codes.TOO_LARGE,
+ )
+
+ # Ensure the avatar's file type is allowed
+ if (
+ self.allowed_avatar_mimetypes
+ and media_info["media_type"] not in self.allowed_avatar_mimetypes
+ ):
+ raise SynapseError(
+ 400, "Avatar file type '%s' not allowed" % media_info["media_type"]
+ )
+
# Same like set_displayname
if by_admin:
requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
)
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
await self.store.set_profile_avatar_url(
- target_user.localpart, avatar_url_to_set
+ target_user.localpart, avatar_url_to_set, new_batchnum
)
if self.hs.config.user_directory_search_all_users:
@@ -308,6 +496,23 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ def _validate_and_parse_media_id_from_avatar_url(self, mxc):
+ """Validate and parse a provided avatar url and return the local media id
+
+ Args:
+ mxc (str): A mxc URL
+
+ Returns:
+ str: The ID of the media
+ """
+ avatar_pieces = mxc.split("/")
+ if len(avatar_pieces) != 4 or avatar_pieces[0] != "mxc:":
+ raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc)
+ return avatar_pieces[-1]
+
async def on_profile_query(self, args: JsonDict) -> JsonDict:
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 49b085269b..ab4d5ccc1c 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -49,6 +49,7 @@ class RegistrationHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
self.identity_handler = self.hs.get_identity_handler()
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
@@ -57,6 +58,8 @@ class RegistrationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
+ self._show_in_user_directory = self.hs.config.show_users_in_user_directory
+
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -77,6 +80,16 @@ class RegistrationHandler(BaseHandler):
guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None,
):
+ """
+
+ Args:
+ localpart (str|None): The user's localpart
+ guest_access_token (str|None): A guest's access token
+ assigned_user_id (str|None): An existing User ID for this user if pre-calculated
+
+ Returns:
+ Deferred
+ """
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
@@ -119,6 +132,8 @@ class RegistrationHandler(BaseHandler):
raise SynapseError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
+
+ # Retrieve guest user information from provided access token
user_data = await self.auth.get_user_by_access_token(guest_access_token)
if (
not user_data.is_guest
@@ -237,6 +252,12 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
+ if default_display_name:
+ requester = create_requester(user)
+ await self.profile_handler.set_displayname(
+ user, requester, default_display_name, by_admin=True
+ )
+
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(localpart)
await self.user_directory_handler.handle_local_profile_change(
@@ -246,8 +267,6 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
fail_count = 0
- # If a default display name is not given, generate one.
- generate_display_name = default_display_name is None
# This breaks on successful registration *or* errors after 10 failures.
while True:
# Fail after being unable to find a suitable ID a few times
@@ -258,7 +277,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
- if generate_display_name:
+ if default_display_name is None:
default_display_name = localpart
try:
await self.register_with_store(
@@ -270,6 +289,11 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
+ requester = create_requester(user)
+ await self.profile_handler.set_displayname(
+ user, requester, default_display_name, by_admin=True
+ )
+
# Successfully registered
break
except SynapseError:
@@ -301,7 +325,15 @@ class RegistrationHandler(BaseHandler):
}
# Bind email to new account
- await self._register_email_threepid(user_id, threepid_dict, None)
+ await self.register_email_threepid(user_id, threepid_dict, None)
+
+ # Prevent the new user from showing up in the user directory if the server
+ # mandates it.
+ if not self._show_in_user_directory:
+ await self.store.add_account_data_for_user(
+ user_id, "im.vector.hide_profile", {"hide_profile": True}
+ )
+ await self.profile_handler.set_active([user], False, True)
return user_id
@@ -501,7 +533,10 @@ class RegistrationHandler(BaseHandler):
"""
await self._auto_join_rooms(user_id)
- async def appservice_register(self, user_localpart: str, as_token: str) -> str:
+ async def appservice_register(
+ self, user_localpart: str, as_token: str, password_hash: str, display_name: str
+ ):
+ # FIXME: this should be factored out and merged with normal register()
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -518,12 +553,26 @@ class RegistrationHandler(BaseHandler):
self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)
+ display_name = display_name or user.localpart
+
await self.register_with_store(
user_id=user_id,
- password_hash="",
+ password_hash=password_hash,
appservice_id=service_id,
- create_profile_with_displayname=user.localpart,
+ create_profile_with_displayname=display_name,
+ )
+
+ requester = create_requester(user)
+ await self.profile_handler.set_displayname(
+ user, requester, display_name, by_admin=True
)
+
+ if self.hs.config.user_directory_search_all_users:
+ profile = await self.store.get_profileinfo(user_localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
return user_id
def check_user_id_not_appservice_exclusive(
@@ -552,6 +601,37 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
+ async def shadow_register(self, localpart, display_name, auth_result, params):
+ """Invokes the current registration on another server, using
+ shared secret registration, passing in any auth_results from
+ other registration UI auth flows (e.g. validated 3pids)
+ Useful for setting up shadow/backup accounts on a parallel deployment.
+ """
+
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token),
+ {
+ # XXX: auth_result is an unspecified extension for shadow registration
+ "auth_result": auth_result,
+ # XXX: another unspecified extension for shadow registration to ensure
+ # that the displayname is correctly set by the masters erver
+ "display_name": display_name,
+ "username": localpart,
+ "password": params.get("password"),
+ "bind_msisdn": params.get("bind_msisdn"),
+ "device_id": params.get("device_id"),
+ "initial_device_display_name": params.get(
+ "initial_device_display_name"
+ ),
+ "inhibit_login": False,
+ "access_token": as_token,
+ },
+ )
+
def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
@@ -704,6 +784,7 @@ class RegistrationHandler(BaseHandler):
if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
threepid = auth_result[LoginType.EMAIL_IDENTITY]
+
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(
@@ -711,7 +792,32 @@ class RegistrationHandler(BaseHandler):
):
await self.store.upsert_monthly_active_user(user_id)
- await self._register_email_threepid(user_id, threepid, access_token)
+ await self.register_email_threepid(user_id, threepid, access_token)
+
+ if self.hs.config.bind_new_user_emails_to_sydent:
+ # Attempt to call Sydent's internal bind API on the given identity server
+ # to bind this threepid
+ id_server_url = self.hs.config.bind_new_user_emails_to_sydent
+
+ logger.debug(
+ "Attempting the bind email of %s to identity server: %s using "
+ "internal Sydent bind API.",
+ user_id,
+ self.hs.config.bind_new_user_emails_to_sydent,
+ )
+
+ try:
+ await self.identity_handler.bind_email_using_internal_sydent_api(
+ id_server_url, threepid["address"], user_id
+ )
+ except Exception as e:
+ logger.warning(
+ "Failed to bind email of '%s' to Sydent instance '%s' ",
+ "using Sydent internal bind API: %s",
+ user_id,
+ id_server_url,
+ e,
+ )
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
@@ -731,7 +837,7 @@ class RegistrationHandler(BaseHandler):
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
- async def _register_email_threepid(
+ async def register_email_threepid(
self, user_id: str, threepid: dict, token: Optional[str]
) -> None:
"""Add an email address as a 3pid identifier
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ee27d99135..d037742081 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -126,6 +126,10 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
+ self._invite_burst_count = (
+ hs.config.ratelimiting.rc_invites_per_room.burst_count
+ )
+
async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
) -> str:
@@ -359,7 +363,19 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not await self.spam_checker.user_may_create_room(user_id):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to create rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
+
+ if not is_requester_admin and not await self.spam_checker.user_may_create_room(
+ user_id, invite_list=[], third_party_invite_list=[], cloning=True
+ ):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -610,8 +626,14 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
+ invite_list = config.get("invite", [])
+ invite_3pid_list = config.get("invite_3pid", [])
+
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
- user_id
+ user_id,
+ invite_list=invite_list,
+ third_party_invite_list=invite_3pid_list,
+ cloning=False,
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -662,6 +684,9 @@ class RoomCreationHandler(BaseHandler):
invite_3pid_list = []
invite_list = []
+ if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count:
+ raise SynapseError(400, "Cannot invite so many users at once")
+
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")
@@ -796,6 +821,7 @@ class RoomCreationHandler(BaseHandler):
"invite",
ratelimit=False,
content=content,
+ new_room=True,
)
for invite_3pid in invite_3pid_list:
@@ -813,6 +839,7 @@ class RoomCreationHandler(BaseHandler):
id_server,
requester,
txn_id=None,
+ new_room=True,
id_access_token=id_access_token,
)
@@ -890,6 +917,7 @@ class RoomCreationHandler(BaseHandler):
"join",
ratelimit=ratelimit,
content=creator_join_profile,
+ new_room=True,
)
# We treat the power levels override specially as this needs to be one
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 14f14db449..373b9dcd0d 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -170,6 +170,7 @@ class RoomListHandler(BaseHandler):
"world_readable": room["history_visibility"]
== HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
+ "join_rule": room["join_rules"],
}
# Filter out Nones – rather omit the field altogether
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index e001e418f9..a92f7ba012 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016-2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import abc
import logging
import random
@@ -31,7 +31,15 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
+from synapse.types import (
+ JsonDict,
+ Requester,
+ RoomAlias,
+ RoomID,
+ StateMap,
+ UserID,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
@@ -85,6 +93,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
+ self._invites_per_room_limiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
+ burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
+ )
+ self._invites_per_user_limiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
+ burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
+ )
+
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
# it doesn't store state.
@@ -111,6 +130,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ async def remote_knock(
+ self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict,
+ ) -> Tuple[str, int]:
+ """Try and knock on a room that this server is not in
+
+ Args:
+ remote_room_hosts: List of servers that can be used to knock via.
+ room_id: Room that we are trying to knock on.
+ user: User who is trying to knock.
+ content: A dict that should be used as the content of the knock event.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
async def remote_reject_invite(
self,
invite_event_id: str,
@@ -134,6 +167,27 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ async def remote_rescind_knock(
+ self,
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """Rescind a local knock made on a remote room.
+
+ Args:
+ knock_event_id: The ID of the knock event to rescind.
+ txn_id: An optional transaction ID supplied by the client.
+ requester: The user making the request, according to the access token.
+ content: The content of the generated leave event.
+
+ Returns:
+ A tuple containing (event_id, stream_id of the leave event).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has left the
room.
@@ -144,6 +198,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
+ def ratelimit_invite(self, room_id: str, invitee_user_id: str):
+ """Ratelimit invites by room and by target user.
+ """
+ self._invites_per_room_limiter.ratelimit(room_id)
+ self._invites_per_user_limiter.ratelimit(invitee_user_id)
+
async def _local_membership_update(
self,
requester: Requester,
@@ -279,6 +339,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
) -> Tuple[str, int]:
"""Update a user's membership in a room.
@@ -319,6 +380,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
+ new_room=new_room,
require_consent=require_consent,
)
@@ -335,6 +397,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
) -> Tuple[str, int]:
"""Helper for update_membership.
@@ -387,8 +450,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(403, "This room has been blocked on this server")
if effective_membership_state == Membership.INVITE:
+ target_id = target.to_string()
+ if ratelimit:
+ self.ratelimit_invite(room_id, target_id)
+
# block any attempts to invite the server notices mxid
- if target.to_string() == self._server_notices_mxid:
+ if target_id == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
block_invite = False
@@ -411,8 +478,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
block_invite = True
+ is_published = await self.store.is_room_published(room_id)
+
if not await self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(), room_id
+ requester.user.to_string(),
+ target_id,
+ third_party_invite=None,
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
):
logger.info("Blocking invite due to spam checker")
block_invite = True
@@ -490,6 +564,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to join rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
+
+ inviter = await self._get_inviter(target.to_string(), room_id)
+ if not is_requester_admin:
+ # We assume that if the spam checker allowed the user to create
+ # a room then they're allowed to join it.
+ if not new_room and not self.spam_checker.user_may_join_room(
+ target.to_string(), room_id, is_invited=inviter is not None
+ ):
+ raise SynapseError(403, "Not allowed to join this room")
+
if not is_host_in_room:
if ratelimit:
time_now_s = self.clock.time()
@@ -527,50 +620,76 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
- # perhaps we've been invited
+ # Figure out the user's current membership state for the room
(
current_membership_type,
current_membership_event_id,
) = await self.store.get_local_current_membership_for_user_in_room(
target.to_string(), room_id
)
- if (
- current_membership_type != Membership.INVITE
- or not current_membership_event_id
- ):
+ if not current_membership_type or not current_membership_event_id:
logger.info(
"%s sent a leave request to %s, but that is not an active room "
- "on this server, and there is no pending invite",
+ "on this server, or there is no pending invite or knock",
target,
room_id,
)
raise SynapseError(404, "Not a known room")
- invite = await self.store.get_event(current_membership_event_id)
- logger.info(
- "%s rejects invite to %s from %s", target, room_id, invite.sender
- )
+ # perhaps we've been invited
+ if current_membership_type == Membership.INVITE:
+ invite = await self.store.get_event(current_membership_event_id)
+ logger.info(
+ "%s rejects invite to %s from %s",
+ target,
+ room_id,
+ invite.sender,
+ )
+
+ if not self.hs.is_mine_id(invite.sender):
+ # send the rejection to the inviter's HS (with fallback to
+ # local event)
+ return await self.remote_reject_invite(
+ invite.event_id, txn_id, requester, content,
+ )
- if not self.hs.is_mine_id(invite.sender):
- # send the rejection to the inviter's HS (with fallback to
- # local event)
- return await self.remote_reject_invite(
- invite.event_id, txn_id, requester, content,
+ # the inviter was on our server, but has now left. Carry on
+ # with the normal rejection codepath, which will also send the
+ # rejection out to any other servers we believe are still in the room.
+
+ # thanks to overzealous cleaning up of event_forward_extremities in
+ # `delete_old_current_state_events`, it's possible to end up with no
+ # forward extremities here. If that happens, let's just hang the
+ # rejection off the invite event.
+ #
+ # see: https://github.com/matrix-org/synapse/issues/7139
+ if len(latest_event_ids) == 0:
+ latest_event_ids = [invite.event_id]
+
+ # or perhaps this is a remote room that a local user has knocked on
+ elif current_membership_type == Membership.KNOCK:
+ knock = await self.store.get_event(current_membership_event_id)
+ return await self.remote_rescind_knock(
+ knock.event_id, txn_id, requester, content
)
- # the inviter was on our server, but has now left. Carry on
- # with the normal rejection codepath, which will also send the
- # rejection out to any other servers we believe are still in the room.
+ elif effective_membership_state == Membership.KNOCK:
+ if not is_host_in_room:
+ # The knock needs to be sent over federation instead
+ remote_room_hosts.append(get_domain_from_id(room_id))
- # thanks to overzealous cleaning up of event_forward_extremities in
- # `delete_old_current_state_events`, it's possible to end up with no
- # forward extremities here. If that happens, let's just hang the
- # rejection off the invite event.
- #
- # see: https://github.com/matrix-org/synapse/issues/7139
- if len(latest_event_ids) == 0:
- latest_event_ids = [invite.event_id]
+ content["membership"] = Membership.KNOCK
+
+ profile = self.profile_handler
+ if "displayname" not in content:
+ content["displayname"] = await profile.get_displayname(target)
+ if "avatar_url" not in content:
+ content["avatar_url"] = await profile.get_avatar_url(target)
+
+ return await self.remote_knock(
+ remote_room_hosts, room_id, target, content
+ )
return await self._local_membership_update(
requester=requester,
@@ -786,6 +905,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
id_server: str,
requester: Requester,
txn_id: Optional[str],
+ new_room: bool = False,
id_access_token: Optional[str] = None,
) -> int:
"""Invite a 3PID to a room.
@@ -833,6 +953,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
Codes.FORBIDDEN,
)
+ can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
+ medium, address, room_id
+ )
+ if not can_invite:
+ raise SynapseError(
+ 403,
+ "This third-party identifier can not be invited in this room",
+ Codes.FORBIDDEN,
+ )
+
if not self._enable_lookup:
raise SynapseError(
403, "Looking up third-party identifiers is denied from this server"
@@ -842,6 +972,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
id_server, medium, address, id_access_token
)
+ is_published = await self.store.is_room_published(room_id)
+
+ if not await self.spam_checker.user_may_invite(
+ requester.user.to_string(),
+ invitee,
+ third_party_invite={"medium": medium, "address": address},
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ raise SynapseError(403, "Invites have been disabled on this server")
+
if invitee:
# Note that update_membership with an action of "invite" can raise
# a ShadowBanError, but this was done above already.
@@ -1131,6 +1274,35 @@ class RoomMemberMasterHandler(RoomMemberHandler):
invite_event, txn_id, requester, content
)
+ async def remote_rescind_knock(
+ self,
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """
+ Rescinds a local knock made on a remote room
+
+ Args:
+ knock_event_id: The ID of the knock event to rescind.
+ txn_id: The transaction ID to use.
+ requester: The originator of the request.
+ content: The content of the leave event.
+
+ Implements RoomMemberHandler.remote_rescind_knock
+ """
+ # TODO: We don't yet support rescinding knocks over federation
+ # as we don't know which homeserver to send it to. An obvious
+ # candidate is the remote homeserver we originally knocked through,
+ # however we don't currently store that information.
+
+ # Just rescind the knock locally
+ knock_event = await self.store.get_event(knock_event_id)
+ return await self._generate_local_out_of_band_leave(
+ knock_event, txn_id, requester, content
+ )
+
async def _generate_local_out_of_band_leave(
self,
previous_membership_event: EventBase,
@@ -1191,6 +1363,32 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return result_event.event_id, result_event.internal_metadata.stream_ordering
+ async def remote_knock(
+ self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict,
+ ) -> Tuple[str, int]:
+ """Sends a knock to a room. Attempts to do so via one remote out of a given list.
+
+ Args:
+ remote_room_hosts: A list of homeservers to try knocking through.
+ room_id: The ID of the room to knock on.
+ user: The user to knock on behalf of.
+ content: The content of the knock event.
+
+ Returns:
+ A tuple of (event ID, stream ID).
+ """
+ # filter ourselves out of remote_room_hosts
+ remote_room_hosts = [
+ host for host in remote_room_hosts if host != self.hs.hostname
+ ]
+
+ if len(remote_room_hosts) == 0:
+ raise SynapseError(404, "No known servers")
+
+ return await self.federation_handler.do_knock(
+ remote_room_hosts, room_id, user.to_string(), content=content
+ )
+
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index f2e88f6a5b..3de63e885e 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,10 +21,12 @@ from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
from synapse.replication.http.membership import (
ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
+ ReplicationRemoteKnockRestServlet as ReplRemoteKnock,
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
+ ReplicationRemoteRescindKnockRestServlet as ReplRescindKnock,
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
)
-from synapse.types import Requester, UserID
+from synapse.types import JsonDict, Requester, UserID
logger = logging.getLogger(__name__)
@@ -33,7 +36,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
super().__init__(hs)
self._remote_join_client = ReplRemoteJoin.make_client(hs)
+ self._remote_knock_client = ReplRemoteKnock.make_client(hs)
self._remote_reject_client = ReplRejectInvite.make_client(hs)
+ self._remote_rescind_client = ReplRescindKnock.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs)
async def _remote_join(
@@ -79,6 +84,49 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
)
return ret["event_id"], ret["stream_id"]
+ async def remote_rescind_knock(
+ self,
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """
+ Rescinds a local knock made on a remote room
+
+ Args:
+ knock_event_id: the knock event
+ txn_id: optional transaction ID supplied by the client
+ requester: user making the request, according to the access token
+ content: additional content to include in the leave event.
+ Normally an empty dict.
+
+ Returns:
+ A tuple containing (event_id, stream_id of the leave event)
+ """
+ ret = await self._remote_rescind_client(
+ knock_event_id=knock_event_id,
+ txn_id=txn_id,
+ requester=requester,
+ content=content,
+ )
+ return ret["event_id"], ret["stream_id"]
+
+ async def remote_knock(
+ self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict,
+ ) -> Tuple[str, int]:
+ """Sends a knock to a room.
+
+ Implements RoomMemberHandler.remote_knock
+ """
+ ret = await self._remote_knock_client(
+ remote_room_hosts=remote_room_hosts,
+ room_id=room_id,
+ user=user,
+ content=content,
+ )
+ return ret["event_id"], ret["stream_id"]
+
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 66f1bbcfc4..94062e79cb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -15,23 +15,28 @@
import itertools
import logging
-from typing import Iterable
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
+from synapse.events import EventBase
from synapse.storage.state import StateFilter
+from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
@@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
return historical_room_ids
- async def search(self, user, content, batch=None):
+ async def search(
+ self, user: UserID, content: JsonDict, batch: Optional[str] = None
+ ) -> JsonDict:
"""Performs a full text search for a user.
Args:
- user (UserID)
- content (dict): Search parameters
- batch (str): The next_batch parameter. Used for pagination.
+ user
+ content: Search parameters
+ batch: The next_batch parameter. Used for pagination.
Returns:
dict to be returned to the client with results of search
@@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
if search_filter.rooms:
- historical_room_ids = []
+ historical_room_ids = [] # type: List[str]
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
rank_map = {} # event_id -> rank of event
allowed_events = []
- room_groups = {} # Holds result of grouping by room, if applicable
- sender_group = {} # Holds result of grouping by sender, if applicable
+ # Holds result of grouping by room, if applicable
+ room_groups = {} # type: Dict[str, JsonDict]
+ # Holds result of grouping by sender, if applicable
+ sender_group = {} # type: Dict[str, JsonDict]
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
@@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id)
elif order_by == "recent":
- room_events = []
+ room_events = [] # type: List[EventBase]
i = 0
pagination_token = batch_token
@@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
state_results = {}
if include_state:
- rooms = {e.room_id for e in allowed_events}
- for room_id in rooms:
+ for room_id in {e.room_id for e in allowed_events}:
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
- state_results.values()
-
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong
@@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
if state_results:
s = {}
- for room_id, state in state_results.items():
+ for room_id, state_events in state_results.items():
s[room_id] = await self._event_serializer.serialize_events(
- state, time_now
+ state_events, time_now
)
rooms_cat_res["state"] = s
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index a5d67f828f..cef6b3ae48 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,24 +14,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
- self._password_policy_handler = hs.get_password_policy_handler()
async def set_password(
self,
@@ -38,7 +41,7 @@ class SetPasswordHandler(BaseHandler):
password_hash: str,
logout_devices: bool,
requester: Optional[Requester] = None,
- ):
+ ) -> None:
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index fb4f70e8e2..b3f9875358 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -14,15 +14,25 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class StateDeltasHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+ async def _get_key_change(
+ self,
+ prev_event_id: Optional[str],
+ event_id: Optional[str],
+ key_name: str,
+ public_value: str,
+ ) -> Optional[bool]:
"""Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so.
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index dc62b21c06..0b5e62da1b 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
+# Copyright 2020 Sorunome
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,13 +14,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
+
+from typing_extensions import Counter as CounterType
from synapse.api.constants import EventTypes, Membership
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -31,7 +39,7 @@ class StatsHandler:
Heavily derived from UserDirectoryHandler
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@@ -44,7 +52,7 @@ class StatsHandler:
self.stats_enabled = hs.config.stats_enabled
# The current position in the current_state_delta stream
- self.pos = None
+ self.pos = None # type: Optional[int]
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@@ -56,7 +64,7 @@ class StatsHandler:
# we start populating stats
self.clock.call_later(0, self.notify_new_event)
- def notify_new_event(self):
+ def notify_new_event(self) -> None:
"""Called when there may be more deltas to process
"""
if not self.stats_enabled or self._is_processing:
@@ -72,7 +80,7 @@ class StatsHandler:
run_as_background_process("stats.notify_new_event", process)
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_stats_positions()
@@ -110,10 +118,10 @@ class StatsHandler:
)
for room_id, fields in room_count.items():
- room_deltas.setdefault(room_id, {}).update(fields)
+ room_deltas.setdefault(room_id, Counter()).update(fields)
for user_id, fields in user_count.items():
- user_deltas.setdefault(user_id, {}).update(fields)
+ user_deltas.setdefault(user_id, Counter()).update(fields)
logger.debug("room_deltas: %s", room_deltas)
logger.debug("user_deltas: %s", user_deltas)
@@ -131,19 +139,20 @@ class StatsHandler:
self.pos = max_pos
- async def _handle_deltas(self, deltas):
+ async def _handle_deltas(
+ self, deltas: Iterable[JsonDict]
+ ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
"""Called with the state deltas to process
Returns:
- tuple[dict[str, Counter], dict[str, counter]]
Two dicts: the room deltas and the user deltas,
mapping from room/user ID to changes in the various fields.
"""
- room_to_stats_deltas = {}
- user_to_stats_deltas = {}
+ room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
+ user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
- room_to_state_updates = {}
+ room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
for delta in deltas:
typ = delta["type"]
@@ -173,7 +182,7 @@ class StatsHandler:
)
continue
- event_content = {}
+ event_content = {} # type: JsonDict
sender = None
if event_id is not None:
@@ -225,6 +234,8 @@ class StatsHandler:
room_stats_delta["left_members"] -= 1
elif prev_membership == Membership.BAN:
room_stats_delta["banned_members"] -= 1
+ elif prev_membership == Membership.KNOCK:
+ room_stats_delta["knocked_members"] -= 1
else:
raise ValueError(
"%r is not a valid prev_membership" % (prev_membership,)
@@ -246,6 +257,8 @@ class StatsHandler:
room_stats_delta["left_members"] += 1
elif membership == Membership.BAN:
room_stats_delta["banned_members"] += 1
+ elif membership == Membership.KNOCK:
+ room_stats_delta["knocked_members"] += 1
else:
raise ValueError("%r is not a valid membership" % (membership,))
@@ -257,13 +270,13 @@ class StatsHandler:
)
if has_changed_joinedness:
- delta = +1 if membership == Membership.JOIN else -1
+ membership_delta = +1 if membership == Membership.JOIN else -1
user_to_stats_deltas.setdefault(user_id, Counter())[
"joined_rooms"
- ] += delta
+ ] += membership_delta
- room_stats_delta["local_users_in_room"] += delta
+ room_stats_delta["local_users_in_room"] += membership_delta
elif typ == EventTypes.Create:
room_state["is_federatable"] = (
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5c7590f38e..e8947e0f9b 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -151,6 +151,16 @@ class InvitedSyncResult:
@attr.s(slots=True, frozen=True)
+class KnockedSyncResult:
+ room_id = attr.ib(type=str)
+ knock = attr.ib(type=EventBase)
+
+ def __bool__(self) -> bool:
+ """Knocked rooms should always be reported to the client"""
+ return True
+
+
+@attr.s(slots=True, frozen=True)
class GroupsSyncResult:
join = attr.ib(type=JsonDict)
invite = attr.ib(type=JsonDict)
@@ -183,6 +193,7 @@ class _RoomChanges:
room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
invited = attr.ib(type=List[InvitedSyncResult])
+ knocked = attr.ib(type=List[KnockedSyncResult])
newly_joined_rooms = attr.ib(type=List[str])
newly_left_rooms = attr.ib(type=List[str])
@@ -196,6 +207,7 @@ class SyncResult:
account_data: List of account_data events for the user.
joined: JoinedSyncResult for each joined room.
invited: InvitedSyncResult for each invited room.
+ knocked: KnockedSyncResult for each knocked on room.
archived: ArchivedSyncResult for each archived room.
to_device: List of direct messages for the device.
device_lists: List of user_ids whose devices have changed
@@ -211,6 +223,7 @@ class SyncResult:
account_data = attr.ib(type=List[JsonDict])
joined = attr.ib(type=List[JoinedSyncResult])
invited = attr.ib(type=List[InvitedSyncResult])
+ knocked = attr.ib(type=List[KnockedSyncResult])
archived = attr.ib(type=List[ArchivedSyncResult])
to_device = attr.ib(type=List[JsonDict])
device_lists = attr.ib(type=DeviceLists)
@@ -227,6 +240,7 @@ class SyncResult:
self.presence
or self.joined
or self.invited
+ or self.knocked
or self.archived
or self.account_data
or self.to_device
@@ -999,7 +1013,7 @@ class SyncHandler:
res = await self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
- newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
+ newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res
_, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
@@ -1008,7 +1022,9 @@ class SyncHandler:
if self.hs_config.use_presence and not block_all_presence_data:
logger.debug("Fetching presence data")
await self._generate_sync_entry_for_presence(
- sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
+ sync_result_builder,
+ newly_joined_rooms,
+ newly_joined_or_invited_or_knocked_users,
)
logger.debug("Fetching to-device data")
@@ -1017,7 +1033,7 @@ class SyncHandler:
device_lists = await self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
- newly_joined_or_invited_users=newly_joined_or_invited_users,
+ newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
@@ -1051,6 +1067,7 @@ class SyncHandler:
account_data=sync_result_builder.account_data,
joined=sync_result_builder.joined,
invited=sync_result_builder.invited,
+ knocked=sync_result_builder.knocked,
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
@@ -1110,7 +1127,7 @@ class SyncHandler:
self,
sync_result_builder: "SyncResultBuilder",
newly_joined_rooms: Set[str],
- newly_joined_or_invited_users: Set[str],
+ newly_joined_or_invited_or_knocked_users: Set[str],
newly_left_rooms: Set[str],
newly_left_users: Set[str],
) -> DeviceLists:
@@ -1119,8 +1136,9 @@ class SyncHandler:
Args:
sync_result_builder
newly_joined_rooms: Set of rooms user has joined since previous sync
- newly_joined_or_invited_users: Set of users that have joined or
- been invited to a room since previous sync.
+ newly_joined_or_invited_or_knocked_users: Set of users that have joined,
+ been invited to a room or are knocking on a room since
+ previous sync.
newly_left_rooms: Set of rooms user has left since previous sync
newly_left_users: Set of users that have left a room we're in since
previous sync
@@ -1131,7 +1149,9 @@ class SyncHandler:
# We're going to mutate these fields, so lets copy them rather than
# assume they won't get used later.
- newly_joined_or_invited_users = set(newly_joined_or_invited_users)
+ newly_joined_or_invited_or_knocked_users = set(
+ newly_joined_or_invited_or_knocked_users
+ )
newly_left_users = set(newly_left_users)
if since_token and since_token.device_list_key:
@@ -1170,11 +1190,11 @@ class SyncHandler:
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
joined_users = await self.state.get_current_users_in_room(room_id)
- newly_joined_or_invited_users.update(joined_users)
+ newly_joined_or_invited_or_knocked_users.update(joined_users)
# TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined.
- users_that_have_changed.update(newly_joined_or_invited_users)
+ users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
user_signatures_changed = await self.store.get_users_whose_signatures_changed(
user_id, since_token.device_list_key
@@ -1419,6 +1439,7 @@ class SyncHandler:
room_entries = room_changes.room_entries
invited = room_changes.invited
+ knocked = room_changes.knocked
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
@@ -1439,9 +1460,10 @@ class SyncHandler:
await concurrently_execute(handle_room_entries, room_entries, 10)
sync_result_builder.invited.extend(invited)
+ sync_result_builder.knocked.extend(knocked)
- # Now we want to get any newly joined or invited users
- newly_joined_or_invited_users = set()
+ # Now we want to get any newly joined, invited or knocking users
+ newly_joined_or_invited_or_knocked_users = set()
newly_left_users = set()
if since_token:
for joined_sync in sync_result_builder.joined:
@@ -1453,19 +1475,22 @@ class SyncHandler:
if (
event.membership == Membership.JOIN
or event.membership == Membership.INVITE
+ or event.membership == Membership.KNOCK
):
- newly_joined_or_invited_users.add(event.state_key)
+ newly_joined_or_invited_or_knocked_users.add(
+ event.state_key
+ )
else:
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership == Membership.JOIN:
newly_left_users.add(event.state_key)
- newly_left_users -= newly_joined_or_invited_users
+ newly_left_users -= newly_joined_or_invited_or_knocked_users
return (
set(newly_joined_rooms),
- newly_joined_or_invited_users,
+ newly_joined_or_invited_or_knocked_users,
set(newly_left_rooms),
newly_left_users,
)
@@ -1521,6 +1546,7 @@ class SyncHandler:
newly_left_rooms = []
room_entries = []
invited = []
+ knocked = []
for room_id, events in mem_change_events_by_room_id.items():
logger.debug(
"Membership changes in %s: [%s]",
@@ -1600,9 +1626,17 @@ class SyncHandler:
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
if event.sender not in ignored_users:
- room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
- if room_sync:
- invited.append(room_sync)
+ invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
+ if invite_room_sync:
+ invited.append(invite_room_sync)
+
+ # Only bother if our latest membership in the room is knock (and we haven't
+ # been accepted/rejected in the meantime).
+ should_knock = non_joins[-1].membership == Membership.KNOCK
+ if should_knock:
+ knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
+ if knock_room_sync:
+ knocked.append(knock_room_sync)
# Always include leave/ban events. Just take the last one.
# TODO: How do we handle ban -> leave in same batch?
@@ -1706,7 +1740,9 @@ class SyncHandler:
)
room_entries.append(entry)
- return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms)
+ return _RoomChanges(
+ room_entries, invited, knocked, newly_joined_rooms, newly_left_rooms,
+ )
async def _get_all_rooms(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
@@ -1726,6 +1762,7 @@ class SyncHandler:
membership_list = (
Membership.INVITE,
+ Membership.KNOCK,
Membership.JOIN,
Membership.LEAVE,
Membership.BAN,
@@ -1737,6 +1774,7 @@ class SyncHandler:
room_entries = []
invited = []
+ knocked = []
for event in room_list:
if event.membership == Membership.JOIN:
@@ -1756,8 +1794,11 @@ class SyncHandler:
continue
invite = await self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
+ elif event.membership == Membership.KNOCK:
+ knock = await self.store.get_event(event.event_id)
+ knocked.append(KnockedSyncResult(room_id=event.room_id, knock=knock))
elif event.membership in (Membership.LEAVE, Membership.BAN):
- # Always send down rooms we were banned or kicked from.
+ # Always send down rooms we were banned from or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if user_id == event.sender:
@@ -1778,7 +1819,7 @@ class SyncHandler:
)
)
- return _RoomChanges(room_entries, invited, [], [])
+ return _RoomChanges(room_entries, invited, knocked, [], [])
async def _generate_room_entry(
self,
@@ -2067,6 +2108,7 @@ class SyncResultBuilder:
account_data (list)
joined (list[JoinedSyncResult])
invited (list[InvitedSyncResult])
+ knocked (list[KnockedSyncResult])
archived (list[ArchivedSyncResult])
groups (GroupsSyncResult|None)
to_device (list)
@@ -2082,6 +2124,7 @@ class SyncResultBuilder:
account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
+ knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list))
archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e919a8f9ed..3f0dfc7a74 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,13 +15,13 @@
import logging
import random
from collections import namedtuple
-from typing import TYPE_CHECKING, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -65,17 +65,17 @@ class FollowerTypingHandler:
)
# map room IDs to serial numbers
- self._room_serials = {}
+ self._room_serials = {} # type: Dict[str, int]
# map room IDs to sets of users currently typing
- self._room_typing = {}
+ self._room_typing = {} # type: Dict[str, Set[str]]
- self._member_last_federation_poke = {}
+ self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
self.clock.looping_call(self._handle_timeouts, 5000)
- def _reset(self):
+ def _reset(self) -> None:
"""Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
- def _handle_timeouts(self):
+ def _handle_timeouts(self) -> None:
logger.debug("Checking for typing timeouts")
now = self.clock.time_msec()
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
for member in members:
self._handle_timeout_for_member(now, member)
- def _handle_timeout_for_member(self, now: int, member: RoomMember):
+ def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
if not self.is_typing(member):
# Nothing to do if they're no longer typing
return
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
# each person typing.
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
- def is_typing(self, member):
+ def is_typing(self, member: RoomMember) -> bool:
return member.user_id in self._room_typing.get(member.room_id, [])
- async def _push_remote(self, member, typing):
+ async def _push_remote(self, member: RoomMember, typing: bool) -> None:
if not self.federation:
return
@@ -148,7 +148,7 @@ class FollowerTypingHandler:
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
- ):
+ ) -> None:
"""Should be called whenever we receive updates for typing stream.
"""
@@ -178,7 +178,7 @@ class FollowerTypingHandler:
async def _send_changes_in_typing_to_remotes(
self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
- ):
+ ) -> None:
"""Process a change in typing of a room from replication, sending EDUs
for any local users.
"""
@@ -194,12 +194,12 @@ class FollowerTypingHandler:
if self.is_mine_id(user_id):
await self._push_remote(RoomMember(room_id, user_id), False)
- def get_current_token(self):
+ def get_current_token(self) -> int:
return self._latest_room_serial
class TypingWriterHandler(FollowerTypingHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
assert hs.config.worker.writers.typing == hs.get_instance_name()
@@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
- self._member_typing_until = {} # clock time we expect to stop
+ # clock time we expect to stop
+ self._member_typing_until = {} # type: Dict[RoomMember, int]
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", self._latest_room_serial
)
- def _handle_timeout_for_member(self, now: int, member: RoomMember):
+ def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
super()._handle_timeout_for_member(now, member)
if not self.is_typing(member):
@@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
return
- async def started_typing(self, target_user, requester, room_id, timeout):
+ async def started_typing(
+ self, target_user: UserID, requester: Requester, room_id: str, timeout: int
+ ) -> None:
target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
@@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler):
if was_present:
# No point sending another notification
- return None
+ return
self._push_update(member=member, typing=True)
- async def stopped_typing(self, target_user, requester, room_id):
+ async def stopped_typing(
+ self, target_user: UserID, requester: Requester, room_id: str
+ ) -> None:
target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
@@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
- def user_left_room(self, user, room_id):
+ def user_left_room(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
if self.is_mine_id(user_id):
member = RoomMember(room_id=room_id, user_id=user_id)
self._stopped_typing(member)
- def _stopped_typing(self, member):
+ def _stopped_typing(self, member: RoomMember) -> None:
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
- return None
+ return
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
self._push_update(member=member, typing=False)
- def _push_update(self, member, typing):
+ def _push_update(self, member: RoomMember, typing: bool) -> None:
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
run_as_background_process(
@@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self._push_update_local(member=member, typing=typing)
- async def _recv_edu(self, origin, content):
+ async def _recv_edu(self, origin: str, content: JsonDict) -> None:
room_id = content["room_id"]
user_id = content["user_id"]
@@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
self._push_update_local(member=member, typing=content["typing"])
- def _push_update_local(self, member, typing):
+ def _push_update_local(self, member: RoomMember, typing: bool) -> None:
room_set = self._room_typing.setdefault(member.room_id, set())
if typing:
room_set.add(member.user_id)
@@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler):
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
last_id
- )
+ ) # type: Optional[Iterable[str]]
if changed_rooms is None:
changed_rooms = self._room_serials
@@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler):
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
- ):
+ ) -> None:
# The writing process should never get updates from replication.
raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
# We can't call get_typing_handler here because there's a cycle:
@@ -427,7 +432,7 @@ class TypingNotificationEventSource:
#
self.get_typing_handler = hs.get_typing_handler
- def _make_event_for(self, room_id):
+ def _make_event_for(self, room_id: str) -> JsonDict:
typing = self.get_typing_handler()._room_typing[room_id]
return {
"type": "m.typing",
@@ -462,7 +467,9 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
- async def get_new_events(self, from_key, room_ids, **kwargs):
+ async def get_new_events(
+ self, from_key: int, room_ids: Iterable[str], **kwargs
+ ) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
handler = self.get_typing_handler()
@@ -478,5 +485,5 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
- def get_current_key(self):
+ def get_current_key(self) -> int:
return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d4651c8348..8aedf5072e 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler):
if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos()
- # If still None then the initial background update hasn't happened yet
- if self.pos is None:
- return None
-
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
@@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler):
if change: # The user joined
event = await self.store.get_event(event_id, allow_none=True)
+ # It isn't expected for this event to not exist, but we
+ # don't want the entire background process to break.
+ if event is None:
+ continue
+
profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"),
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 37ccf5ab98..8eb93ba73e 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -289,8 +289,7 @@ class SimpleHttpClient:
treq_args: Dict[str, Any] = {},
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
- http_proxy: Optional[bytes] = None,
- https_proxy: Optional[bytes] = None,
+ use_proxy: bool = False,
):
"""
Args:
@@ -300,8 +299,8 @@ class SimpleHttpClient:
we may not request.
ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
- http_proxy: proxy server to use for http connections. host[:port]
- https_proxy: proxy server to use for https connections. host[:port]
+ use_proxy: Whether proxy settings should be discovered and used
+ from conventional environment variables.
"""
self.hs = hs
@@ -345,8 +344,7 @@ class SimpleHttpClient:
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
- http_proxy=http_proxy,
- https_proxy=https_proxy,
+ use_proxy=use_proxy,
)
if self._ip_blacklist:
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 856e28454f..b797e3ce80 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -19,9 +19,10 @@ from zope.interface import implementer
from twisted.internet import defer, protocol
from twisted.internet.error import ConnectError
-from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.internet.protocol import connectionDone
+from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
from twisted.web import http
+from twisted.web.http_headers import Headers
logger = logging.getLogger(__name__)
@@ -43,23 +44,33 @@ class HTTPConnectProxyEndpoint:
Args:
reactor: the Twisted reactor to use for the connection
- proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
- proxy
- host (bytes): hostname that we want to CONNECT to
- port (int): port that we want to connect to
+ proxy_endpoint: the endpoint to use to connect to the proxy
+ host: hostname that we want to CONNECT to
+ port: port that we want to connect to
+ headers: Extra HTTP headers to include in the CONNECT request
"""
- def __init__(self, reactor, proxy_endpoint, host, port):
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ proxy_endpoint: IStreamClientEndpoint,
+ host: bytes,
+ port: int,
+ headers: Headers,
+ ):
self._reactor = reactor
self._proxy_endpoint = proxy_endpoint
self._host = host
self._port = port
+ self._headers = headers
def __repr__(self):
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
- def connect(self, protocolFactory):
- f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
+ def connect(self, protocolFactory: ClientFactory):
+ f = HTTPProxiedClientFactory(
+ self._host, self._port, protocolFactory, self._headers
+ )
d = self._proxy_endpoint.connect(f)
# once the tcp socket connects successfully, we need to wait for the
# CONNECT to complete.
@@ -74,15 +85,23 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
HTTP Protocol object and run the rest of the connection.
Args:
- dst_host (bytes): hostname that we want to CONNECT to
- dst_port (int): port that we want to connect to
- wrapped_factory (protocol.ClientFactory): The original Factory
+ dst_host: hostname that we want to CONNECT to
+ dst_port: port that we want to connect to
+ wrapped_factory: The original Factory
+ headers: Extra HTTP headers to include in the CONNECT request
"""
- def __init__(self, dst_host, dst_port, wrapped_factory):
+ def __init__(
+ self,
+ dst_host: bytes,
+ dst_port: int,
+ wrapped_factory: ClientFactory,
+ headers: Headers,
+ ):
self.dst_host = dst_host
self.dst_port = dst_port
self.wrapped_factory = wrapped_factory
+ self.headers = headers
self.on_connection = defer.Deferred()
def startedConnecting(self, connector):
@@ -92,7 +111,11 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
return HTTPConnectProtocol(
- self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
+ self.dst_host,
+ self.dst_port,
+ wrapped_protocol,
+ self.on_connection,
+ self.headers,
)
def clientConnectionFailed(self, connector, reason):
@@ -112,24 +135,37 @@ class HTTPConnectProtocol(protocol.Protocol):
"""Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
Args:
- host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
+ host: The original HTTP(s) hostname or IPv4 or IPv6 address literal
to put in the CONNECT request
- port (int): The original HTTP(s) port to put in the CONNECT request
+ port: The original HTTP(s) port to put in the CONNECT request
- wrapped_protocol (interfaces.IProtocol): the original protocol (probably
- HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
+ wrapped_protocol: the original protocol (probably HTTPChannel or
+ TLSMemoryBIOProtocol, but could be anything really)
- connected_deferred (Deferred): a Deferred which will be callbacked with
+ connected_deferred: a Deferred which will be callbacked with
wrapped_protocol when the CONNECT completes
+
+ headers: Extra HTTP headers to include in the CONNECT request
"""
- def __init__(self, host, port, wrapped_protocol, connected_deferred):
+ def __init__(
+ self,
+ host: bytes,
+ port: int,
+ wrapped_protocol: Protocol,
+ connected_deferred: defer.Deferred,
+ headers: Headers,
+ ):
self.host = host
self.port = port
self.wrapped_protocol = wrapped_protocol
self.connected_deferred = connected_deferred
- self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
+ self.headers = headers
+
+ self.http_setup_client = HTTPConnectSetupClient(
+ self.host, self.port, self.headers
+ )
self.http_setup_client.on_connected.addCallback(self.proxyConnected)
def connectionMade(self):
@@ -154,7 +190,7 @@ class HTTPConnectProtocol(protocol.Protocol):
if buf:
self.wrapped_protocol.dataReceived(buf)
- def dataReceived(self, data):
+ def dataReceived(self, data: bytes):
# if we've set up the HTTP protocol, we can send the data there
if self.wrapped_protocol.connected:
return self.wrapped_protocol.dataReceived(data)
@@ -168,21 +204,29 @@ class HTTPConnectSetupClient(http.HTTPClient):
"""HTTPClient protocol to send a CONNECT message for proxies and read the response.
Args:
- host (bytes): The hostname to send in the CONNECT message
- port (int): The port to send in the CONNECT message
+ host: The hostname to send in the CONNECT message
+ port: The port to send in the CONNECT message
+ headers: Extra headers to send with the CONNECT message
"""
- def __init__(self, host, port):
+ def __init__(self, host: bytes, port: int, headers: Headers):
self.host = host
self.port = port
+ self.headers = headers
self.on_connected = defer.Deferred()
def connectionMade(self):
logger.debug("Connected to proxy, sending CONNECT")
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
+
+ # Send any additional specified headers
+ for name, values in self.headers.getAllRawHeaders():
+ for value in values:
+ self.sendHeader(name, value)
+
self.endHeaders()
- def handleStatus(self, version, status, message):
+ def handleStatus(self, version: bytes, status: bytes, message: bytes):
logger.debug("Got Status: %s %s %s", status, message, version)
if status != b"200":
raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index b730d2c634..ee65a6668b 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -12,9 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import base64
import logging
import re
+from typing import Optional, Tuple
+from urllib.request import getproxies_environment, proxy_bypass_environment
+import attr
from zope.interface import implementer
from twisted.internet import defer
@@ -22,6 +26,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.python.failure import Failure
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
from twisted.web.error import SchemeNotSupported
+from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
@@ -31,6 +36,22 @@ logger = logging.getLogger(__name__)
_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
+@attr.s
+class ProxyCredentials:
+ username_password = attr.ib(type=bytes)
+
+ def as_proxy_authorization_value(self) -> bytes:
+ """
+ Return the value for a Proxy-Authorization header (i.e. 'Basic abdef==').
+
+ Returns:
+ A transformation of the authentication string the encoded value for
+ a Proxy-Authorization header.
+ """
+ # Encode as base64 and prepend the authorization type
+ return b"Basic " + base64.encodebytes(self.username_password)
+
+
@implementer(IAgent)
class ProxyAgent(_AgentBase):
"""An Agent implementation which will use an HTTP proxy if one was requested
@@ -58,6 +79,9 @@ class ProxyAgent(_AgentBase):
pool (HTTPConnectionPool|None): connection pool to be used. If None, a
non-persistent pool instance will be created.
+
+ use_proxy (bool): Whether proxy settings should be discovered and used
+ from conventional environment variables.
"""
def __init__(
@@ -68,8 +92,7 @@ class ProxyAgent(_AgentBase):
connectTimeout=None,
bindAddress=None,
pool=None,
- http_proxy=None,
- https_proxy=None,
+ use_proxy=False,
):
_AgentBase.__init__(self, reactor, pool)
@@ -84,6 +107,18 @@ class ProxyAgent(_AgentBase):
if bindAddress is not None:
self._endpoint_kwargs["bindAddress"] = bindAddress
+ http_proxy = None
+ https_proxy = None
+ no_proxy = None
+ if use_proxy:
+ proxies = getproxies_environment()
+ http_proxy = proxies["http"].encode() if "http" in proxies else None
+ https_proxy = proxies["https"].encode() if "https" in proxies else None
+ no_proxy = proxies["no"] if "no" in proxies else None
+
+ # Parse credentials from https proxy connection string if present
+ self.https_proxy_creds, https_proxy = parse_username_password(https_proxy)
+
self.http_proxy_endpoint = _http_proxy_endpoint(
http_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
@@ -92,6 +127,8 @@ class ProxyAgent(_AgentBase):
https_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
+ self.no_proxy = no_proxy
+
self._policy_for_https = contextFactory
self._reactor = reactor
@@ -139,18 +176,43 @@ class ProxyAgent(_AgentBase):
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
request_path = parsed_uri.originForm
- if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
+ should_skip_proxy = False
+ if self.no_proxy is not None:
+ should_skip_proxy = proxy_bypass_environment(
+ parsed_uri.host.decode(), proxies={"no": self.no_proxy},
+ )
+
+ if (
+ parsed_uri.scheme == b"http"
+ and self.http_proxy_endpoint
+ and not should_skip_proxy
+ ):
# Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy:
pool_key = ("http-proxy", self.http_proxy_endpoint)
endpoint = self.http_proxy_endpoint
request_path = uri
- elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
+ elif (
+ parsed_uri.scheme == b"https"
+ and self.https_proxy_endpoint
+ and not should_skip_proxy
+ ):
+ connect_headers = Headers()
+
+ # Determine whether we need to set Proxy-Authorization headers
+ if self.https_proxy_creds:
+ # Set a Proxy-Authorization header
+ connect_headers.addRawHeader(
+ b"Proxy-Authorization",
+ self.https_proxy_creds.as_proxy_authorization_value(),
+ )
+
endpoint = HTTPConnectProxyEndpoint(
self.proxy_reactor,
self.https_proxy_endpoint,
parsed_uri.host,
parsed_uri.port,
+ headers=connect_headers,
)
else:
# not using a proxy
@@ -179,12 +241,16 @@ class ProxyAgent(_AgentBase):
)
-def _http_proxy_endpoint(proxy, reactor, **kwargs):
+def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs):
"""Parses an http proxy setting and returns an endpoint for the proxy
Args:
- proxy (bytes|None): the proxy setting
+ proxy: the proxy setting in the form: [<username>:<password>@]<host>[:<port>]
+ Note that compared to other apps, this function currently lacks support
+ for specifying a protocol schema (i.e. protocol://...).
+
reactor: reactor to be used to connect to the proxy
+
kwargs: other args to be passed to HostnameEndpoint
Returns:
@@ -194,16 +260,43 @@ def _http_proxy_endpoint(proxy, reactor, **kwargs):
if proxy is None:
return None
- # currently we only support hostname:port. Some apps also support
- # protocol://<host>[:port], which allows a way of requiring a TLS connection to the
- # proxy.
-
+ # Parse the connection string
host, port = parse_host_port(proxy, default_port=1080)
return HostnameEndpoint(reactor, host, port, **kwargs)
-def parse_host_port(hostport, default_port=None):
- # could have sworn we had one of these somewhere else...
+def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]:
+ """
+ Parses the username and password from a proxy declaration e.g
+ username:password@hostname:port.
+
+ Args:
+ proxy: The proxy connection string.
+
+ Returns
+ An instance of ProxyCredentials and the proxy connection string with any credentials
+ stripped, i.e u:p@host:port -> host:port. If no credentials were found, the
+ ProxyCredentials instance is replaced with None.
+ """
+ if proxy and b"@" in proxy:
+ # We use rsplit here as the password could contain an @ character
+ credentials, proxy_without_credentials = proxy.rsplit(b"@", 1)
+ return ProxyCredentials(credentials), proxy_without_credentials
+
+ return None, proxy
+
+
+def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]:
+ """
+ Parse the hostname and port from a proxy connection byte string.
+
+ Args:
+ hostport: The proxy connection string. Must be in the form 'host[:port]'.
+ default_port: The default port to return if one is not found in `hostport`.
+
+ Returns:
+ A tuple containing the hostname and port. Uses `default_port` if one was not found.
+ """
if b":" in hostport:
host, port = hostport.rsplit(b":", 1)
try:
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index b361b7cbaf..9bfe151b5f 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,8 +14,8 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
-
import logging
+from typing import Dict, List, Optional, Union
from synapse.api.errors import Codes, SynapseError
from synapse.util import json_decoder
@@ -147,16 +147,67 @@ def parse_string(
)
+def parse_list_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: Union[bytes, str],
+ encoding: Optional[str] = "ascii",
+):
+ """Parse and optionally decode a list of values from request query parameters.
+
+ Args:
+ args: A dictionary of query parameters from a request.
+ name: The name of the query parameter to extract values from. If given as bytes,
+ will be decoded as "ascii".
+ encoding: An optional encoding that is used to decode each parameter value with.
+
+ Raises:
+ KeyError: If the given `name` does not exist in `args`.
+ SynapseError: If an argument was not encoded with the specified `encoding`.
+ """
+ if not isinstance(name, bytes):
+ name = name.encode("ascii")
+ args_list = args[name]
+
+ if encoding:
+ # Decode each argument value
+ try:
+ args_list = [value.decode(encoding) for value in args_list]
+ except ValueError:
+ raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+
+ return args_list
+
+
def parse_string_from_args(
- args,
- name,
- default=None,
- required=False,
- allowed_values=None,
- param_type="string",
- encoding="ascii",
+ args: Dict[bytes, List[bytes]],
+ name: Union[bytes, str],
+ default: Optional[str] = None,
+ required: Optional[bool] = False,
+ allowed_values: Optional[List[bytes]] = None,
+ param_type: Optional[str] = "string",
+ encoding: Optional[str] = "ascii",
):
+ """Parse and optionally decode a single value from request query parameters.
+ Args:
+ args: A dictionary of query parameters from a request.
+ name: The name of the query parameter to extract values from. If given as bytes,
+ will be decoded as "ascii".
+ default: A default value to return if the given argument `name` was not found.
+ required: If this is True, no `default` is provided and the given argument `name`
+ was not found then a SynapseError is raised.
+ allowed_values: A list of allowed values. If specified and the found str is
+ not in this list, a SynapseError is raised.
+ param_type: The expected type of the query parameter's value.
+ encoding: An optional encoding that is used to decode each parameter value with.
+
+ Returns:
+ The found argument value.
+
+ Raises:
+ SynapseError: If the given name was not found in the request arguments,
+ the argument's values were encoded incorrectly or a required value was missing.
+ """
if not isinstance(name, bytes):
name = name.encode("ascii")
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index ab586c318c..0538350f38 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -791,7 +791,7 @@ def tag_args(func):
@wraps(func)
def _tag_args_inner(*args, **kwargs):
- argspec = inspect.getargspec(func)
+ argspec = inspect.getfullargspec(func)
for i, arg in enumerate(argspec.args[1:]):
set_tag("ARG_" + arg, args[i])
set_tag("args", args[len(argspec.args) :])
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 6211506990..4d284de133 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -495,7 +495,11 @@ BASE_APPEND_UNDERRIDE_RULES = [
"_id": "_message",
}
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight", "value": False},
+ ],
},
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
@@ -509,7 +513,11 @@ BASE_APPEND_UNDERRIDE_RULES = [
"_id": "_encrypted",
}
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight", "value": False},
+ ],
},
{
"rule_id": "global/underride/.im.vector.jitsi",
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 4d875dcb91..745b1dde94 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -668,6 +668,15 @@ class Mailer:
def safe_markup(raw_html: str) -> jinja2.Markup:
+ """
+ Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs.
+
+ Args
+ raw_html: Unsafe HTML.
+
+ Returns:
+ A Markup object ready to safely use in a Jinja template.
+ """
return jinja2.Markup(
bleach.linkify(
bleach.clean(
@@ -684,8 +693,13 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
def safe_text(raw_text: str) -> jinja2.Markup:
"""
- Process text: treat it as HTML but escape any tags (ie. just escape the
- HTML) then linkify it.
+ Sanitise text (escape any HTML tags), and then linkify any bare URLs.
+
+ Args
+ raw_text: Unsafe text which might include HTML markup.
+
+ Returns:
+ A Markup object ready to safely use in a Jinja template.
"""
return jinja2.Markup(
bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False))
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 7e50341d74..04c2c1482c 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -17,7 +17,7 @@ import logging
import re
from typing import TYPE_CHECKING, Dict, Iterable, Optional
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.types import StateMap
@@ -63,7 +63,7 @@ async def calculate_room_name(
m_room_name = await store.get_event(
room_state_ids[(EventTypes.Name, "")], allow_none=True
)
- if m_room_name and m_room_name.content and m_room_name.content["name"]:
+ if m_room_name and m_room_name.content and m_room_name.content.get("name"):
return m_room_name.content["name"]
# does it have a canonical alias?
@@ -74,15 +74,11 @@ async def calculate_room_name(
if (
canon_alias
and canon_alias.content
- and canon_alias.content["alias"]
+ and canon_alias.content.get("alias")
and _looks_like_an_alias(canon_alias.content["alias"])
):
return canon_alias.content["alias"]
- # at this point we're going to need to search the state by all state keys
- # for an event type, so rearrange the data structure
- room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
-
if not fallback_to_members:
return None
@@ -94,7 +90,7 @@ async def calculate_room_name(
if (
my_member_event is not None
- and my_member_event.content["membership"] == "invite"
+ and my_member_event.content.get("membership") == Membership.INVITE
):
if (EventTypes.Member, my_member_event.sender) in room_state_ids:
inviter_member_event = await store.get_event(
@@ -111,6 +107,10 @@ async def calculate_room_name(
else:
return "Room Invite"
+ # at this point we're going to need to search the state by all state keys
+ # for an event type, so rearrange the data structure
+ room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
+
# we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user.
if EventTypes.Member in room_state_bytype_ids:
@@ -120,8 +120,8 @@ async def calculate_room_name(
all_members = [
ev
for ev in member_events.values()
- if ev.content["membership"] == "join"
- or ev.content["membership"] == "invite"
+ if ev.content.get("membership") == Membership.JOIN
+ or ev.content.get("membership") == Membership.INVITE
]
# Sort the member events oldest-first so the we name people in the
# order the joined (it should at least be deterministic rather than
@@ -194,11 +194,7 @@ def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
def name_from_member_event(member_event: EventBase) -> str:
- if (
- member_event.content
- and "displayname" in member_event.content
- and member_event.content["displayname"]
- ):
+ if member_event.content and member_event.content.get("displayname"):
return member_event.content["displayname"]
return member_event.state_key
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index eed16dbfb5..3e843c97fe 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -62,7 +62,7 @@ class PusherPool:
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
- self._account_validity = hs.config.account_validity
+ self._account_validity_enabled = hs.config.account_validity_enabled
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
@@ -225,7 +225,7 @@ class PusherPool:
for u in users_affected:
# Don't push if the user account has expired
- if self._account_validity.enabled:
+ if self._account_validity_enabled:
expired = await self.store.is_account_expired(
u, self.clock.time_msec()
)
@@ -255,7 +255,7 @@ class PusherPool:
for u in users_affected:
# Don't push if the user account has expired
- if self._account_validity.enabled:
+ if self._account_validity_enabled:
expired = await self.store.is_account_expired(
u, self.clock.time_msec()
)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index bfd46a3730..60e6793c8d 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -112,6 +112,8 @@ CONDITIONAL_REQUIREMENTS = {
"redis": ["txredisapi>=1.4.7", "hiredis"],
}
+CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
+
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 84e002f934..84afc4c1e4 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -98,6 +98,73 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
return 200, {"event_id": event_id, "stream_id": stream_id}
+class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
+ """Perform a remote knock for the given user on the given room
+
+ Request format:
+
+ POST /_synapse/replication/remote_knock/:room_id/:user_id
+
+ {
+ "requester": ...,
+ "remote_room_hosts": [...],
+ "content": { ... }
+ }
+ """
+
+ NAME = "remote_knock"
+ PATH_ARGS = ("room_id", "user_id")
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ self.federation_handler = hs.get_federation_handler()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload( # type: ignore
+ requester: Requester,
+ room_id: str,
+ user_id: str,
+ remote_room_hosts: List[str],
+ content: JsonDict,
+ ):
+ """
+ Args:
+ requester: The user making the request, according to the access token.
+ room_id: The ID of the room to knock on.
+ user_id: The ID of the knocking user.
+ remote_room_hosts: Servers to try and send the knock via.
+ content: The event content to use for the knock event.
+ """
+ return {
+ "requester": requester.serialize(),
+ "remote_room_hosts": remote_room_hosts,
+ "content": content,
+ }
+
+ async def _handle_request( # type: ignore
+ self, request: Request, room_id: str, user_id: str,
+ ):
+ content = parse_json_object_from_request(request)
+
+ remote_room_hosts = content["remote_room_hosts"]
+ event_content = content["content"]
+
+ requester = Requester.deserialize(self.store, content["requester"])
+
+ request.requester = requester
+
+ logger.debug("remote_knock: %s on room: %s", user_id, room_id)
+
+ event_id, stream_id = await self.federation_handler.do_knock(
+ remote_room_hosts, room_id, user_id, event_content
+ )
+
+ return 200, {"event_id": event_id, "stream_id": stream_id}
+
+
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"""Rejects an out-of-band invite we have received from a remote server
@@ -166,6 +233,70 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
return 200, {"event_id": event_id, "stream_id": stream_id}
+class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
+ """Rescinds a local knock made on a remote room
+
+ Request format:
+
+ POST /_synapse/replication/remote_rescind_knock/:event_id
+
+ {
+ "txn_id": ...,
+ "requester": ...,
+ "content": { ... }
+ }
+ """
+
+ NAME = "remote_rescind_knock"
+ PATH_ARGS = ("knock_event_id",)
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.member_handler = hs.get_room_member_handler()
+
+ @staticmethod
+ async def _serialize_payload( # type: ignore
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ):
+ """
+ Args:
+ knock_event_id: The ID of the knock to be rescinded.
+ txn_id: An optional transaction ID supplied by the client.
+ requester: The user making the rescind request, according to the access token.
+ content: The content to include in the rescind event.
+ """
+ return {
+ "txn_id": txn_id,
+ "requester": requester.serialize(),
+ "content": content,
+ }
+
+ async def _handle_request( # type: ignore
+ self, request: Request, knock_event_id: str,
+ ):
+ content = parse_json_object_from_request(request)
+
+ txn_id = content["txn_id"]
+ event_content = content["content"]
+
+ requester = Requester.deserialize(self.store, content["requester"])
+
+ request.requester = requester
+
+ # hopefully we're now on the master, so this won't recurse!
+ event_id, stream_id = await self.member_handler.remote_rescind_knock(
+ knock_event_id, txn_id, requester, event_content,
+ )
+
+ return 200, {"event_id": event_id, "stream_id": stream_id}
+
+
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
new file mode 100644
index 0000000000..34fa3ff5b3
--- /dev/null
+++ b/synapse/replication/tcp/external_cache.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Any, Optional
+
+from prometheus_client import Counter
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util import json_decoder, json_encoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+set_counter = Counter(
+ "synapse_external_cache_set",
+ "Number of times we set a cache",
+ labelnames=["cache_name"],
+)
+
+get_counter = Counter(
+ "synapse_external_cache_get",
+ "Number of times we get a cache",
+ labelnames=["cache_name", "hit"],
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalCache:
+ """A cache backed by an external Redis. Does nothing if no Redis is
+ configured.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ self._redis_connection = hs.get_outbound_redis_connection()
+
+ def _get_redis_key(self, cache_name: str, key: str) -> str:
+ return "cache_v1:%s:%s" % (cache_name, key)
+
+ def is_enabled(self) -> bool:
+ """Whether the external cache is used or not.
+
+ It's safe to use the cache when this returns false, the methods will
+ just no-op, but the function is useful to avoid doing unnecessary work.
+ """
+ return self._redis_connection is not None
+
+ async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
+ """Add the key/value to the named cache, with the expiry time given.
+ """
+
+ if self._redis_connection is None:
+ return
+
+ set_counter.labels(cache_name).inc()
+
+ # txredisapi requires the value to be string, bytes or numbers, so we
+ # encode stuff in JSON.
+ encoded_value = json_encoder.encode(value)
+
+ logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
+
+ return await make_deferred_yieldable(
+ self._redis_connection.set(
+ self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
+ )
+ )
+
+ async def get(self, cache_name: str, key: str) -> Optional[Any]:
+ """Look up a key/value in the named cache.
+ """
+
+ if self._redis_connection is None:
+ return None
+
+ result = await make_deferred_yieldable(
+ self._redis_connection.get(self._get_redis_key(cache_name, key))
+ )
+
+ logger.debug("Got cache result %s %s: %r", cache_name, key, result)
+
+ get_counter.labels(cache_name, result is not None).inc()
+
+ if not result:
+ return None
+
+ # For some reason the integers get magically converted back to integers
+ if isinstance(result, int):
+ return result
+
+ return json_decoder.decode(result)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 317796d5e0..8ea8dcd587 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import (
+ TYPE_CHECKING,
Any,
Awaitable,
Dict,
@@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
TypingStream,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -88,7 +92,7 @@ class ReplicationCommandHandler:
back out to connections.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore()
@@ -282,13 +286,6 @@ class ReplicationCommandHandler:
if hs.config.redis.redis_enabled:
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
- lazyConnection,
- )
-
- logger.info(
- "Connecting to redis (host=%r port=%r)",
- hs.config.redis_host,
- hs.config.redis_port,
)
# First let's ensure that we have a ReplicationStreamer started.
@@ -299,13 +296,7 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
- outbound_redis_connection = lazyConnection(
- reactor=hs.get_reactor(),
- host=hs.config.redis_host,
- port=hs.config.redis_port,
- password=hs.config.redis.redis_password,
- reconnect=True,
- )
+ outbound_redis_connection = hs.get_outbound_redis_connection()
# Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory(
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index bc6ba709a7..fdd087683b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Type, cast
import txredisapi
@@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext,
run_as_background_process,
+ wrap_as_background_process,
)
from synapse.replication.tcp.commands import (
Command,
@@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
immediately after initialisation.
Attributes:
- handler: The command handler to handle incoming commands.
- stream_name: The *redis* stream name to subscribe to and publish from
- (not anything to do with Synapse replication streams).
- outbound_redis_connection: The connection to redis to use to send
+ synapse_handler: The command handler to handle incoming commands.
+ synapse_stream_name: The *redis* stream name to subscribe to and publish
+ from (not anything to do with Synapse replication streams).
+ synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""
- handler = None # type: ReplicationCommandHandler
- stream_name = None # type: str
- outbound_redis_connection = None # type: txredisapi.RedisProtocol
+ synapse_handler = None # type: ReplicationCommandHandler
+ synapse_stream_name = None # type: str
+ synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
- logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
- await make_deferred_yieldable(self.subscribe(self.stream_name))
+ logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
+ await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
- self.handler.new_connection(self)
+ self.synapse_handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
# We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the
# otherside won't know we've connected and so won't issue a REPLICATE.
- self.handler.send_positions_to_connection(self)
+ self.synapse_handler.send_positions_to_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
@@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
cmd: received command
"""
- cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+ cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
return
@@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason)
- self.handler.lost_connection(self)
+ self.synapse_handler.lost_connection(self)
# mark the logging context as finished
self._logging_context.__exit__(None, None, None)
@@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
await make_deferred_yieldable(
- self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+ self.synapse_outbound_redis_connection.publish(
+ self.synapse_stream_name, encoded_string
+ )
+ )
+
+
+class SynapseRedisFactory(txredisapi.RedisFactory):
+ """A subclass of RedisFactory that periodically sends pings to ensure that
+ we detect dead connections.
+ """
+
+ def __init__(
+ self,
+ hs: "HomeServer",
+ uuid: str,
+ dbid: Optional[int],
+ poolsize: int,
+ isLazy: bool = False,
+ handler: Type = txredisapi.ConnectionHandler,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ replyTimeout: int = 30,
+ convertNumbers: Optional[int] = True,
+ ):
+ super().__init__(
+ uuid=uuid,
+ dbid=dbid,
+ poolsize=poolsize,
+ isLazy=isLazy,
+ handler=handler,
+ charset=charset,
+ password=password,
+ replyTimeout=replyTimeout,
+ convertNumbers=convertNumbers,
)
+ hs.get_clock().looping_call(self._send_ping, 30 * 1000)
+
+ @wrap_as_background_process("redis_ping")
+ async def _send_ping(self):
+ for connection in self.pool:
+ try:
+ await make_deferred_yieldable(connection.ping())
+ except Exception:
+ logger.warning("Failed to send ping to a redis connection")
-class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+
+class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
subscribes to a stream.
@@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
):
- super().__init__()
-
- # This sets the password on the RedisFactory base class (as
- # SubscriberFactory constructor doesn't pass it through).
- self.password = hs.config.redis.redis_password
+ super().__init__(
+ hs,
+ uuid="subscriber",
+ dbid=None,
+ poolsize=1,
+ replyTimeout=30,
+ password=hs.config.redis.redis_password,
+ )
- self.handler = hs.get_tcp_replication()
- self.stream_name = hs.hostname
+ self.synapse_handler = hs.get_tcp_replication()
+ self.synapse_stream_name = hs.hostname
- self.outbound_redis_connection = outbound_redis_connection
+ self.synapse_outbound_redis_connection = outbound_redis_connection
def buildProtocol(self, addr):
- p = super().buildProtocol(addr) # type: RedisSubscriber
+ p = super().buildProtocol(addr)
+ p = cast(RedisSubscriber, p)
# We do this here rather than add to the constructor of `RedisSubcriber`
# as to do so would involve overriding `buildProtocol` entirely, however
# the base method does some other things than just instantiating the
# protocol.
- p.handler = self.handler
- p.outbound_redis_connection = self.outbound_redis_connection
- p.stream_name = self.stream_name
- p.password = self.password
+ p.synapse_handler = self.synapse_handler
+ p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
+ p.synapse_stream_name = self.synapse_stream_name
return p
def lazyConnection(
- reactor,
+ hs: "HomeServer",
host: str = "localhost",
port: int = 6379,
dbid: Optional[int] = None,
reconnect: bool = True,
- charset: str = "utf-8",
password: Optional[str] = None,
- connectTimeout: Optional[int] = None,
- replyTimeout: Optional[int] = None,
- convertNumbers: bool = True,
+ replyTimeout: int = 30,
) -> txredisapi.RedisProtocol:
- """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
- reactor.
+ """Creates a connection to Redis that is lazily set up and reconnects if the
+ connections is lost.
"""
- isLazy = True
- poolsize = 1
-
uuid = "%s:%d" % (host, port)
- factory = txredisapi.RedisFactory(
- uuid,
- dbid,
- poolsize,
- isLazy,
- txredisapi.ConnectionHandler,
- charset,
- password,
- replyTimeout,
- convertNumbers,
+ factory = SynapseRedisFactory(
+ hs,
+ uuid=uuid,
+ dbid=dbid,
+ poolsize=1,
+ isLazy=True,
+ handler=txredisapi.ConnectionHandler,
+ password=password,
+ replyTimeout=replyTimeout,
)
factory.continueTrying = reconnect
- for x in range(poolsize):
- reactor.connectTCP(host, port, factory, connectTimeout)
+
+ reactor = hs.get_reactor()
+ reactor.connectTCP(host, port, factory, 30)
return factory.handler
diff --git a/synapse/res/templates/account_previously_renewed.html b/synapse/res/templates/account_previously_renewed.html
new file mode 100644
index 0000000000..b751359bdf
--- /dev/null
+++ b/synapse/res/templates/account_previously_renewed.html
@@ -0,0 +1 @@
+<html><body>Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html>
diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html
index 894da030af..e8c0f52f05 100644
--- a/synapse/res/templates/account_renewed.html
+++ b/synapse/res/templates/account_renewed.html
@@ -1 +1 @@
-<html><body>Your account has been successfully renewed.</body><html>
+<html><body>Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html>
diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
index a75c73a142..0f704b0818 100644
--- a/synapse/res/templates/sso_auth_bad_user.html
+++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -12,7 +12,7 @@
<header>
<h1>That doesn't look right</h1>
<p>
- <strong>We were unable to validate your {{ server_name | e }} account</strong>
+ <strong>We were unable to validate your {{ server_name }} account</strong>
via single sign‑on (SSO), because the SSO Identity
Provider returned different details than when you logged in.
</p>
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index 69b93d65c1..f127804823 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -22,7 +22,7 @@
<header>
<h1>There was an error</h1>
<p>
- <strong id="errormsg">{{ error_description | e }}</strong>
+ <strong id="errormsg">{{ error_description }}</strong>
</p>
<p>
If you are seeing this page after clicking a link sent to you via email,
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
index 5b38481012..62a640dad2 100644
--- a/synapse/res/templates/sso_login_idp_picker.html
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -3,22 +3,22 @@
<head>
<meta charset="UTF-8">
<link rel="stylesheet" href="/_matrix/static/client/login/style.css">
- <title>{{server_name | e}} Login</title>
+ <title>{{ server_name }} Login</title>
</head>
<body>
<div id="container">
- <h1 id="title">{{server_name | e}} Login</h1>
+ <h1 id="title">{{ server_name }} Login</h1>
<div class="login_flow">
<p>Choose one of the following identity providers:</p>
<form>
- <input type="hidden" name="redirectUrl" value="{{redirect_url | e}}">
+ <input type="hidden" name="redirectUrl" value="{{ redirect_url }}">
<ul class="radiobuttons">
{% for p in providers %}
<li>
- <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
- <label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
+ <input type="radio" name="idp" id="prov{{ loop.index }}" value="{{ p.idp_id }}">
+ <label for="prov{{ loop.index }}">{{ p.idp_name }}</label>
{% if p.idp_icon %}
- <img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/>
+ <img src="{{ p.idp_icon | mxc_to_http(32, 32) }}"/>
{% endif %}
</li>
{% endfor %}
diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
index ce4f573848..d1328a6969 100644
--- a/synapse/res/templates/sso_redirect_confirm.html
+++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -12,11 +12,11 @@
<header>
{% if new_user %}
<h1>Your account is now ready</h1>
- <p>You've made your account on {{ server_name | e }}.</p>
+ <p>You've made your account on {{ server_name }}.</p>
{% else %}
<h1>Log in</h1>
{% endif %}
- <p>Continue to confirm you trust <strong>{{ display_url | e }}</strong>.</p>
+ <p>Continue to confirm you trust <strong>{{ display_url }}</strong>.</p>
</header>
<main>
{% if user_profile.avatar_url %}
@@ -24,13 +24,13 @@
<img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
<div class="profile-details">
{% if user_profile.display_name %}
- <div class="display-name">{{ user_profile.display_name | e }}</div>
+ <div class="display-name">{{ user_profile.display_name }}</div>
{% endif %}
- <div class="user-id">{{ user_id | e }}</div>
+ <div class="user-id">{{ user_id }}</div>
</div>
</div>
{% endif %}
- <a href="{{ redirect_url | e }}" class="primary-button">Continue</a>
+ <a href="{{ redirect_url }}" class="primary-button">Continue</a>
</main>
</body>
</html>
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 40f5c32db2..ee3a9af569 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -39,6 +39,7 @@ from synapse.rest.client.v2_alpha import (
filter,
groups,
keys,
+ knock,
notifications,
openid,
password_policy,
@@ -119,8 +120,10 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
+ password_policy.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
+ knock.register_servlets(hs, client_resource)
# moving to /_synapse/admin
admin.register_servlets_for_client_rest_resource(hs, client_resource)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 6f7dc06503..57e0a10837 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
+# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
+
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,6 +38,7 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
+ ForwardExtremitiesRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
MakeRoomAdminRestServlet,
@@ -51,6 +54,7 @@ from synapse.rest.admin.users import (
PushersRestServlet,
ResetPasswordRestServlet,
SearchUsersRestServlet,
+ ShadowBanRestServlet,
UserAdminServlet,
UserMediaRestServlet,
UserMembershipRestServlet,
@@ -230,6 +234,8 @@ def register_servlets(hs, http_server):
EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
MakeRoomAdminRestServlet(hs).register(http_server)
+ ShadowBanRestServlet(hs).register(http_server)
+ ForwardExtremitiesRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index ab7cc9102a..fcbb91f736 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
from http import HTTPStatus
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
@@ -23,6 +23,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
+ parse_list_from_args,
parse_string,
)
from synapse.http.site import SynapseRequest
@@ -323,10 +324,8 @@ class JoinRoomAliasServlet(RestServlet):
if RoomID.is_valid(room_identifier):
room_id = room_identifier
try:
- remote_room_hosts = [
- x.decode("ascii") for x in request.args[b"server_name"]
- ] # type: Optional[List[str]]
- except Exception:
+ remote_room_hosts = parse_list_from_args(request.args, "server_name")
+ except KeyError:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
@@ -431,7 +430,17 @@ class MakeRoomAdminRestServlet(RestServlet):
if not admin_users:
raise SynapseError(400, "No local admin user in room")
- admin_user_id = admin_users[-1]
+ admin_user_id = None
+
+ for admin_user in reversed(admin_users):
+ if room_state.get((EventTypes.Member, admin_user)):
+ admin_user_id = admin_user
+ break
+
+ if not admin_user_id:
+ raise SynapseError(
+ 400, "No local admin user in room",
+ )
pl_content = power_levels.content
else:
@@ -499,3 +508,60 @@ class MakeRoomAdminRestServlet(RestServlet):
)
return 200, {}
+
+
+class ForwardExtremitiesRestServlet(RestServlet):
+ """Allows a server admin to get or clear forward extremities.
+
+ Clearing does not require restarting the server.
+
+ Clear forward extremities:
+ DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+
+ Get forward_extremities:
+ GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.store = hs.get_datastore()
+
+ async def resolve_room_id(self, room_identifier: str) -> str:
+ """Resolve to a room ID, if necessary."""
+ if RoomID.is_valid(room_identifier):
+ resolved_room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+ resolved_room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+ if not resolved_room_id:
+ raise SynapseError(
+ 400, "Unknown room ID or room alias %s" % room_identifier
+ )
+ return resolved_room_id
+
+ async def on_DELETE(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ room_id = await self.resolve_room_id(room_identifier)
+
+ deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
+ return 200, {"deleted": deleted_count}
+
+ async def on_GET(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ room_id = await self.resolve_room_id(room_identifier)
+
+ extremities = await self.store.get_forward_extremities_for_room(room_id)
+ return 200, {"count": len(extremities), "results": extremities}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f39e3d6d5c..68c3c64a0d 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet):
The parameter `deactivated` can be used to include deactivated users.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
+
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter from must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter limit must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
user_id = parse_string(request, "user_id", default=None)
name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
@@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet):
start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
- if len(users) >= limit:
+ if (start + limit) < total:
ret["next_token"] = str(start + len(users))
return 200, ret
@@ -875,3 +890,39 @@ class UserTokenRestServlet(RestServlet):
)
return 200, {"access_token": token}
+
+
+class ShadowBanRestServlet(RestServlet):
+ """An admin API for shadow-banning a user.
+
+ A shadow-banned users receives successful responses to their client-server
+ API requests, but the events are not propagated into rooms.
+
+ Shadow-banning a user should be used as a tool of last resort and may lead
+ to confusing or broken behaviour for the client.
+
+ Example:
+
+ POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
+ {}
+
+ 200 OK
+ {}
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Only local users can be shadow-banned")
+
+ await self.store.set_shadow_banned(UserID.from_string(user_id), True)
+
+ return 200, {}
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 23a529f8e3..94bfe2d1b0 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -49,9 +49,7 @@ class PresenceStatusRestServlet(RestServlet):
raise AuthError(403, "You are not allowed to see their presence.")
state = await self.presence_handler.get_state(target_user=user)
- state = format_user_presence_state(
- state, self.clock.time_msec(), include_user_id=False
- )
+ state = format_user_presence_state(state, self.clock.time_msec())
return 200, state
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 85a66458c5..b5fa1cc464 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,6 +14,7 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
+from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -28,6 +29,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -65,8 +67,27 @@ class ProfileDisplaynameRestServlet(RestServlet):
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_displayname(shadow_user.to_string(), content)
+
+ return 200, {}
+
+ def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_displayname(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
@@ -75,6 +96,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -113,8 +135,27 @@ class ProfileAvatarURLRestServlet(RestServlet):
user, requester, new_avatar_url, is_admin
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_avatar_url(shadow_user.to_string(), content)
+
+ return 200, {}
+
+ def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_avatar_url(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index f95627ee61..c8b128583f 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -15,10 +15,9 @@
# limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
-
import logging
import re
-from typing import TYPE_CHECKING, List, Optional
+from typing import TYPE_CHECKING, Optional
from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership
@@ -37,6 +36,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
+ parse_list_from_args,
parse_string,
)
from synapse.logging.opentracing import set_tag
@@ -283,10 +283,8 @@ class JoinRoomAliasServlet(TransactionRestServlet):
if RoomID.is_valid(room_identifier):
room_id = room_identifier
try:
- remote_room_hosts = [
- x.decode("ascii") for x in request.args[b"server_name"]
- ] # type: Optional[List[str]]
- except Exception:
+ remote_room_hosts = parse_list_from_args(request.args, "server_name")
+ except KeyError:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
@@ -739,7 +737,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["id_server"],
requester,
txn_id,
- content.get("id_access_token"),
+ new_room=False,
+ id_access_token=content.get("id_access_token"),
)
except ShadowBanError:
# Pretend the request succeeded.
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 65e68d641b..aa170c215f 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
# limitations under the License.
import logging
import random
+import re
from http import HTTPStatus
from typing import TYPE_CHECKING
from urllib.parse import urlparse
@@ -38,6 +39,7 @@ from synapse.http.servlet import (
)
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
@@ -54,7 +56,7 @@ logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.datastore = hs.get_datastore()
@@ -103,6 +105,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
@@ -164,6 +168,7 @@ class PasswordRestServlet(RestServlet):
self.datastore = self.hs.get_datastore()
self.password_policy_handler = hs.get_password_policy_handler()
self._set_password_handler = hs.get_set_password_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -189,24 +194,29 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
- try:
- params, session_id = await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "modify your account password",
- )
- except InteractiveAuthIncompleteError as e:
- # The user needs to provide more steps to complete auth, but
- # they're not required to provide the password again.
- #
- # If a password is available now, hash the provided password and
- # store it for later.
- if new_password:
- password_hash = await self.auth_handler.hash(new_password)
- await self.auth_handler.set_session_data(
- e.session_id,
- UIAuthSessionDataConstants.PASSWORD_HASH,
- password_hash,
+ # blindly trust ASes without UI-authing them
+ if requester.app_service:
+ params = body
+ else:
+ try:
+ (
+ params,
+ session_id,
+ ) = await self.auth_handler.validate_user_via_ui_auth(
+ requester, request, body, "modify your account password",
)
- raise
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
user_id = requester.user.to_string()
else:
requester = None
@@ -276,8 +286,28 @@ class PasswordRestServlet(RestServlet):
user_id, password_hash, logout_devices, requester
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ await self.shadow_password(params, shadow_user.to_string())
+
+ return 200, {}
+
+ def on_OPTIONS(self, _):
return 200, {}
+ async def shadow_password(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
@@ -372,13 +402,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", email)):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -430,7 +462,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__()
self.store = self.hs.get_datastore()
@@ -451,13 +483,17 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -621,7 +657,8 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- self.datastore = self.hs.get_datastore()
+ self.datastore = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
@@ -640,6 +677,29 @@ class ThreepidRestServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
+ # skip validation if this is a shadow 3PID from an AS
+ if requester.app_service:
+ # XXX: ASes pass in a validated threepid directly to bypass the IS.
+ # This makes the API entirely change shape when we have an AS token;
+ # it really should be an entirely separate API - perhaps
+ # /account/3pid/replicate or something.
+ threepid = body.get("threepid")
+
+ await self.auth_handler.add_threepid(
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
+ )
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
+ return 200, {}
+
threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds")
if threepid_creds is None:
raise SynapseError(
@@ -661,12 +721,35 @@ class ThreepidRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ async def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")
@@ -677,6 +760,7 @@ class ThreepidAddRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -708,12 +792,33 @@ class ThreepidAddRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ async def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")
@@ -783,6 +888,7 @@ class ThreepidDeleteRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
if not self.hs.config.enable_3pid_changes:
@@ -807,6 +913,12 @@ class ThreepidDeleteRestServlet(RestServlet):
logger.exception("Failed to remove threepid")
raise SynapseError(500, "Failed to remove threepid")
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ await self.shadow_3pid_delete(body, shadow_user.to_string())
+
if ret:
id_server_unbind_result = "success"
else:
@@ -814,6 +926,74 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
+ async def shadow_3pid_delete(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
+
+class ThreepidLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_identity_handler()
+
+ async def on_GET(self, request):
+ """Proxy a /_matrix/identity/api/v1/lookup request to an identity
+ server
+ """
+ await self.auth.get_user_by_req(request)
+
+ # Verify query parameters
+ query_params = request.args
+ assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"])
+
+ # Retrieve needed information from query parameters
+ medium = parse_string(request, "medium")
+ address = parse_string(request, "address")
+ id_server = parse_string(request, "id_server")
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = await self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
+
+ return 200, ret
+
+
+class ThreepidBulkLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidBulkLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_identity_handler()
+
+ async def on_POST(self, request):
+ """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
+ server
+ """
+ await self.auth.get_user_by_req(request)
+
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["threepids", "id_server"])
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = await self.identity_handler.proxy_bulk_lookup_3pid(
+ body["id_server"], body["threepids"]
+ )
+
+ return 200, ret
+
def assert_valid_next_link(hs: "HomeServer", next_link: str):
"""
@@ -880,4 +1060,6 @@ def register_servlets(hs, http_server):
ThreepidBindRestServlet(hs).register(http_server)
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ ThreepidLookupRestServlet(hs).register(http_server)
+ ThreepidBulkLookupRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 3f28c0bc3e..c9f13e4ac5 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -17,6 +17,7 @@ import logging
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
from ._base import client_patterns
@@ -38,6 +39,9 @@ class AccountDataServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.handler = hs.get_account_data_handler()
+ self.notifier = hs.get_notifier()
+ self._is_worker = hs.config.worker_app is not None
+ self._profile_handler = hs.get_profile_handler()
async def on_PUT(self, request, user_id, account_data_type):
requester = await self.auth.get_user_by_req(request)
@@ -46,7 +50,15 @@ class AccountDataServlet(RestServlet):
body = parse_json_object_from_request(request)
- await self.handler.add_account_data_for_user(user_id, account_data_type, body)
+ if account_data_type == "im.vector.hide_profile":
+ user = UserID.from_string(user_id)
+ hide_profile = body.get("hide_profile")
+ await self._profile_handler.set_active([user], not hide_profile, True)
+
+ max_id = await self.handler.add_account_data_for_user(
+ user_id, account_data_type, body
+ )
+ self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index bd7f9ae203..40c5bd4d8c 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -37,24 +37,38 @@ class AccountValidityRenewServlet(RestServlet):
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
- self.success_html = hs.config.account_validity.account_renewed_html_content
- self.failure_html = hs.config.account_validity.invalid_token_html_content
+ self.account_renewed_template = (
+ hs.config.account_validity_account_renewed_template
+ )
+ self.account_previously_renewed_template = (
+ hs.config.account_validity_account_previously_renewed_template
+ )
+ self.invalid_token_template = hs.config.account_validity_invalid_token_template
async def on_GET(self, request):
if b"token" not in request.args:
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
- token_valid = await self.account_activity_handler.renew_account(
+ (
+ token_valid,
+ token_stale,
+ expiration_ts,
+ ) = await self.account_activity_handler.renew_account(
renewal_token.decode("utf8")
)
if token_valid:
status_code = 200
- response = self.success_html
+ response = self.account_renewed_template.render(expiration_ts=expiration_ts)
+ elif token_stale:
+ status_code = 200
+ response = self.account_previously_renewed_template.render(
+ expiration_ts=expiration_ts
+ )
else:
status_code = 404
- response = self.failure_html
+ response = self.invalid_token_template.render(expiration_ts=expiration_ts)
respond_with_html(request, status_code, response)
@@ -72,10 +86,12 @@ class AccountValiditySendMailServlet(RestServlet):
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
- self.account_validity = self.hs.config.account_validity
+ self.account_validity_renew_by_email_enabled = (
+ self.hs.config.account_validity_renew_by_email_enabled
+ )
async def on_POST(self, request):
- if not self.account_validity.renew_by_email_enabled:
+ if not self.account_validity_renew_by_email_enabled:
raise AuthError(
403, "Account renewal via email is disabled on this server."
)
diff --git a/synapse/rest/client/v2_alpha/knock.py b/synapse/rest/client/v2_alpha/knock.py
new file mode 100644
index 0000000000..8439da447e
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/knock.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Sorunome
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Optional, Tuple
+
+from twisted.web.server import Request
+
+from synapse.api.constants import Membership
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ parse_json_object_from_request,
+ parse_list_from_args,
+)
+from synapse.logging.opentracing import set_tag
+from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.types import JsonDict, RoomAlias, RoomID
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class KnockRoomAliasServlet(RestServlet):
+ """
+ POST /xyz.amorgan.knock/{roomIdOrAlias}
+ """
+
+ PATTERNS = client_patterns(
+ "/xyz.amorgan.knock/(?P<room_identifier>[^/]*)", releases=()
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.txns = HttpTransactionCache(hs)
+ self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(
+ self, request: Request, room_identifier: str, txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+
+ content = parse_json_object_from_request(request)
+ event_content = None
+ if "reason" in content:
+ event_content = {"reason": content["reason"]}
+
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+ try:
+ remote_room_hosts = parse_list_from_args(request.args, "server_name")
+ except KeyError:
+ remote_room_hosts = None
+ elif RoomAlias.is_valid(room_identifier):
+ handler = self.room_member_handler
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias)
+ room_id = room_id_obj.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+
+ await self.room_member_handler.update_membership(
+ requester=requester,
+ target=requester.user,
+ room_id=room_id,
+ action=Membership.KNOCK,
+ txn_id=txn_id,
+ third_party_signed=None,
+ remote_room_hosts=remote_room_hosts,
+ content=event_content,
+ )
+
+ return 200, {"room_id": room_id}
+
+ def on_PUT(self, request: Request, room_identifier: str, txn_id: str):
+ set_tag("txn_id", txn_id)
+
+ return self.txns.fetch_or_execute_request(
+ request, self.on_POST, request, room_identifier, txn_id
+ )
+
+
+def register_servlets(hs, http_server):
+ KnockRoomAliasServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b093183e79..5d461efb6a 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 - 2016 OpenMarket Ltd
-# Copyright 2017 Vector Creations Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,6 +18,7 @@
import hmac
import logging
import random
+import re
from typing import List, Union
import synapse
@@ -119,13 +121,15 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", body["email"])):
raise SynapseError(
403,
- "Your email domain is not authorized to register on this server",
+ "Your email is not authorized to register on this server",
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
@@ -198,13 +202,19 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ assert_valid_client_secret(body["client_secret"])
+
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Phone numbers are not authorized to register on this server",
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn
)
@@ -354,15 +364,9 @@ class UsernameAvailabilityRestServlet(RestServlet):
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
- ip = request.getClientIP()
- with self.ratelimiter.ratelimit(ip) as wait_deferred:
- await wait_deferred
-
- username = parse_string(request, "username", required=True)
-
- await self.registration_handler.check_username(username)
-
- return 200, {"available": True}
+ # We are not interested in logging in via a username in this deployment.
+ # Simply allow anything here as it won't be used later.
+ return 200, {"available": True}
class RegisterRestServlet(RestServlet):
@@ -412,18 +416,27 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
- # Pull out the provided username and do basic sanity checks early since
- # the auth layer will store these in sessions.
+ # We don't care about usernames for this deployment. In fact, the act
+ # of checking whether they exist already can leak metadata about
+ # which users are already registered.
+ #
+ # Usernames are already derived via the provided email.
+ # So, if they're not necessary, just ignore them.
+ #
+ # (we do still allow appservices to set them below)
desired_username = None
- if "username" in body:
- if not isinstance(body["username"], str) or len(body["username"]) > 512:
- raise SynapseError(400, "Invalid username")
- desired_username = body["username"]
+
+ desired_display_name = body.get("display_name")
appservice = None
if self.auth.has_access_token(request):
appservice = self.auth.get_appservice_by_req(request)
+ # We need to retrieve the password early in order to pass it to
+ # application service registration
+ # This is specific to shadow server registration of users via an AS
+ password = body.pop("password", None)
+
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
@@ -432,7 +445,7 @@ class RegisterRestServlet(RestServlet):
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
- desired_username = body.get("user", desired_username)
+ desired_username = body.get("user", body.get("username"))
# XXX we should check that desired_username is valid. Currently
# we give appservices carte blanche for any insanity in mxids,
@@ -445,7 +458,7 @@ class RegisterRestServlet(RestServlet):
raise SynapseError(400, "Desired Username is missing or not a string")
result = await self._do_appservice_registration(
- desired_username, access_token, body
+ desired_username, password, desired_display_name, access_token, body
)
return 200, result
@@ -454,16 +467,6 @@ class RegisterRestServlet(RestServlet):
if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
- # For regular registration, convert the provided username to lowercase
- # before attempting to register it. This should mean that people who try
- # to register with upper-case in their usernames don't get a nasty surprise.
- #
- # Note that we treat usernames case-insensitively in login, so they are
- # free to carry on imagining that their username is CrAzYh4cKeR if that
- # keeps them happy.
- if desired_username is not None:
- desired_username = desired_username.lower()
-
# Check if this account is upgrading from a guest account.
guest_access_token = body.get("guest_access_token", None)
@@ -472,7 +475,6 @@ class RegisterRestServlet(RestServlet):
# Note that we remove the password from the body since the auth layer
# will store the body in the session and we don't want a plaintext
# password store there.
- password = body.pop("password", None)
if password is not None:
if not isinstance(password, str) or len(password) > 512:
raise SynapseError(400, "Invalid password")
@@ -502,14 +504,6 @@ class RegisterRestServlet(RestServlet):
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
- # Ensure that the username is valid.
- if desired_username is not None:
- await self.registration_handler.check_username(
- desired_username,
- guest_access_token=guest_access_token,
- assigned_user_id=registered_user_id,
- )
-
# Check if the user-interactive authentication flows are complete, if
# not this will raise a user-interactive auth error.
try:
@@ -546,7 +540,7 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
- if not check_3pid_allowed(self.hs, medium, address):
+ if not (await check_3pid_allowed(self.hs, medium, address)):
raise SynapseError(
403,
"Third party identifiers (email/phone numbers)"
@@ -554,6 +548,80 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ existingUid = await self.store.get_user_id_by_threepid(
+ medium, address
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE
+ )
+
+ if self.hs.config.register_mxid_from_3pid:
+ # override the desired_username based on the 3PID if any.
+ # reset it first to avoid folks picking their own username.
+ desired_username = None
+
+ # we should have an auth_result at this point if we're going to progress
+ # to register the user (i.e. we haven't picked up a registered_user_id
+ # from our session store), in which case get ready and gen the
+ # desired_username
+ if auth_result:
+ if (
+ self.hs.config.register_mxid_from_3pid == "email"
+ and LoginType.EMAIL_IDENTITY in auth_result
+ ):
+ address = auth_result[LoginType.EMAIL_IDENTITY]["address"]
+ desired_username = synapse.types.strip_invalid_mxid_characters(
+ address.replace("@", "-").lower()
+ )
+
+ # find a unique mxid for the account, suffixing numbers
+ # if needed
+ while True:
+ try:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+ # if we got this far we passed the check.
+ break
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ m = re.match(r"^(.*?)(\d+)$", desired_username)
+ if m:
+ desired_username = m.group(1) + str(
+ int(m.group(2)) + 1
+ )
+ else:
+ desired_username += "1"
+ else:
+ # something else went wrong.
+ break
+
+ if self.hs.config.register_just_use_email_for_display_name:
+ desired_display_name = address
+ else:
+ # Custom mapping between email address and display name
+ desired_display_name = _map_email_to_displayname(address)
+ elif (
+ self.hs.config.register_mxid_from_3pid == "msisdn"
+ and LoginType.MSISDN in auth_result
+ ):
+ desired_username = auth_result[LoginType.MSISDN]["address"]
+ else:
+ raise SynapseError(
+ 400, "Cannot derive mxid from 3pid; no recognised 3pid"
+ )
+
+ if desired_username is not None:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session", registered_user_id
@@ -568,7 +636,12 @@ class RegisterRestServlet(RestServlet):
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
- desired_username = params.get("username", None)
+ if not self.hs.config.register_mxid_from_3pid:
+ desired_username = params.get("username", None)
+ else:
+ # we keep the original desired_username derived from the 3pid above
+ pass
+
guest_access_token = params.get("guest_access_token", None)
if desired_username is not None:
@@ -617,6 +690,7 @@ class RegisterRestServlet(RestServlet):
localpart=desired_username,
password_hash=password_hash,
guest_access_token=guest_access_token,
+ default_display_name=desired_display_name,
threepid=threepid,
address=client_addr,
user_agent_ips=entries,
@@ -629,6 +703,14 @@ class RegisterRestServlet(RestServlet):
):
await self.store.upsert_monthly_active_user(registered_user_id)
+ if self.hs.config.shadow_server:
+ await self.registration_handler.shadow_register(
+ localpart=desired_username,
+ display_name=desired_display_name,
+ auth_result=auth_result,
+ params=params,
+ )
+
# Remember that the user account has been registered (and the user
# ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
@@ -652,14 +734,40 @@ class RegisterRestServlet(RestServlet):
return 200, return_dict
- async def _do_appservice_registration(self, username, as_token, body):
+ async def _do_appservice_registration(
+ self, username, password, display_name, as_token, body
+ ):
+ # FIXME: appservice_register() is horribly duplicated with register()
+ # and they should probably just be combined together with a config flag.
+
+ if password:
+ # Hash the password
+ #
+ # In mainline hashing of the password was moved further on in the registration
+ # flow, but we need it here for the AS use-case of shadow servers
+ password = await self.auth_handler.hash(password)
user_id = await self.registration_handler.appservice_register(
- username, as_token
+ username, as_token, password, display_name
)
- return await self._create_registration_details(
+ result = await self._create_registration_details(
user_id, body, is_appservice_ghost=True,
)
+ auth_result = body.get("auth_result")
+ if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
+ threepid = auth_result[LoginType.EMAIL_IDENTITY]
+ await self.registration_handler.register_email_threepid(
+ user_id, threepid, result["access_token"]
+ )
+
+ if auth_result and LoginType.MSISDN in auth_result:
+ threepid = auth_result[LoginType.MSISDN]
+ await self.registration_handler.register_msisdn_threepid(
+ user_id, threepid, result["access_token"]
+ )
+
+ return result
+
async def _create_registration_details(
self, user_id, params, is_appservice_ghost=False
):
@@ -715,6 +823,60 @@ class RegisterRestServlet(RestServlet):
)
+def cap(name):
+ """Capitalise parts of a name containing different words, including those
+ separated by hyphens.
+ For example, 'John-Doe'
+
+ Args:
+ name (str): The name to parse
+ """
+ if not name:
+ return name
+
+ # Split the name by whitespace then hyphens, capitalizing each part then
+ # joining it back together.
+ capatilized_name = " ".join(
+ "-".join(part.capitalize() for part in space_part.split("-"))
+ for space_part in name.split()
+ )
+ return capatilized_name
+
+
+def _map_email_to_displayname(address):
+ """Custom mapping from an email address to a user displayname
+
+ Args:
+ address (str): The email address to process
+ Returns:
+ str: The new displayname
+ """
+ # Split the part before and after the @ in the email.
+ # Replace all . with spaces in the first part
+ parts = address.replace(".", " ").split("@")
+
+ # Figure out which org this email address belongs to
+ org_parts = parts[1].split(" ")
+
+ # If this is a ...matrix.org email, mark them as an Admin
+ if org_parts[-2] == "matrix" and org_parts[-1] == "org":
+ org = "Tchap Admin"
+
+ # Is this is a ...gouv.fr address, set the org to whatever is before
+ # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their
+ # org as "gouv"
+ elif org_parts[-2] == "gouv" and org_parts[-1] == "fr":
+ org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2]
+
+ # Otherwise, mark their org as the email's second-level domain name
+ else:
+ org = org_parts[-2]
+
+ desired_display_name = cap(parts[0]) + " [" + cap(org) + "]"
+
+ return desired_display_name
+
+
def _calculate_registration_flows(
# technically `config` has to provide *all* of these interfaces, not just one
config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 8e52e4cca4..582c999abd 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
+from typing import Any, Callable, Dict, List
-from synapse.api.constants import PresenceState
+from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
@@ -24,7 +24,7 @@ from synapse.events.utils import (
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
-from synapse.handlers.sync import SyncConfig
+from synapse.handlers.sync import KnockedSyncResult, SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken
from synapse.util import json_decoder
@@ -213,6 +213,10 @@ class SyncRestServlet(RestServlet):
sync_result.invited, time_now, access_token_id, event_formatter
)
+ knocked = await self.encode_knocked(
+ sync_result.knocked, time_now, access_token_id, event_formatter
+ )
+
archived = await self.encode_archived(
sync_result.archived,
time_now,
@@ -230,11 +234,16 @@ class SyncRestServlet(RestServlet):
"left": list(sync_result.device_lists.left),
},
"presence": SyncRestServlet.encode_presence(sync_result.presence, time_now),
- "rooms": {"join": joined, "invite": invited, "leave": archived},
+ "rooms": {
+ Membership.JOIN: joined,
+ Membership.INVITE: invited,
+ Membership.KNOCK: knocked,
+ Membership.LEAVE: archived,
+ },
"groups": {
- "join": sync_result.groups.join,
- "invite": sync_result.groups.invite,
- "leave": sync_result.groups.leave,
+ Membership.JOIN: sync_result.groups.join,
+ Membership.INVITE: sync_result.groups.invite,
+ Membership.LEAVE: sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
@@ -296,7 +305,7 @@ class SyncRestServlet(RestServlet):
Args:
rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
- sync results for rooms this user is joined to
+ sync results for rooms this user is invited to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
@@ -315,7 +324,7 @@ class SyncRestServlet(RestServlet):
time_now,
token_id=token_id,
event_format=event_formatter,
- is_invite=True,
+ include_stripped_room_state=True,
)
unsigned = dict(invite.get("unsigned", {}))
invite["unsigned"] = unsigned
@@ -325,6 +334,60 @@ class SyncRestServlet(RestServlet):
return invited
+ async def encode_knocked(
+ self,
+ rooms: List[KnockedSyncResult],
+ time_now: int,
+ token_id: int,
+ event_formatter: Callable[[Dict], Dict],
+ ) -> Dict[str, Dict[str, Any]]:
+ """
+ Encode the rooms we've knocked on in a sync result.
+
+ Args:
+ rooms: list of sync results for rooms this user is knocking on
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing of transaction IDs
+ event_formatter: function to convert from federation format to client format
+
+ Returns:
+ The list of rooms the user has knocked on, in our response format.
+ """
+ knocked = {}
+ for room in rooms:
+ knock = await self._event_serializer.serialize_event(
+ room.knock,
+ time_now,
+ token_id=token_id,
+ event_format=event_formatter,
+ include_stripped_room_state=True,
+ )
+
+ # Extract the `unsigned` key from the knock event.
+ # This is where we (cheekily) store the knock state events
+ unsigned = knock.setdefault("unsigned", {})
+
+ # Duplicate the dictionary in order to avoid modifying the original
+ unsigned = dict(unsigned)
+
+ # Extract the stripped room state from the unsigned dict
+ # This is for clients to get a little bit of information about
+ # the room they've knocked on, without revealing any sensitive information
+ knocked_state = list(unsigned.pop("knock_room_state", []))
+
+ # Append the actual knock membership event itself as well. This provides
+ # the client with:
+ #
+ # * A knock state event that they can use for easier internal tracking
+ # * The rough timestamp of when the knock occurred contained within the event
+ knocked_state.append(knock)
+
+ # Build the `knock_state` dictionary, which will contain the state of the
+ # room that the client has knocked on
+ knocked[room.room_id] = {"knock_state": {"events": knocked_state}}
+
+ return knocked
+
async def encode_archived(
self, rooms, time_now, token_id, event_fields, event_formatter
):
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index ad598cefe0..eeddfa31f8 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -14,9 +14,17 @@
# limitations under the License.
import logging
+from typing import Dict
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from signedjson.sign import sign_json
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.types import UserID
from ._base import client_patterns
@@ -35,6 +43,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
"""Searches for users in directory
@@ -61,6 +70,16 @@ class UserDirectorySearchRestServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if self.hs.config.user_directory_defer_to_id_server:
+ signed_body = sign_json(
+ body, self.hs.hostname, self.hs.config.signing_key[0]
+ )
+ url = "%s/_matrix/identity/api/v1/user_directory/search" % (
+ self.hs.config.user_directory_defer_to_id_server,
+ )
+ resp = await self.http_client.post_json_get_json(url, signed_body)
+ return 200, resp
+
limit = body.get("limit", 10)
limit = min(limit, 50)
@@ -76,5 +95,124 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results
+class SingleUserInfoServlet(RestServlet):
+ """
+ Deprecated and replaced by `/users/info`
+
+ GET /user/{user_id}/info HTTP/1.1
+ """
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
+
+ def __init__(self, hs):
+ super(SingleUserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+ registry = hs.get_federation_registry()
+
+ if not registry.query_handlers.get("user_info"):
+ registry.register_query_handler("user_info", self._on_federation_query)
+
+ async def on_GET(self, request, user_id):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ # Attempt to make a federation request to the server that owns this user
+ args = {"user_id": user_id}
+ res = await self.transport_layer.make_query(
+ user.domain, "user_info", args, retry_on_dns_fail=True
+ )
+ return 200, res
+
+ user_id_to_info = await self.store.get_info_for_users([user_id])
+ return 200, user_id_to_info[user_id]
+
+ async def _on_federation_query(self, args):
+ """Called when a request for user information appears over federation
+
+ Args:
+ args (dict): Dictionary of query arguments provided by the request
+
+ Returns:
+ Deferred[dict]: Deactivation and expiration information for a given user
+ """
+ user_id = args.get("user_id")
+ if not user_id:
+ raise SynapseError(400, "user_id not provided")
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ user_ids_to_info_dict = await self.store.get_info_for_users([user_id])
+ return user_ids_to_info_dict[user_id]
+
+
+class UserInfoServlet(RestServlet):
+ """Bulk version of `/user/{user_id}/info` endpoint
+
+ GET /users/info HTTP/1.1
+
+ Returns a dictionary of user_id to info dictionary. Supports remote users
+ """
+
+ PATTERNS = client_patterns("/users/info$", unstable=True, releases=())
+
+ def __init__(self, hs):
+ super(UserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+
+ async def on_POST(self, request):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ # Extract the user_ids from the request
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, required=["user_ids"])
+
+ user_ids = body["user_ids"]
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ # Separate local and remote users
+ local_user_ids = set()
+ remote_server_to_user_ids = {} # type: Dict[str, set]
+ for user_id in user_ids:
+ user = UserID.from_string(user_id)
+
+ if self.hs.is_mine(user):
+ local_user_ids.add(user_id)
+ else:
+ remote_server_to_user_ids.setdefault(user.domain, set())
+ remote_server_to_user_ids[user.domain].add(user_id)
+
+ # Retrieve info of all local users
+ user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids)
+
+ # Request info of each remote user from their remote homeserver
+ for server_name, user_id_set in remote_server_to_user_ids.items():
+ # Make a request to the given server about their own users
+ res = await self.transport_layer.get_info_of_users(
+ server_name, list(user_id_set)
+ )
+
+ user_id_to_info_dict.update(res)
+
+ return 200, user_id_to_info_dict
+
+
def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)
+ SingleUserInfoServlet(hs).register(http_server)
+ UserInfoServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index d24a199318..c9b9e7f5ff 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -72,9 +72,12 @@ class VersionsRestServlet(RestServlet):
# MSC2326.
"org.matrix.label_based_filtering": True,
# Implements support for cross signing as described in MSC1756
- "org.matrix.e2e_cross_signing": True,
+ # "org.matrix.e2e_cross_signing": True,
# Implements additional endpoints as described in MSC2432
"org.matrix.msc2432": True,
+ # Tchap does not currently assume this rule for r0.5.0
+ # XXX: Remove this when it does
+ "m.lazy_load_members": True,
# Implements additional endpoints as described in MSC2666
"uk.half-shot.msc2666": True,
# Whether new rooms will be set to encrypted or not (based on presets).
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 31a41e4a27..f71a03a12d 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -300,6 +300,7 @@ class FileInfo:
thumbnail_height (int)
thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png
+ thumbnail_length (int): The size of the media file, in bytes.
"""
def __init__(
@@ -312,6 +313,7 @@ class FileInfo:
thumbnail_height=None,
thumbnail_method=None,
thumbnail_type=None,
+ thumbnail_length=None,
):
self.server_name = server_name
self.file_id = file_id
@@ -321,6 +323,7 @@ class FileInfo:
self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type
+ self.thumbnail_length = thumbnail_length
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a632099167..5ac307a62d 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -146,8 +146,7 @@ class PreviewUrlResource(DirectServeJsonResource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
- http_proxy=os.getenvb(b"http_proxy"),
- https_proxy=os.getenvb(b"HTTPS_PROXY"),
+ use_proxy=True,
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
@@ -386,7 +385,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Check whether the URL should be downloaded as oEmbed content instead.
- Params:
+ Args:
url: The URL to check.
Returns:
@@ -403,7 +402,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Request content from an oEmbed endpoint.
- Params:
+ Args:
endpoint: The oEmbed API endpoint.
url: The URL to pass to the API.
@@ -692,27 +691,51 @@ class PreviewUrlResource(DirectServeJsonResource):
def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]:
+ """
+ Calculate metadata for an HTML document.
+
+ This uses lxml to parse the HTML document into the OG response. If errors
+ occur during processing of the document, an empty response is returned.
+
+ Args:
+ body: The HTML document, as bytes.
+ media_url: The URI used to download the body.
+ request_encoding: The character encoding of the body, as a string.
+
+ Returns:
+ The OG response as a dictionary.
+ """
# If there's no body, nothing useful is going to be found.
if not body:
return {}
from lxml import etree
+ # Create an HTML parser. If this fails, log and return no metadata.
try:
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body, parser)
- og = _calc_og(tree, media_uri)
+ except LookupError:
+ # blindly consider the encoding as utf-8.
+ parser = etree.HTMLParser(recover=True, encoding="utf-8")
+ except Exception as e:
+ logger.warning("Unable to create HTML parser: %s" % (e,))
+ return {}
+
+ def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ tree = etree.fromstring(body_attempt, parser)
+ return _calc_og(tree, media_uri)
+
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ try:
+ return _attempt_calc_og(body)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
- parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
- og = _calc_og(tree, media_uri)
-
- return og
+ return _attempt_calc_og(body.decode("utf-8", "ignore"))
-def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
+def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d6880f2e6e..d653a58be9 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,7 +16,7 @@
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
from twisted.web.http import Request
@@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource):
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-
- if thumbnail_infos:
- thumbnail_info = self._select_thumbnail(
- width, height, method, m_type, thumbnail_infos
- )
-
- file_info = FileInfo(
- server_name=None,
- file_id=media_id,
- url_cache=media_info["url_cache"],
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- )
-
- t_type = file_info.thumbnail_type
- t_length = thumbnail_info["thumbnail_length"]
-
- responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(request, responder, t_type, t_length)
- else:
- logger.info("Couldn't find any generated thumbnails")
- respond_404(request)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_id,
+ url_cache=media_info["url_cache"],
+ server_name=None,
+ )
async def _select_or_generate_local_thumbnail(
self,
@@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_info["filesystem_id"],
+ url_cache=None,
+ server_name=server_name,
+ )
+ async def _select_and_respond_with_thumbnail(
+ self,
+ request: Request,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ thumbnail_infos: List[Dict[str, Any]],
+ file_id: str,
+ url_cache: Optional[str] = None,
+ server_name: Optional[str] = None,
+ ) -> None:
+ """
+ Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ request: The incoming request.
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: The URL cache value.
+ server_name: The server name, if this is a remote thumbnail.
+ """
if thumbnail_infos:
- thumbnail_info = self._select_thumbnail(
- width, height, method, m_type, thumbnail_infos
+ file_info = self._select_thumbnail(
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ thumbnail_infos,
+ file_id,
+ url_cache,
+ server_name,
)
- file_info = FileInfo(
- server_name=server_name,
- file_id=media_info["filesystem_id"],
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- )
-
- t_type = file_info.thumbnail_type
- t_length = thumbnail_info["thumbnail_length"]
+ if not file_info:
+ logger.info("Couldn't find a thumbnail matching the desired inputs")
+ respond_404(request)
+ return
responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(
+ request, responder, file_info.thumbnail_type, file_info.thumbnail_length
+ )
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)
@@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height: int,
desired_method: str,
desired_type: str,
- thumbnail_infos,
- ) -> dict:
+ thumbnail_infos: List[Dict[str, Any]],
+ file_id: str,
+ url_cache: Optional[str],
+ server_name: Optional[str],
+ ) -> Optional[FileInfo]:
+ """
+ Choose an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: The URL cache value.
+ server_name: The server name, if this is a remote thumbnail.
+
+ Returns:
+ The thumbnail which best matches the desired parameters.
+ """
+ desired_method = desired_method.lower()
+
+ # The chosen thumbnail.
+ thumbnail_info = None
+
d_w = desired_width
d_h = desired_height
- if desired_method.lower() == "crop":
+ if desired_method == "crop":
+ # Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list = []
+ # Other thumbnails.
crop_info_list2 = []
for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info["thumbnail_method"] != "crop":
+ continue
+
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
- t_method = info["thumbnail_method"]
- if t_method == "crop":
- aspect_quality = abs(d_w * t_h - d_h * t_w)
- min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
- size_quality = abs((d_w - t_w) * (d_h - t_h))
- type_quality = desired_type != info["thumbnail_type"]
- length_quality = info["thumbnail_length"]
- if t_w >= d_w or t_h >= d_h:
- crop_info_list.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
+ aspect_quality = abs(d_w * t_h - d_h * t_w)
+ min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+ size_quality = abs((d_w - t_w) * (d_h - t_h))
+ type_quality = desired_type != info["thumbnail_type"]
+ length_quality = info["thumbnail_length"]
+ if t_w >= d_w or t_h >= d_h:
+ crop_info_list.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
)
- else:
- crop_info_list2.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
+ )
+ else:
+ crop_info_list2.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
)
+ )
if crop_info_list:
- return min(crop_info_list)[-1]
- else:
- return min(crop_info_list2)[-1]
- else:
+ thumbnail_info = min(crop_info_list)[-1]
+ elif crop_info_list2:
+ thumbnail_info = min(crop_info_list2)[-1]
+ elif desired_method == "scale":
+ # Thumbnails that match equal or larger sizes of desired width/height.
info_list = []
+ # Other thumbnails.
info_list2 = []
+
for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info["thumbnail_method"] != "scale":
+ continue
+
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
- t_method = info["thumbnail_method"]
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
- if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
+ if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info))
- elif t_method == "scale":
+ else:
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
if info_list:
- return min(info_list)[-1]
- else:
- return min(info_list2)[-1]
+ thumbnail_info = min(info_list)[-1]
+ elif info_list2:
+ thumbnail_info = min(info_list2)[-1]
+
+ if thumbnail_info:
+ return FileInfo(
+ file_id=file_id,
+ url_cache=url_cache,
+ server_name=server_name,
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
+ thumbnail_length=thumbnail_info["thumbnail_length"],
+ )
+
+ # No matching thumbnail was found.
+ return None
diff --git a/synapse/rulecheck/__init__.py b/synapse/rulecheck/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/rulecheck/__init__.py
diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py
new file mode 100644
index 0000000000..6f2a1931c5
--- /dev/null
+++ b/synapse/rulecheck/domain_rule_checker.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.config._base import ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DomainRuleChecker(object):
+ """
+ A re-implementation of the SpamChecker that prevents users in one domain from
+ inviting users in other domains to rooms, based on a configuration.
+
+ Takes a config in the format:
+
+ spam_checker:
+ module: "rulecheck.DomainRuleChecker"
+ config:
+ domain_mapping:
+ "inviter_domain": [ "invitee_domain_permitted", "other_domain_permitted" ]
+ "other_inviter_domain": [ "invitee_domain_permitted" ]
+ default: False
+
+ # Only let local users join rooms if they were explicitly invited.
+ can_only_join_rooms_with_invite: false
+
+ # Only let local users create rooms if they are inviting only one
+ # other user, and that user matches the rules above.
+ can_only_create_one_to_one_rooms: false
+
+ # Only let local users invite during room creation, regardless of the
+ # domain mapping rules above.
+ can_only_invite_during_room_creation: false
+
+ # Prevent local users from inviting users from certain domains to
+ # rooms published in the room directory.
+ domains_prevented_from_being_invited_to_published_rooms: []
+
+ # Allow third party invites
+ can_invite_by_third_party_id: true
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config):
+ self.domain_mapping = config["domain_mapping"] or {}
+ self.default = config["default"]
+
+ self.can_only_join_rooms_with_invite = config.get(
+ "can_only_join_rooms_with_invite", False
+ )
+ self.can_only_create_one_to_one_rooms = config.get(
+ "can_only_create_one_to_one_rooms", False
+ )
+ self.can_only_invite_during_room_creation = config.get(
+ "can_only_invite_during_room_creation", False
+ )
+ self.can_invite_by_third_party_id = config.get(
+ "can_invite_by_third_party_id", True
+ )
+ self.domains_prevented_from_being_invited_to_published_rooms = config.get(
+ "domains_prevented_from_being_invited_to_published_rooms", []
+ )
+
+ def check_event_for_spam(self, event):
+ """Implements synapse.events.SpamChecker.check_event_for_spam
+ """
+ return False
+
+ def user_may_invite(
+ self,
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room=False,
+ ):
+ """Implements synapse.events.SpamChecker.user_may_invite
+ """
+ if self.can_only_invite_during_room_creation and not new_room:
+ return False
+
+ if not self.can_invite_by_third_party_id and third_party_invite:
+ return False
+
+ # This is a third party invite (without a bound mxid), so unless we have
+ # banned all third party invites (above) we allow it.
+ if not invitee_userid:
+ return True
+
+ inviter_domain = self._get_domain_from_id(inviter_userid)
+ invitee_domain = self._get_domain_from_id(invitee_userid)
+
+ if inviter_domain not in self.domain_mapping:
+ return self.default
+
+ if (
+ published_room
+ and invitee_domain
+ in self.domains_prevented_from_being_invited_to_published_rooms
+ ):
+ return False
+
+ return invitee_domain in self.domain_mapping[inviter_domain]
+
+ def user_may_create_room(
+ self, userid, invite_list, third_party_invite_list, cloning
+ ):
+ """Implements synapse.events.SpamChecker.user_may_create_room
+ """
+
+ if cloning:
+ return True
+
+ if not self.can_invite_by_third_party_id and third_party_invite_list:
+ return False
+
+ number_of_invites = len(invite_list) + len(third_party_invite_list)
+
+ if self.can_only_create_one_to_one_rooms and number_of_invites != 1:
+ return False
+
+ return True
+
+ def user_may_create_room_alias(self, userid, room_alias):
+ """Implements synapse.events.SpamChecker.user_may_create_room_alias
+ """
+ return True
+
+ def user_may_publish_room(self, userid, room_id):
+ """Implements synapse.events.SpamChecker.user_may_publish_room
+ """
+ return True
+
+ def user_may_join_room(self, userid, room_id, is_invited):
+ """Implements synapse.events.SpamChecker.user_may_join_room
+ """
+ if self.can_only_join_rooms_with_invite and not is_invited:
+ return False
+
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ """Implements synapse.events.SpamChecker.parse_config
+ """
+ if "default" in config:
+ return config
+ else:
+ raise ConfigError("No default set for spam_config DomainRuleChecker")
+
+ @staticmethod
+ def _get_domain_from_id(mxid):
+ """Parses a string and returns the domain part of the mxid.
+
+ Args:
+ mxid (str): a valid mxid
+
+ Returns:
+ str: the domain part of the mxid
+
+ """
+ idx = mxid.find(":")
+ if idx == -1:
+ raise Exception("Invalid ID: %r" % (mxid,))
+ return mxid[idx + 1 :]
diff --git a/synapse/server.py b/synapse/server.py
index 9cdda83aa1..6ffb7e0fd9 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -24,7 +24,6 @@
import abc
import functools
import logging
-import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
import twisted.internet.base
@@ -103,6 +102,7 @@ from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
+from synapse.replication.tcp.external_cache import ExternalCache
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@@ -128,6 +128,8 @@ from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
+ from txredisapi import RedisProtocol
+
from synapse.handlers.oidc_handler import OidcHandler
from synapse.handlers.saml_handler import SamlHandler
@@ -357,11 +359,7 @@ class HomeServer(metaclass=abc.ABCMeta):
"""
An HTTP client that uses configured HTTP(S) proxies.
"""
- return SimpleHttpClient(
- self,
- http_proxy=os.getenvb(b"http_proxy"),
- https_proxy=os.getenvb(b"HTTPS_PROXY"),
- )
+ return SimpleHttpClient(self, use_proxy=True)
@cache_in_self
def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
@@ -373,8 +371,7 @@ class HomeServer(metaclass=abc.ABCMeta):
self,
ip_whitelist=self.config.ip_range_whitelist,
ip_blacklist=self.config.ip_range_blacklist,
- http_proxy=os.getenvb(b"http_proxy"),
- https_proxy=os.getenvb(b"HTTPS_PROXY"),
+ use_proxy=True,
)
@cache_in_self
@@ -716,6 +713,33 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self)
+ @cache_in_self
+ def get_external_cache(self) -> ExternalCache:
+ return ExternalCache(self)
+
+ @cache_in_self
+ def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
+ if not self.config.redis.redis_enabled:
+ return None
+
+ # We only want to import redis module if we're using it, as we have
+ # `txredisapi` as an optional dependency.
+ from synapse.replication.tcp.redis import lazyConnection
+
+ logger.info(
+ "Connecting to redis (host=%r port=%r) for external cache",
+ self.config.redis_host,
+ self.config.redis_port,
+ )
+
+ return lazyConnection(
+ hs=self,
+ host=self.config.redis_host,
+ port=self.config.redis_port,
+ password=self.config.redis.redis_password,
+ reconnect=True,
+ )
+
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 84f59c7d85..3bd9ff8ca0 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -310,6 +310,7 @@ class StateHandler:
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
+ entry = None
else:
# otherwise, we'll need to resolve the state across the prev_events.
@@ -340,9 +341,13 @@ class StateHandler:
current_state_ids=state_ids_before_event,
)
- # XXX: can we update the state cache entry for the new state group? or
- # could we set a flag on resolve_state_groups_for_events to tell it to
- # always make a state group?
+ # Assign the new state group to the cached state entry.
+ #
+ # Note that this can race in that we could generate multiple state
+ # groups for the same state entry, but that is just inefficient
+ # rather than dangerous.
+ if entry and entry.state_group is None:
+ entry.state_group = state_group_before_event
#
# now if it's not a state event, we're done
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index c7220bc778..d2ba4bd2fc 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -262,6 +262,12 @@ class LoggingTransaction:
return self.txn.description
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
+ """Similar to `executemany`, except `txn.rowcount` will not be correct
+ afterwards.
+
+ More efficient than `executemany` on PostgreSQL
+ """
+
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ae561a2da3..5d0845588c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
+from .events_forward_extremities import EventForwardExtremitiesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
@@ -118,6 +119,7 @@ class DataStore(
UIAuthStore,
CacheInvalidationWorkerStore,
ServerMetricsStore,
+ EventForwardExtremitiesStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9097677648..659d8f245f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
DELETE FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ?
"""
- txn.executemany(sql, ((row[0], row[1]) for row in rows))
+ txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
logger.info("Pruned %d device list outbound pokes", count)
@@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Delete older entries in the table, as we really only care about
# when the latest change happened.
- txn.executemany(
+ txn.execute_batch(
"""
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c128889bf9..309f1e865b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -634,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Optional[Dict[str, dict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -724,7 +724,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
- ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+ ) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Take a list of one time keys out of the database.
Args:
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 1b657191a9..438383abe1 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
VALUES (?, ?, ?, ?, ?, ?)
"""
- txn.executemany(
+ txn.execute_batch(
sql,
(
_gen_entry(user_id, actions)
@@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
],
)
- txn.executemany(
+ txn.execute_batch(
"""
UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ?
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5db7d7aaa8..ccda9f1caa 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -473,8 +473,9 @@ class PersistEventsStore:
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
)
- @staticmethod
+ @classmethod
def _add_chain_cover_index(
+ cls,
txn,
db_pool: DatabasePool,
event_to_room_id: Dict[str, str],
@@ -614,60 +615,17 @@ class PersistEventsStore:
if not events_to_calc_chain_id_for:
return
- # We now calculate the chain IDs/sequence numbers for the events. We
- # do this by looking at the chain ID and sequence number of any auth
- # event with the same type/state_key and incrementing the sequence
- # number by one. If there was no match or the chain ID/sequence
- # number is already taken we generate a new chain.
- #
- # We need to do this in a topologically sorted order as we want to
- # generate chain IDs/sequence numbers of an event's auth events
- # before the event itself.
- chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
- new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
- for event_id in sorted_topologically(
- events_to_calc_chain_id_for, event_to_auth_chain
- ):
- existing_chain_id = None
- for auth_id in event_to_auth_chain.get(event_id, []):
- if event_to_types.get(event_id) == event_to_types.get(auth_id):
- existing_chain_id = chain_map[auth_id]
- break
-
- new_chain_tuple = None
- if existing_chain_id:
- # We found a chain ID/sequence number candidate, check its
- # not already taken.
- proposed_new_id = existing_chain_id[0]
- proposed_new_seq = existing_chain_id[1] + 1
- if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
- already_allocated = db_pool.simple_select_one_onecol_txn(
- txn,
- table="event_auth_chains",
- keyvalues={
- "chain_id": proposed_new_id,
- "sequence_number": proposed_new_seq,
- },
- retcol="event_id",
- allow_none=True,
- )
- if already_allocated:
- # Mark it as already allocated so we don't need to hit
- # the DB again.
- chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
- else:
- new_chain_tuple = (
- proposed_new_id,
- proposed_new_seq,
- )
-
- if not new_chain_tuple:
- new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
-
- chains_tuples_allocated.add(new_chain_tuple)
-
- chain_map[event_id] = new_chain_tuple
- new_chain_tuples[event_id] = new_chain_tuple
+ # Allocate chain ID/sequence numbers to each new event.
+ new_chain_tuples = cls._allocate_chain_ids(
+ txn,
+ db_pool,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
+ events_to_calc_chain_id_for,
+ chain_map,
+ )
+ chain_map.update(new_chain_tuples)
db_pool.simple_insert_many_txn(
txn,
@@ -794,6 +752,137 @@ class PersistEventsStore:
],
)
+ @staticmethod
+ def _allocate_chain_ids(
+ txn,
+ db_pool: DatabasePool,
+ event_to_room_id: Dict[str, str],
+ event_to_types: Dict[str, Tuple[str, str]],
+ event_to_auth_chain: Dict[str, List[str]],
+ events_to_calc_chain_id_for: Set[str],
+ chain_map: Dict[str, Tuple[int, int]],
+ ) -> Dict[str, Tuple[int, int]]:
+ """Allocates, but does not persist, chain ID/sequence numbers for the
+ events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
+ for info on args)
+ """
+
+ # We now calculate the chain IDs/sequence numbers for the events. We do
+ # this by looking at the chain ID and sequence number of any auth event
+ # with the same type/state_key and incrementing the sequence number by
+ # one. If there was no match or the chain ID/sequence number is already
+ # taken we generate a new chain.
+ #
+ # We try to reduce the number of times that we hit the database by
+ # batching up calls, to make this more efficient when persisting large
+ # numbers of state events (e.g. during joins).
+ #
+ # We do this by:
+ # 1. Calculating for each event which auth event will be used to
+ # inherit the chain ID, i.e. converting the auth chain graph to a
+ # tree that we can allocate chains on. We also keep track of which
+ # existing chain IDs have been referenced.
+ # 2. Fetching the max allocated sequence number for each referenced
+ # existing chain ID, generating a map from chain ID to the max
+ # allocated sequence number.
+ # 3. Iterating over the tree and allocating a chain ID/seq no. to the
+ # new event, by incrementing the sequence number from the
+ # referenced event's chain ID/seq no. and checking that the
+ # incremented sequence number hasn't already been allocated (by
+ # looking in the map generated in the previous step). We generate a
+ # new chain if the sequence number has already been allocated.
+ #
+
+ existing_chains = set() # type: Set[int]
+ tree = [] # type: List[Tuple[str, Optional[str]]]
+
+ # We need to do this in a topologically sorted order as we want to
+ # generate chain IDs/sequence numbers of an event's auth events before
+ # the event itself.
+ for event_id in sorted_topologically(
+ events_to_calc_chain_id_for, event_to_auth_chain
+ ):
+ for auth_id in event_to_auth_chain.get(event_id, []):
+ if event_to_types.get(event_id) == event_to_types.get(auth_id):
+ existing_chain_id = chain_map.get(auth_id)
+ if existing_chain_id:
+ existing_chains.add(existing_chain_id[0])
+
+ tree.append((event_id, auth_id))
+ break
+ else:
+ tree.append((event_id, None))
+
+ # Fetch the current max sequence number for each existing referenced chain.
+ sql = """
+ SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
+ WHERE %s
+ GROUP BY chain_id
+ """
+ clause, args = make_in_list_sql_clause(
+ db_pool.engine, "chain_id", existing_chains
+ )
+ txn.execute(sql % (clause,), args)
+
+ chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
+
+ # Allocate the new events chain ID/sequence numbers.
+ #
+ # To reduce the number of calls to the database we don't allocate a
+ # chain ID number in the loop, instead we use a temporary `object()` for
+ # each new chain ID. Once we've done the loop we generate the necessary
+ # number of new chain IDs in one call, replacing all temporary
+ # objects with real allocated chain IDs.
+
+ unallocated_chain_ids = set() # type: Set[object]
+ new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
+ for event_id, auth_event_id in tree:
+ # If we reference an auth_event_id we fetch the allocated chain ID,
+ # either from the existing `chain_map` or the newly generated
+ # `new_chain_tuples` map.
+ existing_chain_id = None
+ if auth_event_id:
+ existing_chain_id = new_chain_tuples.get(auth_event_id)
+ if not existing_chain_id:
+ existing_chain_id = chain_map[auth_event_id]
+
+ new_chain_tuple = None # type: Optional[Tuple[Any, int]]
+ if existing_chain_id:
+ # We found a chain ID/sequence number candidate, check its
+ # not already taken.
+ proposed_new_id = existing_chain_id[0]
+ proposed_new_seq = existing_chain_id[1] + 1
+
+ if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
+ new_chain_tuple = (
+ proposed_new_id,
+ proposed_new_seq,
+ )
+
+ # If we need to start a new chain we allocate a temporary chain ID.
+ if not new_chain_tuple:
+ new_chain_tuple = (object(), 1)
+ unallocated_chain_ids.add(new_chain_tuple[0])
+
+ new_chain_tuples[event_id] = new_chain_tuple
+ chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
+
+ # Generate new chain IDs for all unallocated chain IDs.
+ newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+ txn, len(unallocated_chain_ids)
+ )
+
+ # Map from potentially temporary chain ID to real chain ID
+ chain_id_to_allocated_map = dict(
+ zip(unallocated_chain_ids, newly_allocated_chain_ids)
+ ) # type: Dict[Any, int]
+ chain_id_to_allocated_map.update((c, c) for c in existing_chains)
+
+ return {
+ event_id: (chain_id_to_allocated_map[chain_id], seq)
+ for event_id, (chain_id, seq) in new_chain_tuples.items()
+ }
+
def _persist_transaction_ids_txn(
self,
txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e46e44ba54..5ca4fa6817 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -139,8 +139,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
-
def reindex_txn(txn):
sql = (
"SELECT stream_ordering, event_id, json FROM events"
@@ -178,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
- for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
- clump = update_rows[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ txn.execute_batch(sql, update_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@@ -210,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
-
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
@@ -256,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
- for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
- clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ txn.execute_batch(sql, rows_to_update)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
new file mode 100644
index 0000000000..0ac1da9c35
--- /dev/null
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Dict, List
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class EventForwardExtremitiesStore(SQLBaseStore):
+ async def delete_forward_extremities_for_room(self, room_id: str) -> int:
+ """Delete any extra forward extremities for a room.
+
+ Invalidates the "get_latest_event_ids_in_room" cache if any forward
+ extremities were deleted.
+
+ Returns count deleted.
+ """
+
+ def delete_forward_extremities_for_room_txn(txn):
+ # First we need to get the event_id to not delete
+ sql = """
+ SELECT event_id FROM event_forward_extremities
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id = ?
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ """
+ txn.execute(sql, (room_id,))
+ rows = txn.fetchall()
+ try:
+ event_id = rows[0][0]
+ logger.debug(
+ "Found event_id %s as the forward extremity to keep for room %s",
+ event_id,
+ room_id,
+ )
+ except KeyError:
+ msg = "No forward extremity event found for room %s" % room_id
+ logger.warning(msg)
+ raise SynapseError(400, msg)
+
+ # Now delete the extra forward extremities
+ sql = """
+ DELETE FROM event_forward_extremities
+ WHERE event_id != ? AND room_id = ?
+ """
+
+ txn.execute(sql, (event_id, room_id))
+ logger.info(
+ "Deleted %s extra forward extremities for room %s",
+ txn.rowcount,
+ room_id,
+ )
+
+ if txn.rowcount > 0:
+ # Invalidate the cache
+ self._invalidate_cache_and_stream(
+ txn, self.get_latest_event_ids_in_room, (room_id,),
+ )
+
+ return txn.rowcount
+
+ return await self.db_pool.runInteraction(
+ "delete_forward_extremities_for_room",
+ delete_forward_extremities_for_room_txn,
+ )
+
+ async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+ """Get list of forward extremities for a room."""
+
+ def get_forward_extremities_for_room_txn(txn):
+ sql = """
+ SELECT event_id, state_group, depth, received_ts
+ FROM event_forward_extremities
+ INNER JOIN event_to_state_groups USING (event_id)
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id = ?
+ """
+
+ txn.execute(sql, (room_id,))
+ return self.db_pool.cursor_to_dict(txn)
+
+ return await self.db_pool.runInteraction(
+ "get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
+ )
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 71d823be72..5c4c251871 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -525,7 +525,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_stripped_room_state_from_event_context(
self,
context: EventContext,
- state_types_to_include: List[EventTypes],
+ state_types_to_include: List[str],
membership_user_id: Optional[str] = None,
) -> List[JsonDict]:
"""
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 283c8a5e22..e017177655 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_origin = ? AND media_id = ?"
)
- txn.executemany(
+ txn.execute_batch(
sql,
(
(time_ms, media_origin, media_id)
@@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_id = ?"
)
- txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
+ txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
@@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache", _delete_url_cache_txn
@@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_media_txn(txn):
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index ab18cc4d79..92e65aa640 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1]
)
+ async def count_daily_e2ee_messages(self):
+ """
+ Returns an estimate of the number of messages sent in the last day.
+
+ If it has been significantly less or more than one day since the last
+ call to this function, it will return None.
+ """
+
+ def _count_messages(txn):
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
+
+ async def count_daily_sent_e2ee_messages(self):
+ def _count_messages(txn):
+ # This is good enough as if you have silly characters in your own
+ # hostname then thats your own fault.
+ like_clause = "%:" + self.hs.hostname
+
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND sender LIKE ?
+ AND stream_ordering > ?
+ """
+
+ txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction(
+ "count_daily_sent_e2ee_messages", _count_messages
+ )
+
+ async def count_daily_active_e2ee_rooms(self):
+ def _count(txn):
+ sql = """
+ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction(
+ "count_daily_active_e2ee_rooms", _count
+ )
+
async def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 54ef0f1f54..4360dc0afc 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,11 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.databases.main.roommember import ProfileInfo
+from synapse.types import UserID
+from synapse.util.caches.descriptors import cached
+
+BATCH_SIZE = 100
class ProfileWorkerStore(SQLBaseStore):
@@ -39,6 +44,7 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
+ @cached(max_entries=5000)
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
@@ -47,6 +53,7 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_displayname",
)
+ @cached(max_entries=5000)
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
@@ -55,6 +62,58 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ async def get_latest_profile_replication_batch_number(self):
+ def f(txn):
+ txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
+ rows = self.db_pool.cursor_to_dict(txn)
+ return rows[0]["maxbatch"]
+
+ return await self.db_pool.runInteraction(
+ "get_latest_profile_replication_batch_number", f
+ )
+
+ async def get_profile_batch(self, batchnum):
+ return await self.db_pool.simple_select_list(
+ table="profiles",
+ keyvalues={"batch": batchnum},
+ retcols=("user_id", "displayname", "avatar_url", "active"),
+ desc="get_profile_batch",
+ )
+
+ async def assign_profile_batch(self):
+ def f(txn):
+ sql = (
+ "UPDATE profiles SET batch = "
+ "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) "
+ "WHERE user_id in ("
+ " SELECT user_id FROM profiles WHERE batch is NULL limit ?"
+ ")"
+ )
+ txn.execute(sql, (BATCH_SIZE,))
+ return txn.rowcount
+
+ return await self.db_pool.runInteraction("assign_profile_batch", f)
+
+ async def get_replication_hosts(self):
+ def f(txn):
+ txn.execute(
+ "SELECT host, last_synced_batch FROM profile_replication_status"
+ )
+ rows = self.db_pool.cursor_to_dict(txn)
+ return {r["host"]: r["last_synced_batch"] for r in rows}
+
+ return await self.db_pool.runInteraction("get_replication_hosts", f)
+
+ async def update_replication_batch_for_host(
+ self, host: str, last_synced_batch: int
+ ):
+ return await self.db_pool.simple_upsert(
+ table="profile_replication_status",
+ keyvalues={"host": host},
+ values={"last_synced_batch": last_synced_batch},
+ desc="update_replication_batch_for_host",
+ )
+
async def get_from_remote_profile_cache(
self, user_id: str
) -> Optional[Dict[str, Any]]:
@@ -72,32 +131,95 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def set_profile_displayname(
- self, user_localpart: str, new_displayname: Optional[str]
+ self, user_localpart: str, new_displayname: Optional[str], batchnum: int
) -> None:
- await self.db_pool.simple_update_one(
+ # Invalidate the read cache for this user
+ self.get_profile_displayname.invalidate((user_localpart,))
+
+ await self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
+ values={"displayname": new_displayname, "batch": batchnum},
desc="set_profile_displayname",
+ lock=False, # we can do this because user_id has a unique index
)
async def set_profile_avatar_url(
- self, user_localpart: str, new_avatar_url: Optional[str]
+ self, user_localpart: str, new_avatar_url: Optional[str], batchnum: int
) -> None:
- await self.db_pool.simple_update_one(
+ # Invalidate the read cache for this user
+ self.get_profile_avatar_url.invalidate((user_localpart,))
+
+ await self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
+ values={"avatar_url": new_avatar_url, "batch": batchnum},
desc="set_profile_avatar_url",
+ lock=False, # we can do this because user_id has a unique index
+ )
+
+ async def set_profiles_active(
+ self, users: List[UserID], active: bool, hide: bool, batchnum: int,
+ ) -> None:
+ """Given a set of users, set active and hidden flags on them.
+
+ Args:
+ users: A list of UserIDs
+ active: Whether to set the users to active or inactive
+ hide: Whether to hide the users (withold from replication). If
+ False and active is False, users will have their profiles
+ erased
+ batchnum: The batch number, used for profile replication
+ """
+ # Convert list of localparts to list of tuples containing localparts
+ user_localparts = [(user.localpart,) for user in users]
+
+ # Generate list of value tuples for each user
+ value_names = ("active", "batch")
+ values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple]
+
+ if not active and not hide:
+ # we are deactivating for real (not in hide mode)
+ # so clear the profile information
+ value_names += ("avatar_url", "displayname")
+ values = [v + (None, None) for v in values]
+
+ return await self.db_pool.runInteraction(
+ "set_profiles_active",
+ self.db_pool.simple_upsert_many_txn,
+ table="profiles",
+ key_names=("user_id",),
+ key_values=user_localparts,
+ value_names=value_names,
+ value_values=values,
+ )
+
+ async def add_remote_profile_cache(
+ self, user_id: str, displayname: str, avatar_url: str
+ ) -> None:
+ """Ensure we are caching the remote user's profiles.
+
+ This should only be called when `is_subscribed_remote_profile_for_user`
+ would return true for the user.
+ """
+ await self.db_pool.simple_upsert(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ values={
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ "last_check": self._clock.time_msec(),
+ },
+ desc="add_remote_profile_cache",
)
async def update_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> int:
- return await self.db_pool.simple_update(
+ return await self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
- updatevalues={
+ values={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
@@ -166,6 +288,17 @@ class ProfileWorkerStore(SQLBaseStore):
class ProfileStore(ProfileWorkerStore):
+ def __init__(self, database, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_index_update(
+ "profile_replication_status_host_index",
+ index_name="profile_replication_status_idx",
+ table="profile_replication_status",
+ columns=["host"],
+ unique=True,
+ )
+
async def add_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> None:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 5d668aadb2..ecfc9f20b1 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
)
# Update backward extremeties
- txn.executemany(
+ txn.execute_batch(
"INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)",
[(room_id, event_id) for event_id, in new_backwards_extrems],
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bc7621b8d6..2687ef3e43 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -344,7 +344,9 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,)
)
- self.db_pool.simple_delete_one_txn(
+ # It is expected that there is exactly one pusher to delete, but
+ # if it isn't there (or there are multiple) delete them all.
+ self.db_pool.simple_delete_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 14c0878d81..269eb6e6e7 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -82,12 +82,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)
- self._account_validity = hs.config.account_validity
- if hs.config.run_background_tasks and self._account_validity.enabled:
- self._clock.call_later(
- 0.0, self._set_expiration_date_when_missing,
+ self._account_validity_enabled = hs.config.account_validity_enabled
+ if self._account_validity_enabled:
+ self._account_validity_period = hs.config.account_validity_period
+ self._account_validity_startup_job_max_delta = (
+ hs.config.account_validity_startup_job_max_delta
)
+ if hs.config.run_background_tasks:
+ self._clock.call_later(
+ 0.0, self._set_expiration_date_when_missing,
+ )
+
# Create a background job for culling expired 3PID validity tokens
if hs.config.run_background_tasks:
self._clock.looping_call(
@@ -183,6 +189,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
expiration_ts: int,
email_sent: bool,
renewal_token: Optional[str] = None,
+ token_used_ts: Optional[int] = None,
) -> None:
"""Updates the account validity properties of the given account, with the
given values.
@@ -196,6 +203,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
period.
renewal_token: Renewal token the user can use to extend the validity
of their account. Defaults to no token.
+ token_used_ts: A timestamp of when the current token was used to renew
+ the account.
"""
def set_account_validity_for_user_txn(txn):
@@ -207,6 +216,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"expiration_ts_ms": expiration_ts,
"email_sent": email_sent,
"renewal_token": renewal_token,
+ "token_used_ts_ms": token_used_ts,
},
)
self._invalidate_cache_and_stream(
@@ -217,10 +227,41 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"set_account_validity_for_user", set_account_validity_for_user_txn
)
+ async def get_expired_users(self):
+ """Get UserIDs of all expired users.
+
+ Users who are not active, or do not have profile information, are
+ excluded from the results.
+
+ Returns:
+ Deferred[List[UserID]]: List of expired user IDs
+ """
+
+ def get_expired_users_txn(txn, now_ms):
+ # We need to use pattern matching as profiles.user_id is confusingly just the
+ # user's localpart, whereas account_validity.user_id is a full user ID
+ sql = """
+ SELECT av.user_id from account_validity AS av
+ LEFT JOIN profiles as p
+ ON av.user_id LIKE '%%' || p.user_id || ':%%'
+ WHERE expiration_ts_ms <= ?
+ AND p.active = 1
+ """
+ txn.execute(sql, (now_ms,))
+ rows = txn.fetchall()
+
+ return [UserID.from_string(row[0]) for row in rows]
+
+ res = await self.db_pool.runInteraction(
+ "get_expired_users", get_expired_users_txn, self._clock.time_msec()
+ )
+
+ return res
+
async def set_renewal_token_for_user(
self, user_id: str, renewal_token: str
) -> None:
- """Defines a renewal token for a given user.
+ """Defines a renewal token for a given user, and clears the token_used timestamp.
Args:
user_id: ID of the user to set the renewal token for.
@@ -233,26 +274,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
- updatevalues={"renewal_token": renewal_token},
+ updatevalues={"renewal_token": renewal_token, "token_used_ts_ms": None},
desc="set_renewal_token_for_user",
)
- async def get_user_from_renewal_token(self, renewal_token: str) -> str:
- """Get a user ID from a renewal token.
+ async def get_user_from_renewal_token(
+ self, renewal_token: str
+ ) -> Tuple[str, int, Optional[int]]:
+ """Get a user ID and renewal status from a renewal token.
Args:
renewal_token: The renewal token to perform the lookup with.
Returns:
- The ID of the user to which the token belongs.
+ A tuple of containing the following values:
+ * The ID of a user to which the token belongs.
+ * An int representing the user's expiry timestamp as milliseconds since the
+ epoch, or 0 if the token was invalid.
+ * An optional int representing the timestamp of when the user renewed their
+ account timestamp as milliseconds since the epoch. None if the account
+ has not been renewed using the current token yet.
"""
- return await self.db_pool.simple_select_one_onecol(
+ ret_dict = await self.db_pool.simple_select_one(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
- retcol="user_id",
+ retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
desc="get_user_from_renewal_token",
)
+ return (
+ ret_dict["user_id"],
+ ret_dict["expiration_ts_ms"],
+ ret_dict["token_used_ts_ms"],
+ )
+
async def get_renewal_token_for_user(self, user_id: str) -> str:
"""Get the renewal token associated with a given user ID.
@@ -291,7 +346,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"get_users_expiring_soon",
select_users_txn,
self._clock.time_msec(),
- self.config.account_validity.renew_at,
+ self.config.account_validity_renew_at,
)
async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
@@ -323,6 +378,54 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="delete_account_validity_for_user",
)
+ async def get_info_for_users(
+ self, user_ids: List[str],
+ ):
+ """Return the user info for a given set of users
+
+ Args:
+ user_ids: A list of users to return information about
+
+ Returns:
+ Deferred[Dict[str, bool]]: A dictionary mapping each user ID to
+ a dict with the following keys:
+ * expired - whether this is an expired user
+ * deactivated - whether this is a deactivated user
+ """
+ # Get information of all our local users
+ def _get_info_for_users_txn(txn):
+ rows = []
+
+ for user_id in user_ids:
+ sql = """
+ SELECT u.name, u.deactivated, av.expiration_ts_ms
+ FROM users as u
+ LEFT JOIN account_validity as av
+ ON av.user_id = u.name
+ WHERE u.name = ?
+ """
+
+ txn.execute(sql, (user_id,))
+ row = txn.fetchone()
+ if row:
+ rows.append(row)
+
+ return rows
+
+ info_rows = await self.db_pool.runInteraction(
+ "get_info_for_users", _get_info_for_users_txn
+ )
+
+ return {
+ user_id: {
+ "expired": (
+ expiration is not None and self._clock.time_msec() >= expiration
+ ),
+ "deactivated": deactivated == 1,
+ }
+ for user_id, deactivated, expiration in info_rows
+ }
+
async def is_server_admin(self, user: UserID) -> bool:
"""Determines if a user is an admin of this homeserver.
@@ -360,6 +463,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
+ async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
+ """Sets whether a user shadow-banned.
+
+ Args:
+ user: user ID of the user to test
+ shadow_banned: true iff the user is to be shadow-banned, false otherwise.
+ """
+
+ def set_shadow_banned_txn(txn):
+ self.db_pool.simple_update_one_txn(
+ txn,
+ table="users",
+ keyvalues={"name": user.to_string()},
+ updatevalues={"shadow_banned": shadow_banned},
+ )
+ # In order for this to apply immediately, clear the cache for this user.
+ tokens = self.db_pool.simple_select_onecol_txn(
+ txn,
+ table="access_tokens",
+ keyvalues={"user_id": user.to_string()},
+ retcol="token",
+ )
+ for token in tokens:
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_access_token, (token,)
+ )
+
+ await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
+
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
@@ -922,11 +1054,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
delta equal to 10% of the validity period.
"""
now_ms = self._clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
+ expiration_ts = now_ms + self._account_validity_period
if use_delta:
expiration_ts = self.rand.randrange(
- expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts - self._account_validity_startup_job_max_delta,
expiration_ts,
)
@@ -1124,7 +1256,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
FROM user_threepids
"""
- txn.executemany(sql, [(id_server,) for id_server in id_servers])
+ txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
if id_servers:
await self.db_pool.runInteraction(
@@ -1364,7 +1496,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
except self.database_engine.module.IntegrityError:
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
- if self._account_validity.enabled:
+ if self._account_validity_enabled:
self.set_expiration_date_for_user_txn(txn, user_id)
if create_profile_with_displayname:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index a9fcb5f59c..a98b423771 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -13,14 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import collections
import logging
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -178,11 +177,13 @@ class RoomWorkerStore(SQLBaseStore):
INNER JOIN room_stats_current USING (room_id)
WHERE
(
- join_rules = 'public' OR history_visibility = 'world_readable'
+ join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
+ OR history_visibility = 'world_readable'
)
AND joined_members > 0
""" % {
- "published_sql": published_sql
+ "published_sql": published_sql,
+ "knock_join_rule": JoinRules.KNOCK,
}
txn.execute(sql, query_args)
@@ -305,7 +306,7 @@ class RoomWorkerStore(SQLBaseStore):
sql = """
SELECT
room_id, name, topic, canonical_alias, joined_members,
- avatar, history_visibility, joined_members, guest_access
+ avatar, history_visibility, guest_access, join_rules
FROM (
%(published_sql)s
) published
@@ -313,7 +314,8 @@ class RoomWorkerStore(SQLBaseStore):
INNER JOIN room_stats_current USING (room_id)
WHERE
(
- join_rules = 'public' OR history_visibility = 'world_readable'
+ join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
+ OR history_visibility = 'world_readable'
)
AND joined_members > 0
%(where_clause)s
@@ -322,6 +324,7 @@ class RoomWorkerStore(SQLBaseStore):
"published_sql": published_sql,
"where_clause": where_clause,
"dir": "DESC" if forwards else "ASC",
+ "knock_join_rule": JoinRules.KNOCK,
}
if limit is not None:
@@ -356,6 +359,23 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ async def is_room_published(self, room_id: str) -> bool:
+ """Check whether a room has been published in the local public room
+ directory.
+
+ Args:
+ room_id
+ Returns:
+ Whether the room is currently published in the room directory
+ """
+ # Get room information
+ room_info = await self.get_room(room_id)
+ if not room_info:
+ return False
+
+ # Check the is_public value
+ return room_info.get("is_public", False)
+
async def get_rooms_paginate(
self,
start: int,
@@ -564,6 +584,11 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
"""
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum, in order not to filter out events we should filter out when sending to
+ # the client.
+ if not self.config.retention_enabled:
+ return {"min_lifetime": None, "max_lifetime": None}
def get_retention_policy_for_room_txn(txn):
txn.execute(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index dcdaf09682..92382bed28 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
"max_stream_id_exclusive", self._stream_order_on_start + 1
)
- INSERT_CLUMP_SIZE = 1000
-
def add_membership_profile_txn(txn):
sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
UPDATE room_memberships SET display_name = ?, avatar_url = ?
WHERE event_id = ? AND room_id = ?
"""
- for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
- clump = to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(to_update_sql, clump)
+ txn.execute_batch(to_update_sql, to_update)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
new file mode 100644
index 0000000000..e744c02fe8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Add a batch number to track changes to profiles and the
+ * order they're made in so we can replicate user profiles
+ * to other hosts as they change
+ */
+ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL;
+
+/*
+ * Index on the batch number so we can get profiles
+ * by their batch
+ */
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+
+/*
+ * A table to track what batch of user profiles has been
+ * synced to what profile replication target.
+ */
+CREATE TABLE profile_replication_status (
+ host TEXT NOT NULL,
+ last_synced_batch BIGINT NOT NULL
+);
diff --git a/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
new file mode 100644
index 0000000000..96051ac179
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * A flag saying whether the user owning the profile has been deactivated
+ * This really belongs on the users table, not here, but the users table
+ * stores users by their full user_id and profiles stores them by localpart,
+ * so we can't easily join between the two tables. Plus, the batch number
+ * realy ought to represent data in this table that has changed.
+ */
+ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
new file mode 100644
index 0000000000..7542ab8cbd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE UNIQUE INDEX profile_replication_status_idx ON profile_replication_status(host);
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql b/synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql
new file mode 100644
index 0000000000..4836dac16e
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Track when users renew their account using the value of the 'renewal_token' column.
+-- This field should be set to NULL after a fresh token is generated.
+ALTER TABLE account_validity ADD token_used_ts_ms BIGINT;
diff --git a/synapse/storage/databases/main/schema/delta/58/24add_knock_members_to_stats.sql b/synapse/storage/databases/main/schema/delta/58/24add_knock_members_to_stats.sql
new file mode 100644
index 0000000000..658f55a384
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/24add_knock_members_to_stats.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 Sorunome
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_stats_current ADD knocked_members INT NOT NULL DEFAULT '0';
+ALTER TABLE room_stats_historical ADD knocked_members BIGINT NOT NULL DEFAULT '0';
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
index f35c70b699..9e8f35c1d2 100644
--- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
# { "ignored_users": "@someone:example.org": {} }
ignored_users = content.get("ignored_users", {})
if isinstance(ignored_users, dict) and ignored_users:
- cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
+ cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
# Add indexes after inserting data for efficiency.
logger.info("Adding constraints to ignored_users table")
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..20c5af2eb7 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
@@ -658,10 +658,19 @@ CREATE TABLE presence_stream (
+CREATE TABLE profile_replication_status (
+ host text NOT NULL,
+ last_synced_batch bigint NOT NULL
+);
+
+
+
CREATE TABLE profiles (
user_id text NOT NULL,
displayname text,
- avatar_url text
+ avatar_url text,
+ batch bigint,
+ active smallint DEFAULT 1 NOT NULL
);
@@ -1788,6 +1797,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id);
+CREATE INDEX profiles_batch_idx ON profiles USING btree (batch);
+
+
+
CREATE INDEX public_room_index ON rooms USING btree (is_public);
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..e28ec3fa45 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us
CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) );
CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) );
CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL );
-CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) );
+CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) );
CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) );
CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER );
CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) );
@@ -202,6 +202,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id);
CREATE INDEX group_invites_u_idx ON group_invites(user_id);
CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL );
CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL );
CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e34fce6281..f5e7d9ef98 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -63,7 +64,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
- txn.executemany(sql, args)
+ txn.execute_batch(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
@@ -75,7 +76,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
- txn.executemany(sql, args)
+ txn.execute_batch(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
async def search_rooms(
self,
- room_ids: List[str],
+ room_ids: Collection[str],
search_term: str,
keys: List[str],
limit,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 0cdb3ec1f7..4ad363fb0d 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,11 +15,12 @@
# limitations under the License.
import logging
-from collections import Counter
from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple
+from typing_extensions import Counter
+
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
@@ -41,6 +42,7 @@ ABSOLUTE_STATS_FIELDS = {
"current_state_events",
"joined_members",
"invited_members",
+ "knocked_members",
"left_members",
"banned_members",
"local_users_in_room",
@@ -319,7 +321,9 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
- async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
+ async def get_earliest_token_for_stats(
+ self, stats_type: str, id: str
+ ) -> Optional[int]:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@@ -339,7 +343,7 @@ class StatsStore(StateDeltasStore):
)
async def bulk_update_stats_delta(
- self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+ self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
) -> None:
"""Bulk update stats tables for a given stream_id and updates the stats
incremental position.
@@ -665,7 +669,7 @@ class StatsStore(StateDeltasStore):
async def get_changes_room_total_events_and_bytes(
self, min_pos: int, max_pos: int
- ) -> Dict[str, Dict[str, int]]:
+ ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
"""Fetches the counts of events in the given range of stream IDs.
Args:
@@ -683,18 +687,19 @@ class StatsStore(StateDeltasStore):
max_pos,
)
- def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
+ def get_changes_room_total_events_and_bytes_txn(
+ self, txn, low_pos: int, high_pos: int
+ ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
"""Gets the total_events and total_event_bytes counts for rooms and
senders, in a range of stream_orderings (including backfilled events).
Args:
txn
- low_pos (int): Low stream ordering
- high_pos (int): High stream ordering
+ low_pos: Low stream ordering
+ high_pos: High stream ordering
Returns:
- tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
- room and user deltas for total_events/total_event_bytes in the
+ The room and user deltas for total_events/total_event_bytes in the
format of `stats_id` -> fields
"""
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index ef11f1c3b3..336da218ed 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory",
)
- async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+ async def update_user_directory_stream_pos(self, stream_id: int) -> None:
await self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
@@ -558,6 +558,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
+ self._prefer_local_users_in_search = (
+ hs.config.user_directory_search_prefer_local_users
+ )
+ self._server_name = hs.config.server_name
+
async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
self.db_pool.simple_delete_txn(
@@ -750,9 +755,24 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
"""
+ # We allow manipulating the ranking algorithm by injecting statements
+ # based on config options.
+ additional_ordering_statements = []
+ ordering_arguments = ()
+
if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
+ # If enabled, this config option will rank local users higher than those on
+ # remote instances.
+ if self._prefer_local_users_in_search:
+ # This statement checks whether a given user's user ID contains a server name
+ # that matches the local server
+ statement = "* (CASE WHEN user_id LIKE ? THEN 2.0 ELSE 1.0 END)"
+ additional_ordering_statements.append(statement)
+
+ ordering_arguments += ("%:" + self._server_name,)
+
# We order by rank and then if they have profile info
# The ranking algorithm is hand tweaked for "best" results. Broadly
# the idea is we give a higher weight to exact matches.
@@ -763,7 +783,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
FROM user_directory_search as t
INNER JOIN user_directory AS d USING (user_id)
WHERE
- %s
+ %(where_clause)s
AND vector @@ to_tsquery('simple', ?)
ORDER BY
(CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
@@ -783,33 +803,54 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
8
)
)
+ %(order_case_statements)s
DESC,
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
- """ % (
- where_clause,
+ """ % {
+ "where_clause": where_clause,
+ "order_case_statements": " ".join(additional_ordering_statements),
+ }
+ args = (
+ join_args
+ + (full_query, exact_query, prefix_query)
+ + ordering_arguments
+ + (limit + 1,)
)
- args = join_args + (full_query, exact_query, prefix_query, limit + 1)
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term)
+ # If enabled, this config option will rank local users higher than those on
+ # remote instances.
+ if self._prefer_local_users_in_search:
+ # This statement checks whether a given user's user ID contains a server name
+ # that matches the local server
+ #
+ # Note that we need to include a comma at the end for valid SQL
+ statement = "user_id LIKE ? DESC,"
+ additional_ordering_statements.append(statement)
+
+ ordering_arguments += ("%:" + self._server_name,)
+
sql = """
SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search as t
INNER JOIN user_directory AS d USING (user_id)
WHERE
- %s
+ %(where_clause)s
AND value MATCH ?
ORDER BY
rank(matchinfo(user_directory_search)) DESC,
+ %(order_statements)s
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
- """ % (
- where_clause,
- )
- args = join_args + (search_query, limit + 1)
+ """ % {
+ "where_clause": where_clause,
+ "order_statements": " ".join(additional_ordering_statements),
+ }
+ args = join_args + (search_query,) + ordering_arguments + (limit + 1,)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0e31cc811a..89cdc84a9c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
logger.info("[purge] removing redundant state groups")
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM state_groups_state WHERE state_group = ?",
((sg,) for sg in state_groups_to_delete),
)
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM state_groups WHERE id = ?",
((sg,) for sg in state_groups_to_delete),
)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index bb84c0d792..71ef5a72dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@
import heapq
import logging
import threading
-from collections import deque
+from collections import OrderedDict
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union
import attr
-from typing_extensions import Deque
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -101,7 +100,13 @@ class StreamIdGenerator:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque() # type: Deque[int]
+
+ # We use this as an ordered set, as we want to efficiently append items,
+ # remove items and get the first item. Since we insert IDs in order, the
+ # insertion ordering will ensure its in the correct ordering.
+ #
+ # The key and values are the same, but we never look at the values.
+ self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
def get_next(self):
"""
@@ -113,7 +118,7 @@ class StreamIdGenerator:
self._current += self._step
next_id = self._current
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -121,7 +126,7 @@ class StreamIdGenerator:
yield next_id
finally:
with self._lock:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -140,7 +145,7 @@ class StreamIdGenerator:
self._current += n * self._step
for next_id in next_ids:
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -149,7 +154,7 @@ class StreamIdGenerator:
finally:
with self._lock:
for next_id in next_ids:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -162,7 +167,7 @@ class StreamIdGenerator:
"""
with self._lock:
if self._unfinished_ids:
- return self._unfinished_ids[0] - self._step
+ return next(iter(self._unfinished_ids)) - self._step
return self._current
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c780ade077..0ec4dc2918 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -70,6 +70,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
...
@abc.abstractmethod
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ """Get the next `n` IDs in the sequence"""
+ ...
+
+ @abc.abstractmethod
def check_consistency(
self,
db_conn: "LoggingDatabaseConnection",
@@ -219,6 +224,17 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ first_id = self._current_max_id + 1
+ self._current_max_id += n
+ return [first_id + i for i in range(n)]
+
def check_consistency(
self,
db_conn: Connection,
diff --git a/synapse/third_party_rules/__init__.py b/synapse/third_party_rules/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/synapse/third_party_rules/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
new file mode 100644
index 0000000000..4589e4539b
--- /dev/null
+++ b/synapse/third_party_rules/access_rules.py
@@ -0,0 +1,971 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import email.utils
+import logging
+from typing import Dict, List, Optional, Tuple
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.api.errors import SynapseError
+from synapse.config._base import ConfigError
+from synapse.events import EventBase
+from synapse.module_api import ModuleApi
+from synapse.types import Requester, StateMap, UserID, get_domain_from_id
+
+logger = logging.getLogger(__name__)
+
+ACCESS_RULES_TYPE = "im.vector.room.access_rules"
+
+
+class AccessRules:
+ DIRECT = "direct"
+ RESTRICTED = "restricted"
+ UNRESTRICTED = "unrestricted"
+
+
+VALID_ACCESS_RULES = (
+ AccessRules.DIRECT,
+ AccessRules.RESTRICTED,
+ AccessRules.UNRESTRICTED,
+)
+
+# Rules to which we need to apply the power levels restrictions.
+#
+# These are all of the rules that neither:
+# * forbid users from joining based on a server blacklist (which means that there
+# is no need to apply power level restrictions), nor
+# * target direct chats (since we allow both users to be room admins in this case).
+#
+# The power-level restrictions, when they are applied, prevent the following:
+# * the default power level for users (users_default) being set to anything other than 0.
+# * a non-default power level being assigned to any user which would be forbidden from
+# joining a restricted room.
+RULES_WITH_RESTRICTED_POWER_LEVELS = (AccessRules.UNRESTRICTED,)
+
+
+class RoomAccessRules(object):
+ """Implementation of the ThirdPartyEventRules module API that allows federation admins
+ to define custom rules for specific events and actions.
+ Implements the custom behaviour for the "im.vector.room.access_rules" state event.
+
+ Takes a config in the format:
+
+ third_party_event_rules:
+ module: third_party_rules.RoomAccessRules
+ config:
+ # List of domains (server names) that can't be invited to rooms if the
+ # "restricted" rule is set. Defaults to an empty list.
+ domains_forbidden_when_restricted: []
+
+ # Identity server to use when checking the HS an email address belongs to
+ # using the /info endpoint. Required.
+ id_server: "vector.im"
+
+ # Enable freezing a room when the last room admin leaves.
+ # Note that the departing admin must be a local user in order for this feature
+ # to work.
+ freeze_room_with_no_admin: false
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(
+ self, config: Dict, module_api: ModuleApi,
+ ):
+ self.id_server = config["id_server"]
+ self.module_api = module_api
+
+ self.domains_forbidden_when_restricted = config.get(
+ "domains_forbidden_when_restricted", []
+ )
+
+ self.freeze_room_with_no_admin = config.get("freeze_room_with_no_admin", False)
+
+ @staticmethod
+ def parse_config(config: Dict) -> Dict:
+ """Parses and validates the options specified in the homeserver config.
+
+ Args:
+ config: The config dict.
+
+ Returns:
+ The config dict.
+
+ Raises:
+ ConfigError: If there was an issue with the provided module configuration.
+ """
+ if "id_server" not in config:
+ raise ConfigError("No IS for event rules TchapEventRules")
+
+ return config
+
+ async def on_create_room(
+ self, requester: Requester, config: Dict, is_requester_admin: bool,
+ ) -> bool:
+ """Implements synapse.events.ThirdPartyEventRules.on_create_room.
+
+ Checks if a im.vector.room.access_rules event is being set during room creation.
+ If yes, make sure the event is correct. Otherwise, append an event with the
+ default rule to the initial state.
+
+ Checks if a m.rooms.power_levels event is being set during room creation.
+ If yes, make sure the event is allowed. Otherwise, set power_level_content_override
+ in the config dict to our modified version of the default room power levels.
+
+ Args:
+ requester: The user who is making the createRoom request.
+ config: The createRoom config dict provided by the user.
+ is_requester_admin: Whether the requester is a Synapse admin.
+
+ Returns:
+ Whether the request is allowed.
+
+ Raises:
+ SynapseError: If the createRoom config dict is invalid or its contents blocked.
+ """
+ is_direct = config.get("is_direct")
+ preset = config.get("preset")
+ access_rule = None
+ join_rule = None
+
+ # If there's a rules event in the initial state, check if it complies with the
+ # spec for im.vector.room.access_rules and deny the request if not.
+ for event in config.get("initial_state", []):
+ if event["type"] == ACCESS_RULES_TYPE:
+ access_rule = event["content"].get("rule")
+
+ # Make sure the event has a valid content.
+ if access_rule is None:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule name is valid.
+ if access_rule not in VALID_ACCESS_RULES:
+ raise SynapseError(400, "Invalid access rule")
+
+ if (is_direct and access_rule != AccessRules.DIRECT) or (
+ access_rule == AccessRules.DIRECT and not is_direct
+ ):
+ raise SynapseError(400, "Invalid access rule")
+
+ if event["type"] == EventTypes.JoinRules:
+ join_rule = event["content"].get("join_rule")
+
+ if access_rule is None:
+ # If there's no access rules event in the initial state, create one with the
+ # default setting.
+ if is_direct:
+ default_rule = AccessRules.DIRECT
+ else:
+ # If the default value for non-direct chat changes, we should make another
+ # case here for rooms created with either a "public" join_rule or the
+ # "public_chat" preset to make sure those keep defaulting to "restricted"
+ default_rule = AccessRules.RESTRICTED
+
+ if not config.get("initial_state"):
+ config["initial_state"] = []
+
+ config["initial_state"].append(
+ {
+ "type": ACCESS_RULES_TYPE,
+ "state_key": "",
+ "content": {"rule": default_rule},
+ }
+ )
+
+ access_rule = default_rule
+
+ # Check that the preset in use is compatible with the access rule, whether it's
+ # user-defined or the default.
+ #
+ # Direct rooms may not have their join_rules set to JoinRules.PUBLIC.
+ if (
+ join_rule == JoinRules.PUBLIC or preset == RoomCreationPreset.PUBLIC_CHAT
+ ) and access_rule == AccessRules.DIRECT:
+ raise SynapseError(400, "Invalid access rule")
+
+ default_power_levels = self._get_default_power_levels(
+ requester.user.to_string()
+ )
+
+ # Check if the creator can override values for the power levels.
+ allowed = self._is_power_level_content_allowed(
+ config.get("power_level_content_override", {}),
+ access_rule,
+ default_power_levels,
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content override")
+
+ custom_user_power_levels = config.get("power_level_content_override")
+
+ # Second loop for events we need to know the current rule to process.
+ for event in config.get("initial_state", []):
+ if event["type"] == EventTypes.PowerLevels:
+ allowed = self._is_power_level_content_allowed(
+ event["content"], access_rule, default_power_levels
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content")
+
+ custom_user_power_levels = event["content"]
+ if custom_user_power_levels:
+ # If the user is using their own power levels, but failed to provide an expected
+ # key in the power levels content dictionary, fill it in from the defaults instead
+ for key, value in default_power_levels.items():
+ custom_user_power_levels.setdefault(key, value)
+ else:
+ # If power levels were not overridden by the user, completely override with the
+ # defaults instead
+ config["power_level_content_override"] = default_power_levels
+
+ return True
+
+ # If power levels are not overridden by the user during room creation, the following
+ # rules are used instead. Changes from Synapse's default power levels are noted.
+ #
+ # The same power levels are currently applied regardless of room preset.
+ @staticmethod
+ def _get_default_power_levels(user_id: str) -> Dict:
+ return {
+ "users": {user_id: 100},
+ "users_default": 0,
+ "events": {
+ EventTypes.Name: 50,
+ EventTypes.PowerLevels: 100,
+ EventTypes.RoomHistoryVisibility: 100,
+ EventTypes.CanonicalAlias: 50,
+ EventTypes.RoomAvatar: 50,
+ EventTypes.Tombstone: 100,
+ EventTypes.ServerACL: 100,
+ EventTypes.RoomEncryption: 100,
+ },
+ "events_default": 0,
+ "state_default": 100, # Admins should be the only ones to perform other tasks
+ "ban": 50,
+ "kick": 50,
+ "redact": 50,
+ "invite": 50, # All rooms should require mod to invite, even private
+ }
+
+ async def check_threepid_can_be_invited(
+ self, medium: str, address: str, state_events: StateMap[EventBase],
+ ) -> bool:
+ """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited.
+
+ Check if a threepid can be invited to the room via a 3PID invite given the current
+ rules and the threepid's address, by retrieving the HS it's mapped to from the
+ configured identity server, and checking if we can invite users from it.
+
+ Args:
+ medium: The medium of the threepid.
+ address: The address of the threepid.
+ state_events: A dict mapping (event type, state key) to state event.
+ State events in the room the threepid is being invited to.
+
+ Returns:
+ Whether the threepid invite is allowed.
+ """
+ rule = self._get_rule_from_state(state_events)
+
+ if medium != "email":
+ return False
+
+ if rule != AccessRules.RESTRICTED:
+ # Only "restricted" requires filtering 3PID invites. We don't need to do
+ # anything for "direct" here, because only "restricted" requires filtering
+ # based on the HS the address is mapped to.
+ return True
+
+ parsed_address = email.utils.parseaddr(address)[1]
+ if parsed_address != address:
+ # Avoid reproducing the security issue described here:
+ # https://matrix.org/blog/2019/04/18/security-update-sydent-1-0-2
+ # It's probably not worth it but let's just be overly safe here.
+ return False
+
+ # Get the HS this address belongs to from the identity server.
+ res = await self.module_api.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/info" % (self.id_server,),
+ {"medium": medium, "address": address},
+ )
+
+ # Look for a domain that's not forbidden from being invited.
+ if not res.get("hs"):
+ return False
+ if res.get("hs") in self.domains_forbidden_when_restricted:
+ return False
+
+ return True
+
+ async def check_event_allowed(
+ self, event: EventBase, state_events: StateMap[EventBase],
+ ) -> bool:
+ """Implements synapse.events.ThirdPartyEventRules.check_event_allowed.
+
+ Checks the event's type and the current rule and calls the right function to
+ determine whether the event can be allowed.
+
+ Args:
+ event: The event to check.
+ state_events: A dict mapping (event type, state key) to state event.
+ State events in the room the event originated from.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ if event.type == ACCESS_RULES_TYPE:
+ return await self._on_rules_change(event, state_events)
+
+ # We need to know the rule to apply when processing the event types below.
+ rule = self._get_rule_from_state(state_events)
+
+ if event.type == EventTypes.PowerLevels:
+ return self._is_power_level_content_allowed(
+ event.content, rule, on_room_creation=False
+ )
+
+ if event.type == EventTypes.Member or event.type == EventTypes.ThirdPartyInvite:
+ return await self._on_membership_or_invite(event, rule, state_events)
+
+ if event.type == EventTypes.JoinRules:
+ return self._on_join_rule_change(event, rule)
+
+ if event.type == EventTypes.RoomAvatar:
+ return self._on_room_avatar_change(event, rule)
+
+ if event.type == EventTypes.Name:
+ return self._on_room_name_change(event, rule)
+
+ if event.type == EventTypes.Topic:
+ return self._on_room_topic_change(event, rule)
+
+ return True
+
+ async def check_visibility_can_be_modified(
+ self, room_id: str, state_events: StateMap[EventBase], new_visibility: str
+ ) -> bool:
+ """Implements
+ synapse.events.ThirdPartyEventRules.check_visibility_can_be_modified
+
+ Determines whether a room can be published, or removed from, the public room
+ list. A room is published if its visibility is set to "public". Otherwise,
+ its visibility is "private". A room with access rule other than "restricted"
+ may not be published.
+
+ Args:
+ room_id: The ID of the room.
+ state_events: A dict mapping (event type, state key) to state event.
+ State events in the room.
+ new_visibility: The new visibility state. Either "public" or "private".
+
+ Returns:
+ Whether the room is allowed to be published to, or removed from, the public
+ rooms directory.
+ """
+ # We need to know the rule to apply when processing the event types below.
+ rule = self._get_rule_from_state(state_events)
+
+ # Allow adding a room to the public rooms list only if it is restricted
+ if new_visibility == "public":
+ return rule == AccessRules.RESTRICTED
+
+ # By default a room is created as "restricted", meaning it is allowed to be
+ # published to the public rooms directory.
+ return True
+
+ async def _on_rules_change(
+ self, event: EventBase, state_events: StateMap[EventBase]
+ ):
+ """Checks whether an im.vector.room.access_rules event is forbidden or allowed.
+
+ Args:
+ event: The im.vector.room.access_rules event.
+ state_events: A dict mapping (event type, state key) to state event.
+ State events in the room before the event was sent.
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ new_rule = event.content.get("rule")
+
+ # Check for invalid values.
+ if new_rule not in VALID_ACCESS_RULES:
+ return False
+
+ # Make sure we don't apply "direct" if the room has more than two members.
+ if new_rule == AccessRules.DIRECT:
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ if len(existing_members) > 2 or len(threepid_tokens) > 1:
+ return False
+
+ if new_rule != AccessRules.RESTRICTED:
+ # Block this change if this room is currently listed in the public rooms
+ # directory
+ if await self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ event.room_id
+ ):
+ return False
+
+ prev_rules_event = state_events.get((ACCESS_RULES_TYPE, ""))
+
+ # Now that we know the new rule doesn't break the "direct" case, we can allow any
+ # new rule in rooms that had none before.
+ if prev_rules_event is None:
+ return True
+
+ prev_rule = prev_rules_event.content.get("rule")
+
+ # Currently, we can only go from "restricted" to "unrestricted".
+ return (
+ prev_rule == AccessRules.RESTRICTED and new_rule == AccessRules.UNRESTRICTED
+ )
+
+ async def _on_membership_or_invite(
+ self, event: EventBase, rule: str, state_events: StateMap[EventBase],
+ ) -> bool:
+ """Applies the correct rule for incoming m.room.member and
+ m.room.third_party_invite events.
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+ state_events: A dict mapping (event type, state key) to state event.
+ The state of the room before the event was sent.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ if rule == AccessRules.RESTRICTED:
+ ret = self._on_membership_or_invite_restricted(event)
+ elif rule == AccessRules.UNRESTRICTED:
+ ret = self._on_membership_or_invite_unrestricted(event, state_events)
+ elif rule == AccessRules.DIRECT:
+ ret = self._on_membership_or_invite_direct(event, state_events)
+ else:
+ # We currently apply the default (restricted) if we don't know the rule, we
+ # might want to change that in the future.
+ ret = self._on_membership_or_invite_restricted(event)
+
+ if event.type == "m.room.member":
+ # If this is an admin leaving, and they are the last admin in the room,
+ # raise the power levels of the room so that the room is 'frozen'.
+ #
+ # We have to freeze the room by puppeting an admin user, which we can
+ # only do for local users
+ if (
+ self.freeze_room_with_no_admin
+ and self._is_local_user(event.sender)
+ and event.membership == Membership.LEAVE
+ ):
+ await self._freeze_room_if_last_admin_is_leaving(event, state_events)
+
+ return ret
+
+ async def _freeze_room_if_last_admin_is_leaving(
+ self, event: EventBase, state_events: StateMap[EventBase]
+ ):
+ power_level_state_event = state_events.get(
+ (EventTypes.PowerLevels, "")
+ ) # type: EventBase
+ if not power_level_state_event:
+ return
+ power_level_content = power_level_state_event.content
+
+ # Do some validation checks on the power level state event
+ if (
+ not isinstance(power_level_content, dict)
+ or "users" not in power_level_content
+ or not isinstance(power_level_content["users"], dict)
+ ):
+ # We can't use this power level event to determine whether the room should be
+ # frozen. Bail out.
+ return
+
+ user_id = event.get("sender")
+ if not user_id:
+ return
+
+ # Get every admin user defined in the room's state
+ admin_users = {
+ user
+ for user, power_level in power_level_content["users"].items()
+ if power_level >= 100
+ }
+
+ if user_id not in admin_users:
+ # This user is not an admin, ignore them
+ return
+
+ if any(
+ event_type == EventTypes.Member
+ and event.membership in [Membership.JOIN, Membership.INVITE]
+ and state_key in admin_users
+ and state_key != user_id
+ for (event_type, state_key), event in state_events.items()
+ ):
+ # There's another admin user in, or invited to, the room
+ return
+
+ # Freeze the room by raising the required power level to send events to 100
+ logger.info("Freezing room '%s'", event.room_id)
+
+ # Modify the existing power levels to raise all required types to 100
+ #
+ # This changes a power level state event's content from something like:
+ # {
+ # "redact": 50,
+ # "state_default": 50,
+ # "ban": 50,
+ # "notifications": {
+ # "room": 50
+ # },
+ # "events": {
+ # "m.room.avatar": 50,
+ # "m.room.encryption": 50,
+ # "m.room.canonical_alias": 50,
+ # "m.room.name": 50,
+ # "im.vector.modular.widgets": 50,
+ # "m.room.topic": 50,
+ # "m.room.tombstone": 50,
+ # "m.room.history_visibility": 100,
+ # "m.room.power_levels": 100
+ # },
+ # "users_default": 0,
+ # "events_default": 0,
+ # "users": {
+ # "@admin:example.com": 100,
+ # },
+ # "kick": 50,
+ # "invite": 0
+ # }
+ #
+ # to
+ #
+ # {
+ # "redact": 100,
+ # "state_default": 100,
+ # "ban": 100,
+ # "notifications": {
+ # "room": 50
+ # },
+ # "events": {}
+ # "users_default": 0,
+ # "events_default": 100,
+ # "users": {
+ # "@admin:example.com": 100,
+ # },
+ # "kick": 100,
+ # "invite": 100
+ # }
+ new_content = {}
+ for key, value in power_level_content.items():
+ # Do not change "users_default", as that key specifies the default power
+ # level of new users
+ if isinstance(value, int) and key != "users_default":
+ value = 100
+ new_content[key] = value
+
+ # Set some values in case they are missing from the original
+ # power levels event content
+ new_content.update(
+ {
+ # Clear out any special-cased event keys
+ "events": {},
+ # Ensure state_default and events_default keys exist and are 100.
+ # Otherwise a lower PL user could potentially send state events that
+ # aren't explicitly mentioned elsewhere in the power level dict
+ "state_default": 100,
+ "events_default": 100,
+ # Membership events default to 50 if they aren't present. Set them
+ # to 100 here, as they would be set to 100 if they were present anyways
+ "ban": 100,
+ "kick": 100,
+ "invite": 100,
+ "redact": 100,
+ }
+ )
+
+ await self.module_api.create_and_send_event_into_room(
+ {
+ "room_id": event.room_id,
+ "sender": user_id,
+ "type": EventTypes.PowerLevels,
+ "content": new_content,
+ "state_key": "",
+ }
+ )
+
+ def _on_membership_or_invite_restricted(self, event: EventBase) -> bool:
+ """Implements the checks and behaviour specified for the "restricted" rule.
+
+ "restricted" currently means that users can only invite users if their server is
+ included in a limited list of domains.
+
+ Args:
+ event: The event to check.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ # We're not applying the rules on m.room.third_party_member events here because
+ # the filtering on threepids is done in check_threepid_can_be_invited, which is
+ # called before check_event_allowed.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return True
+
+ # We only need to process "join" and "invite" memberships, in order to be backward
+ # compatible, e.g. if a user from a blacklisted server joined a restricted room
+ # before the rules started being enforced on the server, that user must be able to
+ # leave it.
+ if event.membership not in [Membership.JOIN, Membership.INVITE]:
+ return True
+
+ invitee_domain = get_domain_from_id(event.state_key)
+ return invitee_domain not in self.domains_forbidden_when_restricted
+
+ def _on_membership_or_invite_unrestricted(
+ self, event: EventBase, state_events: StateMap[EventBase]
+ ) -> bool:
+ """Implements the checks and behaviour specified for the "unrestricted" rule.
+
+ "unrestricted" currently means that forbidden users cannot join without an invite.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ # If this is a join from a forbidden user and they don't have an invite to the
+ # room, then deny it
+ if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+ # Check if this user is from a forbidden server
+ target_domain = get_domain_from_id(event.state_key)
+ if target_domain in self.domains_forbidden_when_restricted:
+ # If so, they'll need an invite to join this room. Check if one exists
+ if not self._user_is_invited_to_room(event.state_key, state_events):
+ return False
+
+ return True
+
+ def _on_membership_or_invite_direct(
+ self, event: EventBase, state_events: StateMap[EventBase],
+ ) -> bool:
+ """Implements the checks and behaviour specified for the "direct" rule.
+
+ "direct" currently means that no member is allowed apart from the two initial
+ members the room was created for (i.e. the room's creator and their first invitee).
+
+ Args:
+ event: The event to check.
+ state_events: A dict mapping (event type, state key) to state event.
+ The state of the room before the event was sent.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ # Get the room memberships and 3PID invite tokens from the room's state.
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ # There should never be more than one 3PID invite in the room state: if the second
+ # original user came and left, and we're inviting them using their email address,
+ # given we know they have a Matrix account binded to the address (so they could
+ # join the first time), Synapse will successfully look it up before attempting to
+ # store an invite on the IS.
+ if len(threepid_tokens) == 1 and event.type == EventTypes.ThirdPartyInvite:
+ # If we already have a 3PID invite in flight, don't accept another one, unless
+ # the new one has the same invite token as its state key. This is because 3PID
+ # invite revocations must be allowed, and a revocation is basically a new 3PID
+ # invite event with an empty content and the same token as the invite it
+ # revokes.
+ return event.state_key in threepid_tokens
+
+ if len(existing_members) == 2:
+ # If the user was within the two initial user of the room, Synapse would have
+ # looked it up successfully and thus sent a m.room.member here instead of
+ # m.room.third_party_invite.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return False
+
+ # We can only have m.room.member events here. The rule in this case is to only
+ # allow the event if its target is one of the initial two members in the room,
+ # i.e. the state key of one of the two m.room.member states in the room.
+ return event.state_key in existing_members
+
+ # We're alone in the room (and always have been) and there's one 3PID invite in
+ # flight.
+ if len(existing_members) == 1 and len(threepid_tokens) == 1:
+ # We can only have m.room.member events here. In this case, we can only allow
+ # the event if it's either a m.room.member from the joined user (we can assume
+ # that the only m.room.member event is a join otherwise we wouldn't be able to
+ # send an event to the room) or an an invite event which target is the invited
+ # user.
+ target = event.state_key
+ is_from_threepid_invite = self._is_invite_from_threepid(
+ event, threepid_tokens[0]
+ )
+ return is_from_threepid_invite or target == existing_members[0]
+
+ return True
+
+ def _is_power_level_content_allowed(
+ self,
+ content: Dict,
+ access_rule: str,
+ default_power_levels: Optional[Dict] = None,
+ on_room_creation: bool = True,
+ ) -> bool:
+ """Check if a given power levels event is permitted under the given access rule.
+
+ It shouldn't be allowed if it either changes the default PL to a non-0 value or
+ gives a non-0 PL to a user that would have been forbidden from joining the room
+ under a more restrictive access rule.
+
+ Args:
+ content: The content of the m.room.power_levels event to check.
+ access_rule: The access rule in place in this room.
+ default_power_levels: The default power levels when a room is created with
+ the specified access rule. Required if on_room_creation is True.
+ on_room_creation: True if this call is happening during a room's
+ creation, False otherwise.
+
+ Returns:
+ Whether the content of the power levels event is valid.
+ """
+ # Only enforce these rules during room creation
+ #
+ # We want to allow admins to modify or fix the power levels in a room if they
+ # have a special circumstance, but still want to encourage a certain pattern during
+ # room creation.
+ if on_room_creation:
+ # We specifically don't fail if "invite" or "state_default" are None, as those
+ # values should be replaced with our "default" power level values anyways,
+ # which are compliant
+
+ invite = default_power_levels["invite"]
+ state_default = default_power_levels["state_default"]
+
+ # If invite requirements are less than our required defaults
+ if content.get("invite", invite) < invite:
+ return False
+
+ # If "other" state requirements are less than our required defaults
+ if content.get("state_default", state_default) < state_default:
+ return False
+
+ # Check if we need to apply the restrictions with the current rule.
+ if access_rule not in RULES_WITH_RESTRICTED_POWER_LEVELS:
+ return True
+
+ # If users_default is explicitly set to a non-0 value, deny the event.
+ users_default = content.get("users_default", 0)
+ if users_default:
+ return False
+
+ users = content.get("users", {})
+ for user_id, power_level in users.items():
+ server_name = get_domain_from_id(user_id)
+ # Check the domain against the blacklist. If found, and the PL isn't 0, deny
+ # the event.
+ if (
+ server_name in self.domains_forbidden_when_restricted
+ and power_level != 0
+ ):
+ return False
+
+ return True
+
+ def _on_join_rule_change(self, event: EventBase, rule: str) -> bool:
+ """Check whether a join rule change is allowed.
+
+ A join rule change is always allowed unless the new join rule is "public" and
+ the current access rule is "direct".
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+
+ Returns:
+ Whether the change is allowed.
+ """
+ if event.content.get("join_rule") == JoinRules.PUBLIC:
+ return rule != AccessRules.DIRECT
+
+ return True
+
+ def _on_room_avatar_change(self, event: EventBase, rule: str) -> bool:
+ """Check whether a change of room avatar is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ return rule != AccessRules.DIRECT
+
+ def _on_room_name_change(self, event: EventBase, rule: str) -> bool:
+ """Check whether a change of room name is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ return rule != AccessRules.DIRECT
+
+ def _on_room_topic_change(self, event: EventBase, rule: str) -> bool:
+ """Check whether a change of room topic is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ return rule != AccessRules.DIRECT
+
+ @staticmethod
+ def _get_rule_from_state(state_events: StateMap[EventBase]) -> Optional[str]:
+ """Extract the rule to be applied from the given set of state events.
+
+ Args:
+ state_events: A dict mapping (event type, state key) to state event.
+
+ Returns:
+ The name of the rule (either "direct", "restricted" or "unrestricted") if found,
+ else None.
+ """
+ access_rules = state_events.get((ACCESS_RULES_TYPE, ""))
+ if access_rules is None:
+ return AccessRules.RESTRICTED
+
+ return access_rules.content.get("rule")
+
+ @staticmethod
+ def _get_join_rule_from_state(state_events: StateMap[EventBase]) -> Optional[str]:
+ """Extract the room's join rule from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+
+ Returns:
+ The name of the join rule (either "public", or "invite") if found, else None.
+ """
+ join_rule_event = state_events.get((EventTypes.JoinRules, ""))
+ if join_rule_event is None:
+ return None
+
+ return join_rule_event.content.get("join_rule")
+
+ @staticmethod
+ def _get_members_and_tokens_from_state(
+ state_events: StateMap[EventBase],
+ ) -> Tuple[List[str], List[str]]:
+ """Retrieves the list of users that have a m.room.member event in the room,
+ as well as 3PID invites tokens in the room.
+
+ Args:
+ state_events: A dict mapping (event type, state key) to state event.
+
+ Returns:
+ A tuple containing the:
+ * targets of the m.room.member events in the state.
+ * 3PID invite tokens in the state.
+ """
+ existing_members = []
+ threepid_invite_tokens = []
+ for key, state_event in state_events.items():
+ if key[0] == EventTypes.Member and state_event.content:
+ existing_members.append(state_event.state_key)
+ if key[0] == EventTypes.ThirdPartyInvite and state_event.content:
+ # Don't include revoked invites.
+ threepid_invite_tokens.append(state_event.state_key)
+
+ return existing_members, threepid_invite_tokens
+
+ @staticmethod
+ def _is_invite_from_threepid(invite: EventBase, threepid_invite_token: str) -> bool:
+ """Checks whether the given invite follows the given 3PID invite.
+
+ Args:
+ invite: The m.room.member event with "invite" membership.
+ threepid_invite_token: The state key from the 3PID invite.
+
+ Returns:
+ Whether the invite is due to the given 3PID invite.
+ """
+ token = (
+ invite.content.get("third_party_invite", {})
+ .get("signed", {})
+ .get("token", "")
+ )
+
+ return token == threepid_invite_token
+
+ def _is_local_user(self, user_id: str) -> bool:
+ """Checks whether a given user ID belongs to this homeserver, or a remote
+
+ Args:
+ user_id: A user ID to check.
+
+ Returns:
+ True if the user belongs to this homeserver, False otherwise.
+ """
+ user = UserID.from_string(user_id)
+
+ # Extract the localpart and ask the module API for a user ID from the localpart
+ # The module API will append the local homeserver's server_name
+ local_user_id = self.module_api.get_qualified_user_id(user.localpart)
+
+ # If the user ID we get based on the localpart is the same as the original user ID,
+ # then they were a local user
+ return user_id == local_user_id
+
+ def _user_is_invited_to_room(
+ self, user_id: str, state_events: StateMap[EventBase]
+ ) -> bool:
+ """Checks whether a given user has been invited to a room
+
+ A user has an invite for a room if its state contains a `m.room.member`
+ event with membership "invite" and their user ID as the state key.
+
+ Args:
+ user_id: The user to check.
+ state_events: The state events from the room.
+
+ Returns:
+ True if the user has been invited to the room, or False if they haven't.
+ """
+ for (event_type, state_key), state_event in state_events.items():
+ if (
+ event_type == EventTypes.Member
+ and state_key == user_id
+ and state_event.membership == Membership.INVITE
+ ):
+ return True
+
+ return False
diff --git a/synapse/types.py b/synapse/types.py
index eafe729dfe..9629b26b01 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -34,6 +34,7 @@ from typing import (
import attr
from signedjson.key import decode_verify_key_bytes
+from six.moves import filter
from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError
@@ -335,6 +336,19 @@ def contains_invalid_mxid_characters(localpart: str) -> bool:
return any(c not in mxid_localpart_allowed_characters for c in localpart)
+def strip_invalid_mxid_characters(localpart):
+ """Removes any invalid characters from an mxid
+
+ Args:
+ localpart (basestring): the localpart to be stripped
+
+ Returns:
+ localpart (basestring): the localpart having been stripped
+ """
+ filtered = filter(lambda c: c in mxid_localpart_allowed_characters, localpart)
+ return "".join(filtered)
+
+
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 1ee61851e4..09b094ded7 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -49,7 +49,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
module = importlib.import_module(module)
provider_class = getattr(module, clz)
- module_config = provider.get("config")
+ # Load the module config. If None, pass an empty dictionary instead
+ module_config = provider.get("config") or {}
try:
provider_config = provider_class.parse_config(module_config)
except jsonschema.ValidationError as e:
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 43c2e0ac23..63f955acff 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -19,8 +19,8 @@ import re
logger = logging.getLogger(__name__)
-def check_3pid_allowed(hs, medium, address):
- """Checks whether a given format of 3PID is allowed to be used on this HS
+async def check_3pid_allowed(hs, medium, address):
+ """Checks whether a given 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
@@ -31,6 +31,33 @@ def check_3pid_allowed(hs, medium, address):
bool: whether the 3PID medium/address is allowed to be added to this HS
"""
+ if hs.config.check_is_for_allowed_local_3pids:
+ data = await hs.get_simple_http_client().get_json(
+ "https://%s%s"
+ % (
+ hs.config.check_is_for_allowed_local_3pids,
+ "/_matrix/identity/api/v1/internal-info",
+ ),
+ {"medium": medium, "address": address},
+ )
+
+ # Check for invalid response
+ if "hs" not in data and "shadow_hs" not in data:
+ return False
+
+ # Check if this user is intended to register for this homeserver
+ if (
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
+ ):
+ return False
+
+ if data.get("requires_invite", False) and not data.get("invited", False):
+ # Requires an invite but hasn't been invited
+ return False
+
+ return True
+
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
diff --git a/sytest-blacklist b/sytest-blacklist
index de9986357b..9b7161fec1 100644
--- a/sytest-blacklist
+++ b/sytest-blacklist
@@ -41,3 +41,29 @@ We can't peek into rooms with invited history_visibility
We can't peek into rooms with joined history_visibility
Local users can peek by room alias
Peeked rooms only turn up in the sync for the device who peeked them
+
+# Blacklisted due to https://github.com/matrix-org/synapse-dinsic/issues/43
+Inviting an AS-hosted user asks the AS server
+Accesing an AS-hosted room alias asks the AS server
+Events in rooms with AS-hosted room aliases are sent to AS server
+
+# flaky test
+If remote user leaves room we no longer receive device updates
+
+# flaky test
+Can re-join room if re-invited
+
+# flaky test
+Forgotten room messages cannot be paginated
+
+# flaky test
+Local device key changes get to remote servers
+
+# flaky test
+Old leaves are present in gapped incremental syncs
+
+# flaky test on workers
+Old members are included in gappy incr LL sync if they start speaking
+
+# flaky test on workers
+Presence changes to UNAVAILABLE are reported to remote room members
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
new file mode 100644
index 0000000000..5b928480e7
--- /dev/null
+++ b/tests/federation/transport/test_knocking.py
@@ -0,0 +1,288 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Matrix.org Federation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections import OrderedDict
+from typing import Dict, List
+
+from twisted.internet.defer import succeed
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.events import builder
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.server import HomeServer
+from synapse.types import RoomAlias
+
+from tests.test_utils import event_injection
+from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+
+# An identifier to use while MSC2304 is not in a stable release of the spec
+KNOCK_UNSTABLE_IDENTIFIER = "xyz.amorgan.knock"
+
+
+class KnockingStrippedStateEventHelperMixin(TestCase):
+ def send_example_state_events_to_room(
+ self, hs: "HomeServer", room_id: str, sender: str,
+ ) -> OrderedDict:
+ """Adds some state to a room. State events are those that should be sent to a knocking
+ user after they knock on the room, as well as some state that *shouldn't* be sent
+ to the knocking user.
+
+ Args:
+ hs: The homeserver of the sender.
+ room_id: The ID of the room to send state into.
+ sender: The ID of the user to send state as. Must be in the room.
+
+ Returns:
+ The OrderedDict of event types and content that a user is expected to see
+ after knocking on a room.
+ """
+ # To set a canonical alias, we'll need to point an alias at the room first.
+ canonical_alias = "#fancy_alias:test"
+ self.get_success(
+ self.store.create_room_alias_association(
+ RoomAlias.from_string(canonical_alias), room_id, ["test"]
+ )
+ )
+
+ # Send some state that we *don't* expect to be given to knocking users
+ self.get_success(
+ event_injection.inject_event(
+ hs,
+ room_version=RoomVersions.V7.identifier,
+ room_id=room_id,
+ sender=sender,
+ type="com.example.secret",
+ state_key="",
+ content={"secret": "password"},
+ )
+ )
+
+ # We use an OrderedDict here to ensure that the knock membership appears last.
+ # Note that order only matters when sending stripped state to clients, not federated
+ # homeservers.
+ room_state = OrderedDict(
+ [
+ # We need to set the room's join rules to allow knocking
+ (
+ EventTypes.JoinRules,
+ {"content": {"join_rule": JoinRules.KNOCK}, "state_key": ""},
+ ),
+ # Below are state events that are to be stripped and sent to clients
+ (
+ EventTypes.Name,
+ {"content": {"name": "A cool room"}, "state_key": ""},
+ ),
+ (
+ EventTypes.RoomAvatar,
+ {
+ "content": {
+ "info": {
+ "h": 398,
+ "mimetype": "image/jpeg",
+ "size": 31037,
+ "w": 394,
+ },
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ },
+ "state_key": "",
+ },
+ ),
+ (
+ EventTypes.RoomEncryption,
+ {"content": {"algorithm": "m.megolm.v1.aes-sha2"}, "state_key": ""},
+ ),
+ (
+ EventTypes.CanonicalAlias,
+ {
+ "content": {"alias": canonical_alias, "alt_aliases": []},
+ "state_key": "",
+ },
+ ),
+ ]
+ )
+
+ for event_type, event_dict in room_state.items():
+ event_content = event_dict["content"]
+ state_key = event_dict["state_key"]
+
+ self.get_success(
+ event_injection.inject_event(
+ hs,
+ room_version=RoomVersions.V7.identifier,
+ room_id=room_id,
+ sender=sender,
+ type=event_type,
+ state_key=state_key,
+ content=event_content,
+ )
+ )
+
+ return room_state
+
+ def check_knock_room_state_against_room_state(
+ self, knock_room_state: List[Dict], expected_room_state: Dict,
+ ) -> None:
+ """Test a list of stripped room state events received over federation against a
+ dict of expected state events.
+
+ Args:
+ knock_room_state: The list of room state that was received over federation.
+ expected_room_state: A dict containing the room state we expect to see in
+ `knock_room_state`.
+ """
+ for event in knock_room_state:
+ event_type = event["type"]
+
+ # Check that this event type is one of those that we expected.
+ # Note: This will also check that no excess state was included
+ self.assertIn(event_type, expected_room_state)
+
+ # Check the state content matches
+ self.assertEquals(
+ expected_room_state[event_type]["content"], event["content"]
+ )
+
+ # Check the state key is correct
+ self.assertEqual(
+ expected_room_state[event_type]["state_key"], event["state_key"]
+ )
+
+ # Ensure the event has been stripped
+ self.assertNotIn("signatures", event)
+
+ # Pop once we've found and processed a state event
+ expected_room_state.pop(event_type)
+
+ # Check that all expected state events were accounted for
+ self.assertEqual(len(expected_room_state), 0)
+
+
+class FederationKnockingTestCase(
+ FederatingHomeserverTestCase, KnockingStrippedStateEventHelperMixin
+):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
+
+ # We're not going to be properly signing events as our remote homeserver is fake,
+ # therefore disable event signature checks.
+ # Note that these checks are not relevant to this test case.
+
+ # Have this homeserver auto-approve all event signature checking.
+ def approve_all_signature_checking(_, ev):
+ return [succeed(ev[0])]
+
+ homeserver.get_federation_server()._check_sigs_and_hashes = (
+ approve_all_signature_checking
+ )
+
+ # Have this homeserver skip event auth checks. This is necessary due to
+ # event auth checks ensuring that events were signed the sender's homeserver.
+ async def do_auth(origin, event, context, auth_events):
+ return context
+
+ homeserver.get_federation_handler().do_auth = do_auth
+
+ return super().prepare(reactor, clock, homeserver)
+
+ @override_config({"experimental_features": {"msc2403_enabled": True}})
+ def test_room_state_returned_when_knocking(self):
+ """
+ Tests that specific, stripped state events from a room are returned after
+ a remote homeserver successfully knocks on a local room.
+ """
+ user_id = self.register_user("u1", "you the one")
+ user_token = self.login("u1", "you the one")
+
+ fake_knocking_user_id = "@user:other.example.com"
+
+ # Create a room with a room version that includes knocking
+ room_id = self.helper.create_room_as(
+ "u1",
+ is_public=False,
+ room_version=RoomVersions.V7.identifier,
+ tok=user_token,
+ )
+
+ # Update the join rules and add additional state to the room to check for later
+ expected_room_state = self.send_example_state_events_to_room(
+ self.hs, room_id, user_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/federation/unstable/%s/make_knock/%s/%s?ver=%s"
+ % (
+ KNOCK_UNSTABLE_IDENTIFIER,
+ room_id,
+ fake_knocking_user_id,
+ # Inform the remote that we support the room version of the room we're
+ # knocking on
+ RoomVersions.V7.identifier,
+ ),
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Note: We don't expect the knock membership event to be sent over federation as
+ # part of the stripped room state, as the knocking homeserver already has that
+ # event. It is only done for clients during /sync
+
+ # Extract the generated knock event json
+ knock_event = channel.json_body["event"]
+
+ # Check that the event has things we expect in it
+ self.assertEquals(knock_event["room_id"], room_id)
+ self.assertEquals(knock_event["sender"], fake_knocking_user_id)
+ self.assertEquals(knock_event["state_key"], fake_knocking_user_id)
+ self.assertEquals(knock_event["type"], EventTypes.Member)
+ self.assertEquals(knock_event["content"]["membership"], Membership.KNOCK)
+
+ # Turn the event json dict into a proper event.
+ # We won't sign it properly, but that's OK as we stub out event auth in `prepare`
+ signed_knock_event = builder.create_local_event_from_event_dict(
+ self.clock,
+ self.hs.hostname,
+ self.hs.signing_key,
+ room_version=RoomVersions.V7,
+ event_dict=knock_event,
+ )
+
+ # Convert our proper event back to json dict format
+ signed_knock_event_json = signed_knock_event.get_pdu_json(
+ self.clock.time_msec()
+ )
+
+ # Send the signed knock event into the room
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/federation/unstable/%s/send_knock/%s/%s"
+ % (KNOCK_UNSTABLE_IDENTIFIER, room_id, signed_knock_event.event_id),
+ signed_knock_event_json,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Check that we got the stripped room state in return
+ room_state_events = channel.json_body["knock_state_events"]
+
+ # Validate the stripped room state events
+ self.check_knock_room_state_against_room_state(
+ room_state_events, expected_room_state
+ )
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index a39f898608..ebc6a0866a 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -42,6 +42,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.mock_registry.register_query_handler = register_query_handler
hs = self.setup_test_homeserver(
+ federation_http_client=None,
+ resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_registry=self.mock_registry,
)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 0b24b89a2e..74503112f5 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -16,7 +16,7 @@ import logging
from unittest import TestCase
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
@@ -191,6 +191,97 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
+ @unittest.override_config(
+ {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invite_by_room_ratelimit(self):
+ """Tests that invites from federation in a room are actually rate-limited.
+ """
+ other_server = "otherserver"
+ other_user = "@otheruser:" + other_server
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ def create_invite_for(local_user):
+ return event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": other_user,
+ "state_key": local_user,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ for i in range(3):
+ self.get_success(
+ self.handler.on_invite_request(
+ other_server,
+ create_invite_for("@user-%d:test" % (i,)),
+ room_version,
+ )
+ )
+
+ self.get_failure(
+ self.handler.on_invite_request(
+ other_server, create_invite_for("@user-4:test"), room_version,
+ ),
+ exc=LimitExceededError,
+ )
+
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invite_by_user_ratelimit(self):
+ """Tests that invites from federation to a particular user are
+ actually rate-limited.
+ """
+ other_server = "otherserver"
+ other_user = "@otheruser:" + other_server
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ def create_invite():
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+ return event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": other_user,
+ "state_key": "@user:test",
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ for i in range(3):
+ event = create_invite()
+ self.get_success(
+ self.handler.on_invite_request(other_server, event, event.room_version,)
+ )
+
+ event = create_invite()
+ self.get_failure(
+ self.handler.on_invite_request(other_server, event, event.room_version,),
+ exc=LimitExceededError,
+ )
+
def _build_and_send_join_event(self, other_server, other_user, room_id):
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py
new file mode 100644
index 0000000000..b7d340bcb8
--- /dev/null
+++ b/tests/handlers/test_identity.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse.rest.admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account
+
+from tests import unittest
+
+
+class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.address = "test@test"
+ self.is_server_name = "testis"
+ self.is_server_url = "https://testis"
+ self.rewritten_is_url = "https://int.testis"
+
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = [self.is_server_name]
+ config["rewrite_identity_server_urls"] = {
+ self.is_server_url: self.rewritten_is_url
+ }
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.side_effect = defer.succeed({})
+ mock_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ mock_blacklisting_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_blacklisting_http_client.get_json.side_effect = defer.succeed({})
+ mock_blacklisting_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_identity_handler().blacklisting_http_client = (
+ mock_blacklisting_http_client
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+
+ def test_rewritten_id_server(self):
+ """
+ Tests that, when validating a 3PID association while rewriting the IS's server
+ name:
+ * the bind request is done against the rewritten hostname
+ * the original, non-rewritten, server name is stored in the database
+ """
+ handler = self.hs.get_identity_handler()
+ post_json_get_json = handler.blacklisting_http_client.post_json_get_json
+ store = self.hs.get_datastore()
+
+ creds = {"sid": "123", "client_secret": "some_secret"}
+
+ # Make sure processing the mocked response goes through.
+ data = self.get_success(
+ handler.bind_threepid(
+ client_secret=creds["client_secret"],
+ sid=creds["sid"],
+ mxid=self.user_id,
+ id_server=self.is_server_name,
+ use_v2=False,
+ )
+ )
+ self.assertEqual(data.get("address"), self.address)
+
+ # Check that the request was done against the rewritten server name.
+ post_json_get_json.assert_called_once_with(
+ "%s/_matrix/identity/api/v1/3pid/bind" % (self.rewritten_is_url,),
+ {
+ "sid": creds["sid"],
+ "client_secret": creds["client_secret"],
+ "mxid": self.user_id,
+ },
+ headers={},
+ )
+
+ # Check that the original server name is saved in the database instead of the
+ # rewritten one.
+ id_servers = self.get_success(
+ store.get_id_servers_user_bound(self.user_id, "email", self.address)
+ )
+ self.assertEqual(id_servers, [self.is_server_name])
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 022943a10a..75275f0e4f 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -63,7 +63,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_name(self):
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
)
displayname = yield defer.ensureDeferred(
@@ -126,7 +126,7 @@ class ProfileTestCase(unittest.TestCase):
# Setting displayname for the first time is allowed
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
)
self.assertEquals(
@@ -179,7 +179,7 @@ class ProfileTestCase(unittest.TestCase):
def test_incoming_fed_query(self):
yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield defer.ensureDeferred(
- self.store.set_profile_displayname("caroline", "Caroline")
+ self.store.set_profile_displayname("caroline", "Caroline", 1)
)
response = yield defer.ensureDeferred(
@@ -194,7 +194,7 @@ class ProfileTestCase(unittest.TestCase):
def test_get_my_avatar(self):
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ self.frank.localpart, "http://my.server/me.png", 1
)
)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
@@ -260,7 +260,7 @@ class ProfileTestCase(unittest.TestCase):
# Setting displayname for the first time is allowed
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ self.frank.localpart, "http://my.server/me.png", 1
)
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index bdf3d0a8a2..aaf5b92f41 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,9 +18,14 @@ from mock import Mock
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
+from synapse.rest.client.v2_alpha.register import (
+ _map_email_to_displayname,
+ register_servlets,
+)
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
+from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import mock_getRawHeaders
@@ -31,6 +36,10 @@ from .. import unittest
class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """
+ servlets = [
+ register_servlets,
+ ]
+
def make_homeserver(self, reactor, clock):
hs_config = self.default_config()
@@ -517,6 +526,103 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
+ def test_email_to_displayname_mapping(self):
+ """Test that custom emails are mapped to new user displaynames correctly"""
+ self._check_mapping(
+ "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]"
+ )
+
+ self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]")
+
+ self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]")
+
+ # Multibyte unicode characters
+ self._check_mapping(
+ "j\u030a\u0065an-poppy.seed@example.com",
+ "J\u030a\u0065an-Poppy Seed [Example]",
+ )
+
+ def _check_mapping(self, i, expected):
+ result = _map_email_to_displayname(i)
+ self.assertEqual(result, expected)
+
+ @override_config(
+ {
+ "bind_new_user_emails_to_sydent": "https://is.example.com",
+ "registrations_require_3pid": ["email"],
+ "account_threepid_delegates": {},
+ "email": {
+ "smtp_host": "127.0.0.1",
+ "smtp_port": 20,
+ "require_transport_security": False,
+ "smtp_user": None,
+ "smtp_pass": None,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "http://localhost",
+ }
+ )
+ def test_user_email_bound_via_sydent_internal_api(self):
+ """Tests that emails are bound after registration if this option is set"""
+ # Register user with an email address
+ email = "alice@example.com"
+
+ # Mock Synapse's threepid validator
+ get_threepid_validation_session = Mock(
+ return_value=make_awaitable(
+ {"medium": "email", "address": email, "validated_at": 0}
+ )
+ )
+ self.store.get_threepid_validation_session = get_threepid_validation_session
+ delete_threepid_session = Mock(return_value=make_awaitable(None))
+ self.store.delete_threepid_session = delete_threepid_session
+
+ # Mock Synapse's http json post method to check for the internal bind call
+ post_json_get_json = Mock(return_value=make_awaitable(None))
+ self.hs.get_identity_handler().http_client.post_json_get_json = (
+ post_json_get_json
+ )
+
+ # Retrieve a UIA session ID
+ channel = self.uia_register(
+ 401, {"username": "alice", "password": "nobodywillguessthis"}
+ )
+ session_id = channel.json_body["session"]
+
+ # Register our email address using the fake validation session above
+ channel = self.uia_register(
+ 200,
+ {
+ "username": "alice",
+ "password": "nobodywillguessthis",
+ "auth": {
+ "session": session_id,
+ "type": "m.login.email.identity",
+ "threepid_creds": {"sid": "blabla", "client_secret": "blablabla"},
+ },
+ },
+ )
+ self.assertEqual(channel.json_body["user_id"], "@alice:test")
+
+ # Check that a bind attempt was made to our fake identity server
+ post_json_get_json.assert_called_with(
+ "https://is.example.com/_matrix/identity/internal/bind",
+ {"address": "alice@example.com", "medium": "email", "mxid": "@alice:test"},
+ )
+
+ # Check that we stored a mapping of this bind
+ bound_threepids = self.get_success(
+ self.store.user_get_bound_threepids("@alice:test")
+ )
+ self.assertListEqual(bound_threepids, [{"medium": "email", "address": email}])
+
+ def uia_register(self, expected_response: int, body: dict) -> FakeChannel:
+ """Make a register request."""
+ channel = self.make_request("POST", "register", body)
+
+ self.assertEqual(channel.code, expected_response)
+ return channel
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 312c0a0d41..0229f58315 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -21,8 +21,14 @@ from tests import unittest
# The expected number of state events in a fresh public room.
EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5
+
# The expected number of state events in a fresh private room.
-EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
+#
+# Note: we increase this by 2 on the dinsic branch as we send
+# a "im.vector.room.access_rules" state event into new private rooms,
+# and an encryption state event as all private rooms are encrypted
+# by default
+EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7
class StatsRoomTests(unittest.HomeserverTestCase):
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 9c886d671a..2afd1970e6 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -18,8 +18,9 @@ from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes
+from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import user_directory
+from synapse.rest.client.v2_alpha import account, account_validity, user_directory
from synapse.storage.roommember import ProfileInfo
from tests import unittest
@@ -46,6 +47,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.handler = hs.get_user_directory_handler()
+ self.event_builder_factory = self.hs.get_event_builder_factory()
+ self.event_creation_handler = self.hs.get_event_creation_handler()
def test_handle_local_profile_change_with_support_user(self):
support_user_id = "@support:test"
@@ -541,6 +544,97 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, u4, 10))
self.assertEqual(len(s["results"]), 1)
+ @override_config(
+ {
+ "user_directory": {
+ "enabled": True,
+ "search_all_users": True,
+ "prefer_local_users": True,
+ }
+ }
+ )
+ def test_prefer_local_users(self):
+ """Tests that local users are shown higher in search results when
+ user_directory.prefer_local_users is True.
+ """
+ # Create a room and few users to test the directory with
+ searching_user = self.register_user("searcher", "password")
+ searching_user_tok = self.login("searcher", "password")
+
+ room_id = self.helper.create_room_as(
+ searching_user,
+ room_version=RoomVersions.V1.identifier,
+ tok=searching_user_tok,
+ )
+
+ # Create a few local users and join them to the room
+ local_user_1 = self.register_user("user_xxxxx", "password")
+ local_user_2 = self.register_user("user_bbbbb", "password")
+ local_user_3 = self.register_user("user_zzzzz", "password")
+
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_1)
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_2)
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_3)
+
+ # Create a few "remote" users and join them to the room
+ remote_user_1 = "@user_aaaaa:remote_server"
+ remote_user_2 = "@user_yyyyy:remote_server"
+ remote_user_3 = "@user_ccccc:remote_server"
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_1)
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_2)
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_3)
+
+ local_users = [local_user_1, local_user_2, local_user_3]
+ remote_users = [remote_user_1, remote_user_2, remote_user_3]
+
+ # Populate the user directory via background update
+ self._add_background_updates()
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # The local searching user searches for the term "user", which other users have
+ # in their user id
+ results = self.get_success(
+ self.handler.search_users(searching_user, "user", 20)
+ )["results"]
+ received_user_id_ordering = [result["user_id"] for result in results]
+
+ # Typically we'd expect Synapse to return users in lexicographical order,
+ # assuming they have similar User IDs/display names, and profile information.
+
+ # Check that the order of returned results using our module is as we expect,
+ # i.e our local users show up first, despite all users having lexographically mixed
+ # user IDs.
+ [self.assertIn(user, local_users) for user in received_user_id_ordering[:3]]
+ [self.assertIn(user, remote_users) for user in received_user_id_ordering[3:]]
+
+ def _add_user_to_room(
+ self, room_id: str, room_version: RoomVersion, user_id: str,
+ ):
+ # Add a user to the room.
+ builder = self.event_builder_factory.for_room_version(
+ room_version,
+ {
+ "type": "m.room.member",
+ "sender": user_id,
+ "state_key": user_id,
+ "room_id": room_id,
+ "content": {"membership": "join"},
+ },
+ )
+
+ event, context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
+ )
+
+ self.get_success(
+ self.hs.get_storage().persistence.persist_event(event, context)
+ )
+
class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
user_id = "@test:test"
@@ -585,3 +679,130 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)
+
+
+class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ account.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+
+ # Set accounts to expire after a week
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ return config
+
+ def prepare(self, reactor, clock, hs):
+ super(UserInfoTestCase, self).prepare(reactor, clock, hs)
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ def test_user_info(self):
+ """Test /users/info for local users from the Client-Server API"""
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request info about each user from user_three
+ channel = self.make_request(
+ "POST",
+ path="/_matrix/client/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ access_token=user_three_token,
+ shorthand=False,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def test_user_info_federation(self):
+ """Test that /users/info can be called from the Federation API, and
+ and that we can query remote users from the Client-Server API
+ """
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request information about our local users from the perspective of a remote server
+ channel = self.make_request(
+ "POST",
+ path="/_matrix/federation/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ )
+ self.assertEquals(200, channel.code)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def setup_test_users(self):
+ """Create an admin user and three test users, each with a different state"""
+
+ # Create an admin user to expire other users with
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_token = self.login("admin", "adminpassword")
+
+ # Create three users
+ user_one = self.register_user("alice", "pass")
+ user_one_token = self.login("alice", "pass")
+ user_two = self.register_user("bob", "pass")
+ user_three = self.register_user("carl", "pass")
+ user_three_token = self.login("carl", "pass")
+
+ # Deactivate user_one
+ self.deactivate(user_one, user_one_token)
+
+ # Expire user_two
+ self.expire(user_two, admin_token)
+
+ # Do nothing to user_three
+
+ return user_one, user_two, user_three, user_three_token
+
+ def expire(self, user_id_to_expire, admin_tok):
+ url = "/_synapse/admin/v1/account_validity/validity"
+ request_data = {
+ "user_id": user_id_to_expire,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ channel = self.make_request("POST", url, request_data, access_token=admin_tok)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def deactivate(self, user_id, tok):
+ request_data = {
+ "auth": {"type": "m.login.password", "user": user_id, "password": "pass"},
+ "erase": False,
+ }
+ channel = self.make_request(
+ "POST", "account/deactivate", request_data, access_token=tok
+ )
+ self.assertEqual(channel.code, 200)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 686012dd25..b758b29b2a 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -103,7 +103,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.agent = MatrixFederationAgent(
reactor=self.reactor,
- tls_client_options_factory=self.tls_factory,
+ tls_client_options_factory=FederationPolicyForHTTPS(config),
user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 9a56e1c14a..4e1a5a5138 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -12,7 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import base64
import logging
+import os
+from typing import Optional
+from unittest.mock import patch
import treq
from netaddr import IPSet
@@ -100,22 +104,36 @@ class MatrixFederationAgentTests(TestCase):
return http_protocol
- def test_http_request(self):
- agent = ProxyAgent(self.reactor)
+ def _test_request_direct_connection(self, agent, scheme, hostname, path):
+ """Runs a test case for a direct connection not going through a proxy.
- self.reactor.lookups["test.com"] = "1.2.3.4"
- d = agent.request(b"GET", b"http://test.com")
+ Args:
+ agent (ProxyAgent): the proxy agent being tested
+
+ scheme (bytes): expected to be either "http" or "https"
+
+ hostname (bytes): the hostname to connect to in the test
+
+ path (bytes): the path to connect to in the test
+ """
+ is_https = scheme == b"https"
+
+ self.reactor.lookups[hostname.decode()] = "1.2.3.4"
+ d = agent.request(b"GET", scheme + b"://" + hostname + b"/" + path)
# there should be a pending TCP connection
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, "1.2.3.4")
- self.assertEqual(port, 80)
+ self.assertEqual(port, 443 if is_https else 80)
# make a test server, and wire up the client
http_server = self._make_connection(
- client_factory, _get_test_protocol_factory()
+ client_factory,
+ _get_test_protocol_factory(),
+ ssl=is_https,
+ expected_sni=hostname if is_https else None,
)
# the FakeTransport is async, so we need to pump the reactor
@@ -126,8 +144,8 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0]
self.assertEqual(request.method, b"GET")
- self.assertEqual(request.path, b"/")
- self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ self.assertEqual(request.path, b"/" + path)
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [hostname])
request.write(b"result")
request.finish()
@@ -137,48 +155,54 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
+ def test_http_request(self):
+ agent = ProxyAgent(self.reactor)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
+
def test_https_request(self):
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+ self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
- self.reactor.lookups["test.com"] = "1.2.3.4"
- d = agent.request(b"GET", b"https://test.com/abc")
-
- # there should be a pending TCP connection
- clients = self.reactor.tcpClients
- self.assertEqual(len(clients), 1)
- (host, port, client_factory, _timeout, _bindAddress) = clients[0]
- self.assertEqual(host, "1.2.3.4")
- self.assertEqual(port, 443)
+ def test_http_request_use_proxy_empty_environment(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
- # make a test server, and wire up the client
- http_server = self._make_connection(
- client_factory,
- _get_test_protocol_factory(),
- ssl=True,
- expected_sni=b"test.com",
- )
+ @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
+ def test_http_request_via_uppercase_no_proxy(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
- # the FakeTransport is async, so we need to pump the reactor
- self.reactor.advance(0)
-
- # now there should be a pending request
- self.assertEqual(len(http_server.requests), 1)
+ @patch.dict(
+ os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
+ )
+ def test_http_request_via_no_proxy(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
- request = http_server.requests[0]
- self.assertEqual(request.method, b"GET")
- self.assertEqual(request.path, b"/abc")
- self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
- request.write(b"result")
- request.finish()
+ @patch.dict(
+ os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
+ )
+ def test_https_request_via_no_proxy(self):
+ agent = ProxyAgent(
+ self.reactor, contextFactory=get_test_https_policy(), use_proxy=True,
+ )
+ self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
- self.reactor.advance(0)
+ @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
+ def test_http_request_via_no_proxy_star(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
- resp = self.successResultOf(d)
- body = self.successResultOf(treq.content(resp))
- self.assertEqual(body, b"result")
+ @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
+ def test_https_request_via_no_proxy_star(self):
+ agent = ProxyAgent(
+ self.reactor, contextFactory=get_test_https_policy(), use_proxy=True,
+ )
+ self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
+ @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
def test_http_request_via_proxy(self):
- agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
+ agent = ProxyAgent(self.reactor, use_proxy=True)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
d = agent.request(b"GET", b"http://test.com")
@@ -214,11 +238,24 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
+ @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
def test_https_request_via_proxy(self):
+ """Tests that TLS-encrypted requests can be made through a proxy"""
+ self._do_https_request_via_proxy(auth_credentials=None)
+
+ @patch.dict(
+ os.environ,
+ {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
+ )
+ def test_https_request_via_proxy_with_auth(self):
+ """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
+ self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
+
+ def _do_https_request_via_proxy(
+ self, auth_credentials: Optional[str] = None,
+ ):
agent = ProxyAgent(
- self.reactor,
- contextFactory=get_test_https_policy(),
- https_proxy=b"proxy.com",
+ self.reactor, contextFactory=get_test_https_policy(), use_proxy=True,
)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
@@ -251,6 +288,22 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(request.method, b"CONNECT")
self.assertEqual(request.path, b"test.com:443")
+ # Check whether auth credentials have been supplied to the proxy
+ proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+ b"Proxy-Authorization"
+ )
+
+ if auth_credentials is not None:
+ # Compute the correct header value for Proxy-Authorization
+ encoded_credentials = base64.b64encode(b"bob:pinkponies")
+ expected_header_value = b"Basic " + encoded_credentials
+
+ # Validate the header's value
+ self.assertIn(expected_header_value, proxy_auth_header_values)
+ else:
+ # Check that the Proxy-Authorization header has not been supplied to the proxy
+ self.assertIsNone(proxy_auth_header_values)
+
# tell the proxy server not to close the connection
proxy_server.persistent = True
@@ -285,6 +338,13 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(request.method, b"GET")
self.assertEqual(request.path, b"/abc")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+
+ # Check that the destination server DID NOT receive proxy credentials
+ proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+ b"Proxy-Authorization"
+ )
+ self.assertIsNone(proxy_auth_header_values)
+
request.write(b"result")
request.finish()
@@ -294,40 +354,81 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
- def test_http_request_via_proxy_with_blacklist(self):
+ @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
+ def test_https_request_via_uppercase_proxy_with_blacklist(self):
# The blacklist includes the configured proxy IP.
agent = ProxyAgent(
BlacklistingReactorWrapper(
self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
),
self.reactor,
- http_proxy=b"proxy.com:8888",
+ contextFactory=get_test_https_policy(),
+ use_proxy=True,
)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
- d = agent.request(b"GET", b"http://test.com")
+ d = agent.request(b"GET", b"https://test.com/abc")
# there should be a pending TCP connection
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, "1.2.3.5")
- self.assertEqual(port, 8888)
+ self.assertEqual(port, 1080)
- # make a test server, and wire up the client
- http_server = self._make_connection(
+ # make a test HTTP server, and wire up the client
+ proxy_server = self._make_connection(
client_factory, _get_test_protocol_factory()
)
+ # fish the transports back out so that we can do the old switcheroo
+ s2c_transport = proxy_server.transport
+ client_protocol = s2c_transport.other
+ c2s_transport = client_protocol.transport
+
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
+ # now there should be a pending CONNECT request
+ self.assertEqual(len(proxy_server.requests), 1)
+
+ request = proxy_server.requests[0]
+ self.assertEqual(request.method, b"CONNECT")
+ self.assertEqual(request.path, b"test.com:443")
+
+ # tell the proxy server not to close the connection
+ proxy_server.persistent = True
+
+ # this just stops the http Request trying to do a chunked response
+ # request.setHeader(b"Content-Length", b"0")
+ request.finish()
+
+ # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+ ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+ ssl_protocol = ssl_factory.buildProtocol(None)
+ http_server = ssl_protocol.wrappedProtocol
+
+ ssl_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, ssl_protocol)
+ )
+ c2s_transport.other = ssl_protocol
+
+ self.reactor.advance(0)
+
+ server_name = ssl_protocol._tlsConnection.get_servername()
+ expected_sni = b"test.com"
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
# now there should be a pending request
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b"GET")
- self.assertEqual(request.path, b"http://test.com")
+ self.assertEqual(request.path, b"/abc")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
request.write(b"result")
request.finish()
@@ -338,6 +439,7 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
+ @patch.dict(os.environ, {"https_proxy": "proxy.com"})
def test_https_request_via_proxy_with_blacklist(self):
# The blacklist includes the configured proxy IP.
agent = ProxyAgent(
@@ -346,7 +448,7 @@ class MatrixFederationAgentTests(TestCase):
),
self.reactor,
contextFactory=get_test_https_policy(),
- https_proxy=b"proxy.com",
+ use_proxy=True,
)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 60f0820cff..a3b304d316 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -401,8 +401,8 @@ class HTTPPusherTests(HomeserverTestCase):
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
)
- # check that this is low-priority
- self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+ # check that this is high-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_mention(self):
"""
@@ -477,8 +477,8 @@ class HTTPPusherTests(HomeserverTestCase):
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
)
- # check that this is low-priority
- self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+ # check that this is high-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_atroom(self):
"""
@@ -560,8 +560,8 @@ class HTTPPusherTests(HomeserverTestCase):
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
)
- # check that this is low-priority
- self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+ # check that this is high-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_push_unread_count_group_by_room(self):
"""
diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py
new file mode 100644
index 0000000000..aff563919d
--- /dev/null
+++ b/tests/push/test_presentable_names.py
@@ -0,0 +1,229 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable, Optional, Tuple
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.events import FrozenEvent
+from synapse.push.presentable_names import calculate_room_name
+from synapse.types import StateKey, StateMap
+
+from tests import unittest
+
+
+class MockDataStore:
+ """
+ A fake data store which stores a mapping of state key to event content.
+ (I.e. the state key is used as the event ID.)
+ """
+
+ def __init__(self, events: Iterable[Tuple[StateKey, dict]]):
+ """
+ Args:
+ events: A state map to event contents.
+ """
+ self._events = {}
+
+ for i, (event_id, content) in enumerate(events):
+ self._events[event_id] = FrozenEvent(
+ {
+ "event_id": "$event_id",
+ "type": event_id[0],
+ "sender": "@user:test",
+ "state_key": event_id[1],
+ "room_id": "#room:test",
+ "content": content,
+ "origin_server_ts": i,
+ },
+ RoomVersions.V1,
+ )
+
+ async def get_event(
+ self, event_id: StateKey, allow_none: bool = False
+ ) -> Optional[FrozenEvent]:
+ assert allow_none, "Mock not configured for allow_none = False"
+
+ return self._events.get(event_id)
+
+ async def get_events(self, event_ids: Iterable[StateKey]):
+ # This is cheating since it just returns all events.
+ return self._events
+
+
+class PresentableNamesTestCase(unittest.HomeserverTestCase):
+ USER_ID = "@test:test"
+ OTHER_USER_ID = "@user:test"
+
+ def _calculate_room_name(
+ self,
+ events: StateMap[dict],
+ user_id: str = "",
+ fallback_to_members: bool = True,
+ fallback_to_single_member: bool = True,
+ ):
+ # This isn't 100% accurate, but works with MockDataStore.
+ room_state_ids = {k[0]: k[0] for k in events}
+
+ return self.get_success(
+ calculate_room_name(
+ MockDataStore(events),
+ room_state_ids,
+ user_id or self.USER_ID,
+ fallback_to_members,
+ fallback_to_single_member,
+ )
+ )
+
+ def test_name(self):
+ """A room name event should be used."""
+ events = [
+ ((EventTypes.Name, ""), {"name": "test-name"}),
+ ]
+ self.assertEqual("test-name", self._calculate_room_name(events))
+
+ # Check if the event content has garbage.
+ events = [((EventTypes.Name, ""), {"foo": 1})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ events = [((EventTypes.Name, ""), {"name": 1})]
+ self.assertEqual(1, self._calculate_room_name(events))
+
+ def test_canonical_alias(self):
+ """An canonical alias should be used."""
+ events = [
+ ((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}),
+ ]
+ self.assertEqual("#test-name:test", self._calculate_room_name(events))
+
+ # Check if the event content has garbage.
+ events = [((EventTypes.CanonicalAlias, ""), {"foo": 1})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ def test_invite(self):
+ """An invite has special behaviour."""
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}),
+ ]
+ self.assertEqual("Invite from Other User", self._calculate_room_name(events))
+ self.assertIsNone(
+ self._calculate_room_name(events, fallback_to_single_member=False)
+ )
+ # Ensure this logic is skipped if we don't fallback to members.
+ self.assertIsNone(self._calculate_room_name(events, fallback_to_members=False))
+
+ # Check if the event content has garbage.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+ ]
+ self.assertEqual("Invite from @user:test", self._calculate_room_name(events))
+
+ # No member event for sender.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ]
+ self.assertEqual("Room Invite", self._calculate_room_name(events))
+
+ def test_no_members(self):
+ """Behaviour of an empty room."""
+ events = []
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ # Note that events with invalid (or missing) membership are ignored.
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+ ((EventTypes.Member, "@foo:test"), {"membership": "foo"}),
+ ]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ def test_no_other_members(self):
+ """Behaviour of a room with no other members in it."""
+ events = [
+ (
+ (EventTypes.Member, self.USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Me"},
+ ),
+ ]
+ self.assertEqual("Me", self._calculate_room_name(events))
+
+ # Check if the event content has no displayname.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual("@test:test", self._calculate_room_name(events))
+
+ # 3pid invite, use the other user (who is set as the sender).
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual(
+ "nobody", self._calculate_room_name(events, user_id=self.OTHER_USER_ID)
+ )
+
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+ ((EventTypes.ThirdPartyInvite, self.OTHER_USER_ID), {}),
+ ]
+ self.assertEqual(
+ "Inviting email address",
+ self._calculate_room_name(events, user_id=self.OTHER_USER_ID),
+ )
+
+ def test_one_other_member(self):
+ """Behaviour of a room with a single other member."""
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Other User"},
+ ),
+ ]
+ self.assertEqual("Other User", self._calculate_room_name(events))
+ self.assertIsNone(
+ self._calculate_room_name(events, fallback_to_single_member=False)
+ )
+
+ # Check if the event content has no displayname and is an invite.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.INVITE},
+ ),
+ ]
+ self.assertEqual("@user:test", self._calculate_room_name(events))
+
+ def test_other_members(self):
+ """Behaviour of a room with multiple other members."""
+ # Two other members.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Other User"},
+ ),
+ ((EventTypes.Member, "@foo:test"), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual("Other User and @foo:test", self._calculate_room_name(events))
+
+ # Three or more other members.
+ events.append(
+ ((EventTypes.Member, "@fourth:test"), {"membership": Membership.INVITE})
+ )
+ self.assertEqual("Other User and 2 others", self._calculate_room_name(events))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 1f4b5ca2ac..4a841f5bb8 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -29,7 +29,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"type": "m.room.history_visibility",
"sender": "@user:test",
"state_key": "",
- "room_id": "@room:test",
+ "room_id": "#room:test",
"content": content,
},
RoomVersions.V1,
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 3379189785..d5dce1f83f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Fake in memory Redis server that servers can connect to.
self._redis_server = FakeRedisPubSubServer()
+ # We may have an attempt to connect to redis for the external cache already.
+ self.connect_any_redis_attempts()
+
store = self.hs.get_datastore()
self.database_pool = store.db_pool
@@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
fake one.
"""
clients = self.reactor.tcpClients
- self.assertEqual(len(clients), 1)
- (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
- self.assertEqual(host, "localhost")
- self.assertEqual(port, 6379)
+ while clients:
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 6379)
- client_protocol = client_factory.buildProtocol(None)
- server_protocol = self._redis_server.buildProtocol(None)
+ client_protocol = client_factory.buildProtocol(None)
+ server_protocol = self._redis_server.buildProtocol(None)
- client_to_server_transport = FakeTransport(
- server_protocol, self.reactor, client_protocol
- )
- client_protocol.makeConnection(client_to_server_transport)
-
- server_to_client_transport = FakeTransport(
- client_protocol, self.reactor, server_protocol
- )
- server_protocol.makeConnection(server_to_client_transport)
+ client_to_server_transport = FakeTransport(
+ server_protocol, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
- return client_to_server_transport, server_to_client_transport
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, server_protocol
+ )
+ server_protocol.makeConnection(server_to_client_transport)
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
+
+ # Since we use SET/GET to cache things we can safely no-op them.
+ elif command == b"SET":
+ self.send("OK")
+ elif command == b"GET":
+ self.send(None)
else:
raise Exception("Unknown command")
@@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
# We assume bytes are just unicode strings.
obj = obj.decode("utf-8")
+ if obj is None:
+ return "$-1\r\n"
if isinstance(obj, str):
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
if isinstance(obj, int):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 04599c2fcf..59e58a38f7 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -28,6 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import devices, sync
+from synapse.types import JsonDict
from tests import unittest
from tests.test_utils import make_awaitable
@@ -468,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.user1 = self.register_user(
- "user1", "pass1", admin=False, displayname="Name 1"
- )
- self.user2 = self.register_user(
- "user2", "pass2", admin=False, displayname="Name 2"
- )
-
def test_no_auth(self):
"""
Try to list users without authentication.
@@ -488,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error is returned.
"""
+ self._create_users(1)
other_user_token = self.login("user1", "pass1")
channel = self.make_request("GET", self.url, access_token=other_user_token)
@@ -499,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
List all users, including deactivated users.
"""
+ self._create_users(2)
+
channel = self.make_request(
"GET",
self.url + "?deactivated=true",
@@ -511,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(3, channel.json_body["total"])
# Check that all fields are available
- for u in channel.json_body["users"]:
- self.assertIn("name", u)
- self.assertIn("is_guest", u)
- self.assertIn("admin", u)
- self.assertIn("user_type", u)
- self.assertIn("deactivated", u)
- self.assertIn("displayname", u)
- self.assertIn("avatar_url", u)
+ self._check_fields(channel.json_body["users"])
def test_search_term(self):
"""Test that searching for a users works correctly"""
@@ -549,6 +539,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# Check that users were returned
self.assertTrue("users" in channel.json_body)
+ self._check_fields(channel.json_body["users"])
users = channel.json_body["users"]
# Check that the expected number of users were returned
@@ -561,25 +552,30 @@ class UsersListTestCase(unittest.HomeserverTestCase):
u = users[0]
self.assertEqual(expected_user_id, u["name"])
+ self._create_users(2)
+
+ user1 = "@user1:test"
+ user2 = "@user2:test"
+
# Perform search tests
- _search_test(self.user1, "er1")
- _search_test(self.user1, "me 1")
+ _search_test(user1, "er1")
+ _search_test(user1, "me 1")
- _search_test(self.user2, "er2")
- _search_test(self.user2, "me 2")
+ _search_test(user2, "er2")
+ _search_test(user2, "me 2")
- _search_test(self.user1, "er1", "user_id")
- _search_test(self.user2, "er2", "user_id")
+ _search_test(user1, "er1", "user_id")
+ _search_test(user2, "er2", "user_id")
# Test case insensitive
- _search_test(self.user1, "ER1")
- _search_test(self.user1, "NAME 1")
+ _search_test(user1, "ER1")
+ _search_test(user1, "NAME 1")
- _search_test(self.user2, "ER2")
- _search_test(self.user2, "NAME 2")
+ _search_test(user2, "ER2")
+ _search_test(user2, "NAME 2")
- _search_test(self.user1, "ER1", "user_id")
- _search_test(self.user2, "ER2", "user_id")
+ _search_test(user1, "ER1", "user_id")
+ _search_test(user2, "ER2", "user_id")
_search_test(None, "foo")
_search_test(None, "bar")
@@ -587,6 +583,179 @@ class UsersListTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo", "user_id")
_search_test(None, "bar", "user_id")
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+
+ # negative limit
+ channel = self.make_request(
+ "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # negative from
+ channel = self.make_request(
+ "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # invalid guests
+ channel = self.make_request(
+ "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # invalid deactivated
+ channel = self.make_request(
+ "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ def test_limit(self):
+ """
+ Testing list of users with limit
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 5)
+ self.assertEqual(channel.json_body["next_token"], "5")
+ self._check_fields(channel.json_body["users"])
+
+ def test_from(self):
+ """
+ Testing list of users with a defined starting point (from)
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["users"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of users with a defined starting point and limit
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(channel.json_body["next_token"], "15")
+ self.assertEqual(len(channel.json_body["users"]), 10)
+ self._check_fields(channel.json_body["users"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), number_users)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), number_users)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 19)
+ self.assertEqual(channel.json_body["next_token"], "19")
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ channel = self.make_request(
+ "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def _check_fields(self, content: JsonDict):
+ """Checks that the expected user attributes are present in content
+ Args:
+ content: List that is checked for content
+ """
+ for u in content:
+ self.assertIn("name", u)
+ self.assertIn("is_guest", u)
+ self.assertIn("admin", u)
+ self.assertIn("user_type", u)
+ self.assertIn("deactivated", u)
+ self.assertIn("displayname", u)
+ self.assertIn("avatar_url", u)
+
+ def _create_users(self, number_users: int):
+ """
+ Create a number of users
+ Args:
+ number_users: Number of users to be created
+ """
+ for i in range(1, number_users + 1):
+ self.register_user(
+ "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
+ )
+
class DeactivateAccountTestCase(unittest.HomeserverTestCase):
@@ -612,7 +781,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
# set attributes for user
self.get_success(
- self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
+ self.store.set_profile_avatar_url("user", "mxc://servername/mediaid", 1)
)
self.get_success(
self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
@@ -767,8 +936,13 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
- self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
- self.assertEqual("User1", channel.json_body["displayname"])
+
+ # On DINUM's deployment we clear the profile information during a deactivation regardless,
+ # whereas on mainline we decided to only do this if the deactivation was performed with erase: True.
+ # The discrepancy is due to profile replication.
+ # See synapse.storage.databases.main.profile.ProfileWorkerStore.set_profiles_active
+ self.assertIsNone(channel.json_body["avatar_url"])
+ self.assertIsNone(channel.json_body["displayname"])
self._is_erased("@user:test", False)
@@ -1183,7 +1357,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# set attributes for user
self.get_success(
- self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
+ self.store.set_profile_avatar_url("user", "mxc://servername/mediaid", 1)
)
self.get_success(
self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
@@ -1198,6 +1372,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+
+ # On DINUM's deployment we clear the profile information during a deactivation regardless,
+ # whereas on mainline we decided to only do this if the deactivation was performed with erase: True.
+ # The discrepancy is due to profile replication.
+ # See synapse.storage.databases.main.profile.ProfileWorkerStore.set_profiles_active
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -1215,8 +1394,13 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
- self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
- self.assertEqual("User", channel.json_body["displayname"])
+
+ # On DINUM's deployment we clear the profile information during a deactivation regardless,
+ # whereas on mainline we decided to only do this if the deactivation was performed with erase: True.
+ # The discrepancy is due to profile replication.
+ # See synapse.storage.databases.main.profile.ProfileWorkerStore.set_profiles_active
+ self.assertIsNone(channel.json_body["avatar_url"])
+ self.assertIsNone(channel.json_body["displayname"])
# the user is deactivated, the threepid will be deleted
# Get user
@@ -1228,8 +1412,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
- self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
- self.assertEqual("User", channel.json_body["displayname"])
+ self.assertIsNone(channel.json_body["avatar_url"])
+ self.assertIsNone(channel.json_body["displayname"])
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
def test_change_name_deactivate_user_user_directory(self):
@@ -2211,3 +2395,67 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
+
+
+class ShadowBanRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+
+ self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to get information of an user without authentication.
+ """
+ channel = self.make_request("POST", self.url)
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_not_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ channel = self.make_request("POST", self.url, access_token=other_user_token)
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that shadow-banning for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+
+ channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+
+ def test_success(self):
+ """
+ Shadow-banning should succeed for an admin.
+ """
+ # The user starts off as not shadow-banned.
+ other_user_token = self.login("user", "pass")
+ result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ self.assertFalse(result.shadow_banned)
+
+ channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual({}, channel.json_body)
+
+ # Ensure the user is shadow-banned (and the cache was cleared).
+ result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ self.assertTrue(result.shadow_banned)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c0a9fc6925..61bdae0879 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -13,17 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
+from mock import Mock
+
+from twisted.internet import defer
import synapse.rest.admin
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import account
from tests import unittest
-class IdentityTestCase(unittest.HomeserverTestCase):
+class IdentityDisabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts fail when the HS's config disallows them."""
servlets = [
+ account.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
@@ -32,29 +37,128 @@ class IdentityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["testis"]
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_disabled(self):
+ channel = self.make_request(b"POST", "/createRoom", {}, access_token=self.tok)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ room_id = channel.json_body["room_id"]
+
+ data = {
+ "id_server": "testis",
+ "medium": "email",
+ "address": "test@example.com",
+ }
+ request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
+ channel = self.make_request(b"POST", request_url, data, access_token=self.tok)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
def test_3pid_lookup_disabled(self):
- self.hs.config.enable_3pid_lookup = False
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ channel = self.make_request("GET", url, access_token=self.tok)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
- self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
+ def test_3pid_bulk_lookup_disabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ channel = self.make_request("POST", url, data, access_token=self.tok)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+
+class IdentityEnabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts succeed when the HS's config allows them."""
+
+ servlets = [
+ account.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
- channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok)
+ config = self.default_config()
+ config["enable_3pid_lookup"] = True
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.return_value = defer.succeed({"mxid": "@f:test"})
+ mock_http_client.post_json_get_json.return_value = defer.succeed({})
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_identity_handler().http_client = mock_http_client
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_enabled(self):
+ channel = self.make_request(
+ b"POST", "/createRoom", b"{}", access_token=self.tok
+ )
self.assertEquals(channel.result["code"], b"200", channel.result)
room_id = channel.json_body["room_id"]
- params = {
+ data = {
"id_server": "testis",
"medium": "email",
"address": "test@example.com",
}
- request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
- channel = self.make_request(
- b"POST", request_url, request_data, access_token=tok
+ channel = self.make_request(b"POST", request_url, data, access_token=self.tok)
+
+ get_json = self.hs.get_identity_handler().http_client.get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "test@example.com", "medium": "email"},
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_3pid_lookup_enabled(self):
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ self.make_request("GET", url, access_token=self.tok)
+
+ get_json = self.hs.get_simple_http_client().get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "foo@bar.baz", "medium": "email"},
+ )
+
+ def test_3pid_bulk_lookup_enabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ self.make_request("POST", url, data, access_token=self.tok)
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ post_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/bulk_lookup",
+ {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]},
)
- self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 31dc832fd5..10b1fbac69 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -34,6 +34,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
"default_policy": {
@@ -243,6 +244,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
}
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
new file mode 100644
index 0000000000..97f379b5a5
--- /dev/null
+++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,1069 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import random
+import string
+from typing import Optional
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.api.errors import SynapseError
+from synapse.rest import admin
+from synapse.rest.client.v1 import directory, login, room
+from synapse.third_party_rules.access_rules import (
+ ACCESS_RULES_TYPE,
+ AccessRules,
+ RoomAccessRules,
+)
+from synapse.types import JsonDict, create_requester
+
+from tests import unittest
+
+
+class RoomAccessTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["third_party_event_rules"] = {
+ "module": "synapse.third_party_rules.access_rules.RoomAccessRules",
+ "config": {
+ "domains_forbidden_when_restricted": ["forbidden_domain"],
+ "id_server": "testis",
+ "freeze_room_with_no_admin": "true",
+ },
+ }
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ def send_invite(destination, room_id, event_id, pdu):
+ return defer.succeed(pdu)
+
+ def get_json(uri, args={}, headers=None):
+ address_domain = args["address"].split("@")[1]
+ return defer.succeed({"hs": address_domain})
+
+ def post_json_get_json(uri, post_json, args={}, headers=None):
+ token = "".join(random.choice(string.ascii_letters) for _ in range(10))
+ return defer.succeed(
+ {
+ "token": token,
+ "public_keys": [
+ {
+ "public_key": "serverpublickey",
+ "key_validity_url": "https://testis/pubkey/isvalid",
+ },
+ {
+ "public_key": "phemeralpublickey",
+ "key_validity_url": "https://testis/pubkey/ephemeral/isvalid",
+ },
+ ],
+ "display_name": "f...@b...",
+ }
+ )
+
+ mock_federation_client = Mock(spec=["send_invite"])
+ mock_federation_client.send_invite.side_effect = send_invite
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"],)
+ # Mocking the response for /info on the IS API.
+ mock_http_client.get_json.side_effect = get_json
+ # Mocking the response for /store-invite on the IS API.
+ mock_http_client.post_json_get_json.side_effect = post_json_get_json
+ self.hs = self.setup_test_homeserver(
+ config=config,
+ federation_client=mock_federation_client,
+ simple_http_client=mock_http_client,
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_identity_handler().blacklisting_http_client = mock_http_client
+
+ self.third_party_event_rules = self.hs.get_third_party_event_rules()
+
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.restricted_room = self.create_room()
+ self.unrestricted_room = self.create_room(rule=AccessRules.UNRESTRICTED)
+ self.direct_rooms = [
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ ]
+
+ self.invitee_id = self.register_user("invitee", "test")
+ self.invitee_tok = self.login("invitee", "test")
+
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ )
+
+ def test_create_room_no_rule(self):
+ """Tests that creating a room with no rule will set the default."""
+ room_id = self.create_room()
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.RESTRICTED)
+
+ def test_create_room_direct_no_rule(self):
+ """Tests that creating a direct room with no rule will set the default."""
+ room_id = self.create_room(direct=True)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.DIRECT)
+
+ def test_create_room_valid_rule(self):
+ """Tests that creating a room with a valid rule will set the right."""
+ room_id = self.create_room(rule=AccessRules.UNRESTRICTED)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.UNRESTRICTED)
+
+ def test_create_room_invalid_rule(self):
+ """Tests that creating a room with an invalid rule will set fail."""
+ self.create_room(rule=AccessRules.DIRECT, expected_code=400)
+
+ def test_create_room_direct_invalid_rule(self):
+ """Tests that creating a direct room with an invalid rule will fail.
+ """
+ self.create_room(direct=True, rule=AccessRules.RESTRICTED, expected_code=400)
+
+ def test_create_room_default_power_level_rules(self):
+ """Tests that a room created with no power level overrides instead uses the dinum
+ defaults
+ """
+ room_id = self.create_room(direct=True, rule=AccessRules.DIRECT)
+ power_levels = self.helper.get_state(room_id, "m.room.power_levels", self.tok)
+
+ # Inviting another user should require PL50, even in private rooms
+ self.assertEqual(power_levels["invite"], 50)
+ # Sending arbitrary state events should require PL100
+ self.assertEqual(power_levels["state_default"], 100)
+
+ def test_create_room_fails_on_incorrect_power_level_rules(self):
+ """Tests that a room created with power levels lower than that required are rejected"""
+ modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id)
+ modified_power_levels["invite"] = 0
+ modified_power_levels["state_default"] = 50
+
+ self.create_room(
+ direct=True,
+ rule=AccessRules.DIRECT,
+ initial_state=[
+ {"type": "m.room.power_levels", "content": modified_power_levels}
+ ],
+ expected_code=400,
+ )
+
+ def test_create_room_with_missing_power_levels_use_default_values(self):
+ """
+ Tests that a room created with custom power levels, but without defining invite or state_default
+ succeeds, but the missing values are replaced with the defaults.
+ """
+
+ # Attempt to create a room without defining "invite" or "state_default"
+ modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id)
+ del modified_power_levels["invite"]
+ del modified_power_levels["state_default"]
+ room_id = self.create_room(
+ direct=True,
+ rule=AccessRules.DIRECT,
+ initial_state=[
+ {"type": "m.room.power_levels", "content": modified_power_levels}
+ ],
+ )
+
+ # This should succeed, but the defaults should be put in place instead
+ room_power_levels = self.helper.get_state(
+ room_id, "m.room.power_levels", self.tok
+ )
+ self.assertEqual(room_power_levels["invite"], 50)
+ self.assertEqual(room_power_levels["state_default"], 100)
+
+ # And now the same test, but using power_levels_content_override instead
+ # of initial_state (which takes a slightly different codepath)
+ modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id)
+ del modified_power_levels["invite"]
+ del modified_power_levels["state_default"]
+ room_id = self.create_room(
+ direct=True,
+ rule=AccessRules.DIRECT,
+ power_levels_content_override=modified_power_levels,
+ )
+
+ # This should succeed, but the defaults should be put in place instead
+ room_power_levels = self.helper.get_state(
+ room_id, "m.room.power_levels", self.tok
+ )
+ self.assertEqual(room_power_levels["invite"], 50)
+ self.assertEqual(room_power_levels["state_default"], 100)
+
+ def test_existing_room_can_change_power_levels(self):
+ """Tests that a room created with default power levels can have their power levels
+ dropped after room creation
+ """
+ # Creates a room with the default power levels
+ room_id = self.create_room(
+ direct=True, rule=AccessRules.DIRECT, expected_code=200,
+ )
+
+ # Attempt to drop invite and state_default power levels after the fact
+ room_power_levels = self.helper.get_state(
+ room_id, "m.room.power_levels", self.tok
+ )
+ room_power_levels["invite"] = 0
+ room_power_levels["state_default"] = 50
+ self.helper.send_state(
+ room_id, "m.room.power_levels", room_power_levels, self.tok
+ )
+
+ def test_public_room(self):
+ """Tests that it's only possible to have a room listed in the public room list
+ if the access rule is restricted.
+ """
+ # Creating a room with the public_chat preset should succeed and set the access
+ # rule to restricted.
+ preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT)
+ self.assertEqual(
+ self.current_rule_in_room(preset_room_id), AccessRules.RESTRICTED
+ )
+
+ # Creating a room with the public join rule in its initial state should succeed
+ # and set the access rule to restricted.
+ init_state_room_id = self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ]
+ )
+ self.assertEqual(
+ self.current_rule_in_room(init_state_room_id), AccessRules.RESTRICTED
+ )
+
+ # List preset_room_id in the public room list
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/directory/list/room/%s" % (preset_room_id,),
+ {"visibility": "public"},
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # List init_state_room_id in the public room list
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/directory/list/room/%s" % (init_state_room_id,),
+ {"visibility": "public"},
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Changing access rule to unrestricted should fail.
+ self.change_rule_in_room(
+ preset_room_id, AccessRules.UNRESTRICTED, expected_code=403
+ )
+ self.change_rule_in_room(
+ init_state_room_id, AccessRules.UNRESTRICTED, expected_code=403
+ )
+
+ # Changing access rule to direct should fail.
+ self.change_rule_in_room(preset_room_id, AccessRules.DIRECT, expected_code=403)
+ self.change_rule_in_room(
+ init_state_room_id, AccessRules.DIRECT, expected_code=403
+ )
+
+ # Creating a new room with the public_chat preset and an access rule of direct
+ # should fail.
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=AccessRules.DIRECT,
+ expected_code=400,
+ )
+
+ # Changing join rule to public in an direct room should fail.
+ self.change_join_rule_in_room(
+ self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403
+ )
+
+ def test_restricted(self):
+ """Tests that in restricted mode we're unable to invite users from blacklisted
+ servers but can invite other users.
+
+ Also tests that the room can be published to, and removed from, the public room
+ list.
+ """
+ # We can't invite a user from a forbidden HS.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can invite a user which HS isn't forbidden.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:allowed_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.restricted_room,
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.restricted_room,
+ expected_code=200,
+ )
+
+ # We are allowed to publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room
+ data = {"visibility": "public"}
+
+ channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # We are allowed to remove the room from the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room
+ data = {"visibility": "private"}
+
+ channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ def test_direct(self):
+ """Tests that, in direct mode, other users than the initial two can't be invited,
+ but the following scenario works:
+ * invited user joins the room
+ * invited user leaves the room
+ * room creator re-invites invited user
+
+ Tests that a user from a HS that's in the list of forbidden domains (to use
+ in restricted mode) can be invited.
+
+ Tests that the room cannot be published to the public room list.
+ """
+ not_invited_user = "@not_invited:forbidden_domain"
+
+ # We can't invite a new user to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # The invited user can join the room.
+ self.helper.join(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can leave the room.
+ self.helper.leave(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can be re-invited to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # If we're alone in the room and have always been the only member, we can invite
+ # someone.
+ self.helper.invite(
+ room=self.direct_rooms[1],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # Disable the 3pid invite ratelimiter
+ burst = self.hs.config.rc_third_party_invite.burst_count
+ per_second = self.hs.config.rc_third_party_invite.per_second
+ self.hs.config.rc_third_party_invite.burst_count = 10
+ self.hs.config.rc_third_party_invite.per_second = 0.1
+
+ # We can't send a 3PID invite to a room that already has two members.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[0],
+ expected_code=403,
+ )
+
+ # We can't send a 3PID invite to a room that already has a pending invite.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[1],
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to a room in which we've always been the only member.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to a room in which there's a 3PID invite.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=403,
+ )
+
+ self.hs.config.rc_third_party_invite.burst_count = burst
+ self.hs.config.rc_third_party_invite.per_second = per_second
+
+ # We can't publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.direct_rooms[0]
+ data = {"visibility": "public"}
+
+ channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 403, channel.result)
+
+ def test_unrestricted(self):
+ """Tests that, in unrestricted mode, we can invite whoever we want, but we can
+ only change the power level of users that wouldn't be forbidden in restricted
+ mode.
+
+ Tests that the room cannot be published to the public room list.
+ """
+ # We can invite
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:not_forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a power level event that doesn't redefine the default PL or set a
+ # non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a power level event that redefines the default PL and doesn't set
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={
+ "users": {self.user_id: 100, "@test:not_forbidden_domain": 10},
+ "users_default": 10,
+ },
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't send a power level event that doesn't redefines the default PL but sets
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.unrestricted_room
+ data = {"visibility": "public"}
+
+ channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 403, channel.result)
+
+ def test_change_rules(self):
+ """Tests that we can only change the current rule from restricted to
+ unrestricted.
+ """
+ # We can't change the rule from restricted to direct.
+ self.change_rule_in_room(
+ room_id=self.restricted_room, new_rule=AccessRules.DIRECT, expected_code=403
+ )
+
+ # We can change the rule from restricted to unrestricted.
+ # Note that this changes self.restricted_room to an unrestricted room
+ self.change_rule_in_room(
+ room_id=self.restricted_room,
+ new_rule=AccessRules.UNRESTRICTED,
+ expected_code=200,
+ )
+
+ # We can't change the rule from unrestricted to restricted.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=AccessRules.RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from unrestricted to direct.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=AccessRules.DIRECT,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to restricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=AccessRules.RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=AccessRules.UNRESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't publish a room to the public room list and then change its rule to
+ # unrestricted
+
+ # Create a restricted room
+ test_room_id = self.create_room(rule=AccessRules.RESTRICTED)
+
+ # Publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % test_room_id
+ data = {"visibility": "public"}
+
+ channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Attempt to switch the room to "unrestricted"
+ self.change_rule_in_room(
+ room_id=test_room_id, new_rule=AccessRules.UNRESTRICTED, expected_code=403
+ )
+
+ # Attempt to switch the room to "direct"
+ self.change_rule_in_room(
+ room_id=test_room_id, new_rule=AccessRules.DIRECT, expected_code=403
+ )
+
+ def test_change_room_avatar(self):
+ """Tests that changing the room avatar is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ avatar_content = {
+ "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394},
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ }
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_name(self):
+ """Tests that changing the room name is always allowed unless the room is a direct
+ chat, in which case it's forbidden.
+ """
+
+ name_content = {"name": "My super room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_topic(self):
+ """Tests that changing the room topic is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ topic_content = {"topic": "Welcome to this room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_revoke_3pid_invite_direct(self):
+ """Tests that revoking a 3PID invite doesn't cause the room access rules module to
+ confuse the revokation as a new 3PID invite.
+ """
+ invite_token = "sometoken"
+
+ invite_body = {
+ "display_name": "ker...@exa...",
+ "public_keys": [
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ },
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I",
+ },
+ ],
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ }
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=self.tok,
+ )
+
+ invite_token = "someothertoken"
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ def test_check_event_allowed(self):
+ """Tests that RoomAccessRules.check_event_allowed behaves accordingly.
+
+ It tests that:
+ * forbidden users cannot join restricted rooms.
+ * forbidden users can only join unrestricted rooms if they have an invite.
+ """
+ event_creator = self.hs.get_event_creation_handler()
+
+ # Test that forbidden users cannot join restricted rooms
+ requester = create_requester(self.user_id)
+ allowed_requester = create_requester("@user:allowed_domain")
+ forbidden_requester = create_requester("@user:forbidden_domain")
+
+ # Assert a join event from a forbidden user to a restricted room is rejected
+ self.get_failure(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.restricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ ),
+ SynapseError,
+ )
+
+ # A join event from an non-forbidden user to a restricted room is allowed
+ self.get_success(
+ event_creator.create_event(
+ allowed_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.restricted_room,
+ "sender": allowed_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": allowed_requester.user.to_string(),
+ },
+ )
+ )
+
+ # Test that forbidden users can only join unrestricted rooms if they have an invite
+
+ # A forbidden user without an invite should not be able to join an unrestricted room
+ self.get_failure(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.unrestricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ ),
+ SynapseError,
+ )
+
+ # However, if we then invite this user...
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=requester.user.to_string(),
+ targ=forbidden_requester.user.to_string(),
+ tok=self.tok,
+ )
+
+ # And create another join event, making sure that its context states it's coming
+ # in after the above invite was made...
+ # Then the forbidden user should be able to join!
+ self.get_success(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.unrestricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ )
+ )
+
+ def test_freezing_a_room(self):
+ """Tests that the power levels in a room change to prevent new events from
+ non-admin users when the last admin of a room leaves.
+ """
+
+ def freeze_room_with_id_and_power_levels(
+ room_id: str, custom_power_levels_content: Optional[JsonDict] = None,
+ ):
+ # Invite a user to the room, they join with PL 0
+ self.helper.invite(
+ room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok,
+ )
+
+ # Invitee joins the room
+ self.helper.join(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ if not custom_power_levels_content:
+ # Retrieve the room's current power levels event content
+ power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.tok,
+ )
+ else:
+ power_levels = custom_power_levels_content
+
+ # Override the room's power levels with the given power levels content
+ self.helper.send_state(
+ room_id=room_id,
+ event_type="m.room.power_levels",
+ body=custom_power_levels_content,
+ tok=self.tok,
+ )
+
+ # Ensure that the invitee leaving the room does not change the power levels
+ self.helper.leave(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ # Retrieve the new power levels of the room
+ new_power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.tok,
+ )
+
+ # Ensure they have not changed
+ self.assertDictEqual(power_levels, new_power_levels)
+
+ # Invite the user back again
+ self.helper.invite(
+ room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok,
+ )
+
+ # Invitee joins the room
+ self.helper.join(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ # Now the admin leaves the room
+ self.helper.leave(
+ room=room_id, user=self.user_id, tok=self.tok,
+ )
+
+ # Check the power levels again
+ new_power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.invitee_tok,
+ )
+
+ # Ensure that the new power levels prevent anyone but admins from sending
+ # certain events
+ self.assertEquals(new_power_levels["state_default"], 100)
+ self.assertEquals(new_power_levels["events_default"], 100)
+ self.assertEquals(new_power_levels["kick"], 100)
+ self.assertEquals(new_power_levels["invite"], 100)
+ self.assertEquals(new_power_levels["ban"], 100)
+ self.assertEquals(new_power_levels["redact"], 100)
+ self.assertDictEqual(new_power_levels["events"], {})
+ self.assertDictEqual(new_power_levels["users"], {self.user_id: 100})
+
+ # Ensure new users entering the room aren't going to immediately become admins
+ self.assertEquals(new_power_levels["users_default"], 0)
+
+ # Test that freezing a room with the default power level state event content works
+ room1 = self.create_room()
+ freeze_room_with_id_and_power_levels(room1)
+
+ # Test that freezing a room with a power level state event that is missing
+ # `state_default` and `event_default` keys behaves as expected
+ room2 = self.create_room()
+ freeze_room_with_id_and_power_levels(
+ room2,
+ {
+ "ban": 50,
+ "events": {
+ "m.room.avatar": 50,
+ "m.room.canonical_alias": 50,
+ "m.room.history_visibility": 100,
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ },
+ "invite": 0,
+ "kick": 50,
+ "redact": 50,
+ "users": {self.user_id: 100},
+ "users_default": 0,
+ # Explicitly remove `state_default` and `event_default` keys
+ },
+ )
+
+ # Test that freezing a room with a power level state event that is *additionally*
+ # missing `ban`, `invite`, `kick` and `redact` keys behaves as expected
+ room3 = self.create_room()
+ freeze_room_with_id_and_power_levels(
+ room3,
+ {
+ "events": {
+ "m.room.avatar": 50,
+ "m.room.canonical_alias": 50,
+ "m.room.history_visibility": 100,
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ },
+ "users": {self.user_id: 100},
+ "users_default": 0,
+ # Explicitly remove `state_default` and `event_default` keys
+ # Explicitly remove `ban`, `invite`, `kick` and `redact` keys
+ },
+ )
+
+ def create_room(
+ self,
+ direct=False,
+ rule=None,
+ preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ initial_state=None,
+ power_levels_content_override=None,
+ expected_code=200,
+ ):
+ content = {"is_direct": direct, "preset": preset}
+
+ if rule:
+ content["initial_state"] = [
+ {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}}
+ ]
+
+ if initial_state:
+ if "initial_state" not in content:
+ content["initial_state"] = []
+
+ content["initial_state"] += initial_state
+
+ if power_levels_content_override:
+ content["power_levels_content_override"] = power_levels_content_override
+
+ channel = self.make_request(
+ "POST", "/_matrix/client/r0/createRoom", content, access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ if expected_code == 200:
+ return channel.json_body["room_id"]
+
+ def current_rule_in_room(self, room_id):
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["rule"]
+
+ def change_rule_in_room(self, room_id, new_rule, expected_code=200):
+ data = {"rule": new_rule}
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200):
+ data = {"join_rule": new_join_rule}
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_threepid_invite(self, address, room_id, expected_code=200):
+ params = {"id_server": "testis", "medium": "email", "address": address}
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/invite" % room_id,
+ json.dumps(params),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_state_with_state_key(
+ self, room_id, event_type, state_key, body, tok, expect_code=200
+ ):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id,
+ event_type,
+ state_key,
+ )
+
+ channel = self.make_request("PUT", path, json.dumps(body), access_token=tok)
+
+ self.assertEqual(channel.code, expect_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index e689c3fbea..0ebdf1415b 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -18,6 +18,7 @@ import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+from synapse.types import UserID
from tests import unittest
@@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase):
self.store = self.hs.get_datastore()
self.get_success(
- self.store.db_pool.simple_update(
- table="users",
- keyvalues={"name": self.banned_user_id},
- updatevalues={"shadow_banned": True},
- desc="shadow_ban",
- )
+ self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True)
)
self.other_user_id = self.register_user("otheruser", "pass")
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index d4e3165436..2548b3a80c 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -616,6 +616,41 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
+class RoomInviteRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ admin.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_rooms_ratelimit(self):
+ """Tests that invites in a room are actually rate-limited."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ for i in range(3):
+ self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,))
+
+ self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429)
+
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_users_ratelimit(self):
+ """Tests that invites to a specific user are actually rate-limited."""
+
+ for i in range(3):
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red")
+
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index cb87b80e33..177dc476da 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -24,7 +24,7 @@ import pkg_resources
import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, HttpResponseException
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
@@ -112,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)
+ @override_config({"rc_3pid_validation": {"burst_count": 3}})
+ def test_ratelimit_by_email(self):
+ """Test that we ratelimit /requestToken for the same email.
+ """
+ old_password = "monkey"
+ new_password = "kangeroo"
+
+ user_id = self.register_user("kermit", old_password)
+ self.login("kermit", old_password)
+
+ email = "test1@example.com"
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ def reset(ip):
+ client_secret = "foobar"
+ session_id = self._request_token(email, client_secret, ip)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ self._reset_password(new_password, session_id, client_secret)
+
+ self.email_attempts.clear()
+
+ # We expect to be able to make three requests before getting rate
+ # limited.
+ #
+ # We change IPs to ensure that we're not being ratelimited due to the
+ # same IP
+ reset("127.0.0.1")
+ reset("127.0.0.2")
+ reset("127.0.0.3")
+
+ with self.assertRaises(HttpResponseException) as cm:
+ reset("127.0.0.4")
+
+ self.assertEqual(cm.exception.code, 429)
+
def test_basic_password_reset_canonicalise_email(self):
"""Test basic password reset flow
Request password reset with different spelling
@@ -239,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(session_id)
- def _request_token(self, email, client_secret):
+ def _request_token(self, email, client_secret, ip="127.0.0.1"):
channel = self.make_request(
"POST",
b"account/password/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
+ client_ip=ip,
)
- self.assertEquals(200, channel.code, channel.result)
+
+ if channel.code != 200:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"],
+ )
return channel.json_body["sid"]
@@ -509,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_address_trim(self):
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
+ @override_config({"rc_3pid_validation": {"burst_count": 3}})
+ def test_ratelimit_by_ip(self):
+ """Tests that adding emails is ratelimited by IP
+ """
+
+ # We expect to be able to set three emails before getting ratelimited.
+ self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
+ self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
+ self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+
+ with self.assertRaises(HttpResponseException) as cm:
+ self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+
+ self.assertEqual(cm.exception.code, 429)
+
def test_add_email_if_disabled(self):
"""Test adding email to profile when doing so is disallowed
"""
@@ -777,7 +847,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
body["next_link"] = next_link
channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
- self.assertEquals(expect_code, channel.code, channel.result)
+
+ if channel.code != expect_code:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"],
+ )
return channel.json_body.get("sid")
@@ -823,10 +897,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def _add_email(self, request_email, expected_email):
"""Test adding an email to profile
"""
+ previous_email_attempts = len(self.email_attempts)
+
client_secret = "foobar"
session_id = self._request_token(request_email, client_secret)
- self.assertEquals(len(self.email_attempts), 1)
+ self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
link = self._get_link_from_email()
self._validate_token(link)
@@ -855,4 +931,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
+
+ threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
+ self.assertIn(expected_email, threepids)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 27db4f551e..67f7dc43c3 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -18,9 +18,15 @@
import datetime
import json
import os
+import os.path
+import tempfile
+
+from mock import Mock
import pkg_resources
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
@@ -85,13 +91,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")
- def test_POST_bad_username(self):
- request_data = json.dumps({"username": 777, "password": "monkey"})
- channel = self.make_request(b"POST", self.url, request_data)
-
- self.assertEquals(channel.result["code"], b"400", channel.result)
- self.assertEquals(channel.json_body["error"], "Invalid username")
-
def test_POST_user_valid(self):
user_id = "@kermit:test"
device_id = "frogfone"
@@ -289,6 +288,96 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(channel.json_body.get("sid"))
+class RegisterHideProfileTestCase(unittest.HomeserverTestCase):
+
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.url = b"/_matrix/client/r0/register"
+
+ config = self.default_config()
+ config["enable_registration"] = True
+ config["show_users_in_user_directory"] = False
+ config["replicate_user_profiles_to"] = ["fakeserver"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_profile_hidden(self):
+ user_id = self.register_user("kermit", "monkey")
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # We expect post_json_get_json to have been called twice: once with the original
+ # profile and once with the None profile resulting from the request to hide it
+ # from the user directory.
+ self.assertEqual(post_json.call_count, 2, post_json.call_args_list)
+
+ # Get the args (and not kwargs) passed to post_json.
+ args = post_json.call_args[0]
+ # Make sure the last call was attempting to replicate profiles.
+ split_uri = args[0].split("/")
+ self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0])
+ # Make sure the last profile update was overriding the user's profile to None.
+ self.assertEqual(args[1]["batch"][user_id], None, args[1])
+
+
+class AccountValidityTemplateDirectoryTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Create a custom template directory and a template inside to read
+ temp_dir = tempfile.mkdtemp()
+ self.account_renewed_fd, account_renewed_path = tempfile.mkstemp(dir=temp_dir)
+ self.invalid_token_fd, invalid_token_path = tempfile.mkstemp(dir=temp_dir)
+
+ self.account_renewed_template_contents = "Yay, your account has been renewed"
+ self.invalid_token_template_contents = "Boo, you used an invalid token. Booo"
+
+ # Add some content to the custom templates
+ with open(account_renewed_path, "w") as f:
+ f.write(self.account_renewed_template_contents)
+
+ with open(invalid_token_path, "w") as f:
+ f.write(self.invalid_token_template_contents)
+
+ # Write the config, specifying the custom template directory and name of the custom
+ # template files. They must be different than those that exist in the default
+ # template directory in order to properly test everything.
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ "template_dir": temp_dir,
+ "account_renewed_html_path": os.path.basename(account_renewed_path),
+ "invalid_token_html_path": os.path.basename(invalid_token_path),
+ }
+ self.hs = self.setup_test_homeserver(config=config)
+
+ return self.hs
+
+ def test_template_contents(self):
+ """Tests that the contents of the custom templates as specified in the config are
+ correct.
+ """
+ self.assertEquals(
+ self.hs.config.account_validity.account_validity_account_renewed_template.render(),
+ self.account_renewed_template_contents,
+ )
+
+ self.assertEquals(
+ self.hs.config.account_validity.account_validity_invalid_token_template.render(),
+ self.invalid_token_template_contents,
+ )
+
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -298,6 +387,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
logout.register_servlets,
account_validity.register_servlets,
+ account.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -408,6 +498,146 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.client.v1.profile.register_servlets,
+ synapse.rest.client.v1.room.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Set accounts to expire after a week
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ config["replicate_user_profiles_to"] = "test.is"
+
+ # Mock homeserver requests to an identity server
+ mock_http_client = Mock(spec=["post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_expired_user_in_directory(self):
+ """Test that an expired user is hidden in the user directory"""
+ # Create an admin user to search the user directory
+ admin_id = self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ # Ensure the admin never expires
+ url = "/_synapse/admin/v1/account_validity/validity"
+ params = {
+ "user_id": admin_id,
+ "expiration_ts": 999999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Mock the homeserver's HTTP client
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # Create a user
+ username = "kermit"
+ user_id = self.register_user(username, "monkey")
+ self.login(username, "monkey")
+ self.get_success(
+ self.hs.get_datastore().set_profile_displayname(username, "mr.kermit", 1)
+ )
+
+ # Check that a full profile for this user is replicated
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+ # Expire the user
+ url = "/_synapse/admin/v1/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Wait for the background job to run which hides expired users in the directory
+ self.reactor.advance(60 * 60 * 1000)
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's None, signifying that the user should be removed from the user
+ # directory because they were expired
+ replicated_content = batch[user_id]
+ self.assertIsNone(replicated_content)
+
+ # Now renew the user, and check they get replicated again to the identity server
+ url = "/_synapse/admin/v1/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 99999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.pump(10)
+ self.reactor.advance(10)
+ self.pump()
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None, signifying that the user is back in the user
+ # directory
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -470,8 +700,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
(user_id, tok) = self.create_user()
- # Move 6 days forward. This should trigger a renewal email to be sent.
- self.reactor.advance(datetime.timedelta(days=6).total_seconds())
+ # Move 5 days forward. This should trigger a renewal email to be sent.
+ self.reactor.advance(datetime.timedelta(days=5).total_seconds())
self.assertEqual(len(self.email_attempts), 1)
# Retrieving the URL from the email is too much pain for now, so we
@@ -482,14 +712,32 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
# Check that we're getting HTML back.
- content_type = None
- for header in channel.result.get("headers", []):
- if header[0] == b"Content-Type":
- content_type = header[1]
- self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+ content_type = channel.headers.getRawHeaders(b"Content-Type")
+ self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
# Check that the HTML we're getting is the one we expect on a successful renewal.
- expected_html = self.hs.config.account_validity.account_renewed_html_content
+ expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+ expected_html = self.hs.config.account_validity_account_renewed_template.render(
+ expiration_ts=expiration_ts
+ )
+ self.assertEqual(
+ channel.result["body"], expected_html.encode("utf8"), channel.result
+ )
+
+ # Move 1 day forward. Try to renew with the same token again.
+ url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
+ channel = self.make_request(b"GET", url)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Check that we're getting HTML back.
+ content_type = channel.headers.getRawHeaders(b"Content-Type")
+ self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
+
+ # Check that the HTML we're getting is the one we expect when reusing a
+ # token. The account expiration date should not have changed.
+ expected_html = self.hs.config.account_validity_account_previously_renewed_template.render(
+ expiration_ts=expiration_ts
+ )
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
)
@@ -509,15 +757,12 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"404", channel.result)
# Check that we're getting HTML back.
- content_type = None
- for header in channel.result.get("headers", []):
- if header[0] == b"Content-Type":
- content_type = header[1]
- self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+ content_type = channel.headers.getRawHeaders(b"Content-Type")
+ self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
# Check that the HTML we're getting is the one we expect when using an
# invalid/unknown token.
- expected_html = self.hs.config.account_validity.invalid_token_html_content
+ expected_html = self.hs.config.account_validity_invalid_token_template.render()
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
)
@@ -625,7 +870,12 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
config["account_validity"] = {"enabled": False}
self.hs = self.setup_test_homeserver(config=config)
- self.hs.config.account_validity.period = self.validity_period
+
+ # We need to set these directly, instead of in the homeserver config dict above.
+ # This is due to account validity-related config options not being read by
+ # Synapse when account_validity.enabled is False.
+ self.hs.get_datastore()._account_validity_period = self.validity_period
+ self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta
self.store = self.hs.get_datastore()
@@ -639,8 +889,6 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
"""
user_id = self.register_user("kermit_delta", "user")
- self.hs.config.account_validity.startup_job_max_delta = self.max_delta
-
now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing())
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 512e36c236..7f68032d9d 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,12 +16,22 @@
import json
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RelationTypes,
+)
+from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import read_marker, sync
+from synapse.rest.client.v2_alpha import knock, read_marker, sync
from tests import unittest
+from tests.federation.transport.test_knocking import (
+ KnockingStrippedStateEventHelperMixin,
+)
from tests.server import TimedOutException
+from tests.unittest import override_config
class FilterTestCase(unittest.HomeserverTestCase):
@@ -306,6 +316,89 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.make_request("GET", sync_url % (access_token, next_batch))
+class SyncKnockTestCase(
+ unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin
+):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ knock.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.url = "/sync?since=%s"
+ self.next_batch = "s0"
+
+ # Register the first user (used to create the room to knock on).
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ # Create the room we'll knock on.
+ self.room_id = self.helper.create_room_as(
+ self.user_id,
+ is_public=False,
+ room_version=RoomVersions.V7.identifier,
+ tok=self.tok,
+ )
+
+ # Register the second user (used to knock on the room).
+ self.knocker = self.register_user("knocker", "monkey")
+ self.knocker_tok = self.login("knocker", "monkey")
+
+ # Perform an initial sync for the knocking user.
+ channel = self.make_request(
+ "GET", self.url % self.next_batch, access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Store the next batch for the next request.
+ self.next_batch = channel.json_body["next_batch"]
+
+ # Set up some room state to test with.
+ self.expected_room_state = self.send_example_state_events_to_room(
+ hs, self.room_id, self.user_id
+ )
+
+ @override_config({"experimental_features": {"msc2403_enabled": True}})
+ def test_knock_room_state(self):
+ """Tests that /sync returns state from a room after knocking on it."""
+ # Knock on a room
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/xyz.amorgan.knock/%s" % (self.room_id,),
+ b"{}",
+ self.knocker_tok,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # We expect to see the knock event in the stripped room state later
+ self.expected_room_state[EventTypes.Member] = {
+ "content": {"membership": Membership.KNOCK, "displayname": "knocker"},
+ "state_key": "@knocker:test",
+ }
+
+ # Check that /sync includes stripped state from the room
+ channel = self.make_request(
+ "GET", self.url % self.next_batch, access_token=self.knocker_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Extract the stripped room state events from /sync
+ knock_entry = channel.json_body["rooms"][Membership.KNOCK]
+ room_state_events = knock_entry[self.room_id]["knock_state"]["events"]
+
+ # Validate that the knock membership event came last
+ self.assertEqual(room_state_events[-1]["type"], EventTypes.Member)
+
+ # Validate the stripped room state events
+ self.check_knock_room_state_against_room_state(
+ room_state_events, self.expected_room_state
+ )
+
+
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@@ -439,7 +532,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
)
self._check_unread_count(5)
- def _check_unread_count(self, expected_count: True):
+ def _check_unread_count(self, expected_count: int):
"""Syncs and compares the unread count with the expected value."""
channel = self.make_request(
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index ae2b32b131..a6c6985173 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
config = self.default_config()
config["media_store_path"] = self.media_store_path
- config["thumbnail_requirements"] = {}
config["max_image_pixels"] = 2000000
provider_config = {
@@ -313,15 +312,39 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
+ """Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found
)
def test_thumbnail_scale(self):
+ """Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)
+ def test_invalid_type(self):
+ """An invalid thumbnail type is never available."""
+ self._test_thumbnail("invalid", None, False)
+
+ @unittest.override_config(
+ {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
+ )
+ def test_no_thumbnail_crop(self):
+ """
+ Override the config to generate only scaled thumbnails, but request a cropped one.
+ """
+ self._test_thumbnail("crop", None, False)
+
+ @unittest.override_config(
+ {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
+ )
+ def test_no_thumbnail_scale(self):
+ """
+ Override the config to generate only cropped thumbnails, but request a scaled one.
+ """
+ self._test_thumbnail("scale", None, False)
+
def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
channel = make_request(
diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py
new file mode 100644
index 0000000000..27e59e3891
--- /dev/null
+++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,328 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+
+import synapse.rest.admin
+from synapse.config._base import ConfigError
+from synapse.rest.client.v1 import login, room
+from synapse.rulecheck.domain_rule_checker import DomainRuleChecker
+
+from tests import unittest
+
+
+class DomainRuleCheckerTestCase(unittest.TestCase):
+ def test_allowed(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ "domains_prevented_from_being_invited_to_published_rooms": ["target_two"],
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_one", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_two", "test:target_two", None, "room", False
+ )
+ )
+
+ # User can invite internal user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test1:target_one", None, "room", False, True
+ )
+ )
+
+ # User can invite external user to a non-published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, False
+ )
+ )
+
+ def test_disallowed(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ "source_four": [],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_one", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_one", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_four", "test:target_one", None, "room", False
+ )
+ )
+
+ # User cannot invite external user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, True
+ )
+ )
+
+ def test_default_allow(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_default_deny(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_config_parse(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ self.assertEquals(config, DomainRuleChecker.parse_config(config))
+
+ def test_config_parse_failure(self):
+ config = {
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ }
+ }
+ self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config)
+
+
+class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["localhost"]
+
+ config["spam_checker"] = {
+ "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker",
+ "config": {
+ "default": True,
+ "domain_mapping": {},
+ "can_only_join_rooms_with_invite": True,
+ "can_only_create_one_to_one_rooms": True,
+ "can_only_invite_during_room_creation": True,
+ "can_invite_by_third_party_id": False,
+ },
+ }
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user_id = self.register_user("admin_user", "pass", admin=True)
+ self.admin_access_token = self.login("admin_user", "pass")
+
+ self.normal_user_id = self.register_user("normal_user", "pass", admin=False)
+ self.normal_access_token = self.login("normal_user", "pass")
+
+ self.other_user_id = self.register_user("other_user", "pass", admin=False)
+
+ def test_admin_can_create_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_normal_user_cannot_create_empty_room(self):
+ channel = self._create_room(self.normal_access_token)
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_cannot_create_room_with_multiple_invites(self):
+ channel = self._create_room(
+ self.normal_access_token,
+ content={"invite": [self.other_user_id, self.admin_user_id]},
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly counts both normal and third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [self.other_user_id],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly rejects third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_can_room_with_single_invites(self):
+ channel = self._create_room(
+ self.normal_access_token, content={"invite": [self.other_user_id]}
+ )
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_cannot_join_public_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403
+ )
+
+ def test_can_join_invited_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ def test_cannot_invite(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ def test_cannot_3pid_invite(self):
+ """Test that unbound 3pid invites get rejected.
+ """
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ channel = self.make_request(
+ "POST",
+ "rooms/%s/invite" % (room_id),
+ {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"},
+ access_token=self.normal_access_token,
+ )
+ self.assertEqual(channel.code, 403, channel.result["body"])
+
+ def _create_room(self, token, content={}):
+ path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,)
+
+ channel = self.make_request(
+ "POST", path, content=json.dumps(content).encode("utf8"),
+ )
+
+ return channel
diff --git a/tests/server.py b/tests/server.py
index 5a85d5fe7f..6419c445ec 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -47,6 +47,7 @@ class FakeChannel:
site = attr.ib(type=Site)
_reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict))
+ _ip = attr.ib(type=str, default="127.0.0.1")
_producer = None
@property
@@ -120,7 +121,7 @@ class FakeChannel:
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU
- return address.IPv4Address("TCP", "127.0.0.1", 3423)
+ return address.IPv4Address("TCP", self._ip, 3423)
def getHost(self):
return None
@@ -196,6 +197,7 @@ def make_request(
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
+ client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Make a web request using the given method, path and content, and render it
@@ -223,6 +225,9 @@ def make_request(
will pump the reactor until the the renderer tells the channel the request
is finished.
+ client_ip: The IP to use as the requesting IP. Useful for testing
+ ratelimiting.
+
Returns:
channel
"""
@@ -250,7 +255,7 @@ def make_request(
if isinstance(content, str):
content = content.encode("utf8")
- channel = FakeChannel(site, reactor)
+ channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel)
req.content = BytesIO(content)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index e9e3bca3bf..30e46c650d 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -39,7 +39,7 @@ class DataStoreTestCase(unittest.TestCase):
)
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.user.localpart, self.displayname)
+ self.store.set_profile_displayname(self.user.localpart, self.displayname, 1)
)
users, total = yield defer.ensureDeferred(
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index ea63bd56b4..b7dde51224 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -36,7 +36,7 @@ class ProfileStoreTestCase(unittest.TestCase):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
)
self.assertEquals(
@@ -50,7 +50,7 @@ class ProfileStoreTestCase(unittest.TestCase):
# test set to None
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.u_frank.localpart, None)
+ self.store.set_profile_displayname(self.u_frank.localpart, None, 2)
)
self.assertIsNone(
@@ -67,7 +67,7 @@ class ProfileStoreTestCase(unittest.TestCase):
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ self.u_frank.localpart, "http://my.site/here", 1
)
)
@@ -82,7 +82,7 @@ class ProfileStoreTestCase(unittest.TestCase):
# test set to None
yield defer.ensureDeferred(
- self.store.set_profile_avatar_url(self.u_frank.localpart, None)
+ self.store.set_profile_avatar_url(self.u_frank.localpart, None, 2)
)
self.assertIsNone(
diff --git a/tests/test_preview.py b/tests/test_preview.py
index c19facc1cb..0c6cbbd921 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -261,3 +261,32 @@ class PreviewUrlTestCase(unittest.TestCase):
html = ""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEqual(og, {})
+
+ def test_invalid_encoding(self):
+ """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
+ html = """
+ <html>
+ <head><title>Foo</title></head>
+ <body>
+ Some text.
+ </body>
+ </html>
+ """
+ og = decode_and_calc_og(
+ html, "http://example.com/test.html", "invalid-encoding"
+ )
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
+
+ def test_invalid_encoding2(self):
+ """A body which doesn't match the sent character encoding."""
+ # Note that this contains an invalid UTF-8 sequence in the title.
+ html = b"""
+ <html>
+ <head><title>\xff\xff Foo</title></head>
+ <body>
+ Some text.
+ </body>
+ </html>
+ """
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+ self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
diff --git a/tests/test_types.py b/tests/test_types.py
index acdeea7a09..67ceea6e43 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -12,9 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import string_types
from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ GroupID,
+ RoomAlias,
+ UserID,
+ map_username_to_mxid_localpart,
+ strip_invalid_mxid_characters,
+)
from tests import unittest
@@ -107,3 +114,16 @@ class MapUsernameTestCase(unittest.TestCase):
self.assertEqual(
map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast"
)
+
+
+class StripInvalidMxidCharactersTestCase(unittest.TestCase):
+ def test_return_type(self):
+ unstripped = strip_invalid_mxid_characters("test")
+ stripped = strip_invalid_mxid_characters("test@")
+
+ self.assertTrue(isinstance(unstripped, string_types), type(unstripped))
+ self.assertTrue(isinstance(stripped, string_types), type(stripped))
+
+ def test_strip(self):
+ stripped = strip_invalid_mxid_characters("test@")
+ self.assertEqual(stripped, "test", stripped)
diff --git a/tests/unittest.py b/tests/unittest.py
index bbd295687c..767d5d6077 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -386,6 +386,7 @@ class HomeserverTestCase(TestCase):
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
+ client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Create a SynapseRequest at the path using the method and containing the
@@ -410,6 +411,9 @@ class HomeserverTestCase(TestCase):
custom_headers: (name, value) pairs to add as request headers
+ client_ip: The IP to use as the requesting IP. Useful for testing
+ ratelimiting.
+
Returns:
The FakeChannel object which stores the result of the request.
"""
@@ -426,6 +430,7 @@ class HomeserverTestCase(TestCase):
content_is_form,
await_result,
custom_headers,
+ client_ip,
)
def setup_test_homeserver(self, *args, **kwargs):
diff --git a/tests/utils.py b/tests/utils.py
index 022223cf24..7c5f29afdd 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -157,6 +157,7 @@ def default_config(name, parse=False):
"local": {"per_second": 10000, "burst_count": 10000},
"remote": {"per_second": 10000, "burst_count": 10000},
},
+ "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
"saml2_enabled": False,
"default_identity_server": None,
"key_refresh_interval": 24 * 60 * 60 * 1000,
@@ -171,6 +172,8 @@ def default_config(name, parse=False):
"update_user_directory": False,
"caches": {"global_factor": 1},
"listeners": [{"port": 0, "type": "http"}],
+ # Enable encryption by default in private rooms
+ "encryption_enabled_by_default_for_room_type": "invite",
}
if parse:
diff --git a/tox.ini b/tox.ini
index 0479186348..c18e89b014 100644
--- a/tox.ini
+++ b/tox.ini
@@ -26,7 +26,8 @@ deps =
pip>=10 ; python_version >= '3.6'
pip>=10,<21.0 ; python_version < '3.6'
-# directories/files we run the linters on
+# directories/files we run the linters on.
+# if you update this list, make sure to do the same in scripts-dev/lint.sh
lint_targets =
setup.py
synapse
@@ -156,7 +157,7 @@ commands = isort -c --df --sp setup.cfg {[base]lint_targets}
skip_install = True
deps = towncrier>=18.6.0rc1
commands =
- python -m towncrier.check --compare-with=origin/develop
+ python -m towncrier.check --compare-with=origin/dinsic
[testenv:check-sampleconfig]
commands = {toxinidir}/scripts-dev/generate_sample_config --check
@@ -188,5 +189,8 @@ commands=
[testenv:mypy]
deps =
{[base]deps}
+ # Type hints are broken with Twisted > 20.3.0, see https://github.com/matrix-org/synapse/issues/9513
+ # TODO: Remove after merging in the fixes from mainline
+ twisted==20.3.0
extras = all,mypy
commands = mypy
|