diff --git a/.buildkite/docker-compose.py35.pg95.yaml b/.buildkite/docker-compose.py35.pg95.yaml
deleted file mode 100644
index c6e8280e65..0000000000
--- a/.buildkite/docker-compose.py35.pg95.yaml
+++ /dev/null
@@ -1,23 +0,0 @@
-version: '3.1'
-
-services:
-
- postgres:
- image: postgres:9.5
- environment:
- POSTGRES_PASSWORD: postgres
- POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
- command: -c fsync=off
-
- testenv:
- image: python:3.5
- depends_on:
- - postgres
- env_file: docker-compose-env
- environment:
- SYNAPSE_POSTGRES_HOST: postgres
- SYNAPSE_POSTGRES_USER: postgres
- SYNAPSE_POSTGRES_PASSWORD: postgres
- working_dir: /src
- volumes:
- - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/docker-compose.py37.pg11.yaml b/.buildkite/docker-compose.py37.pg11.yaml
deleted file mode 100644
index 411c37f213..0000000000
--- a/.buildkite/docker-compose.py37.pg11.yaml
+++ /dev/null
@@ -1,23 +0,0 @@
-version: '3.1'
-
-services:
-
- postgres:
- image: postgres:11
- environment:
- POSTGRES_PASSWORD: postgres
- POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
- command: -c fsync=off
-
- testenv:
- image: python:3.7
- depends_on:
- - postgres
- env_file: docker-compose-env
- environment:
- SYNAPSE_POSTGRES_HOST: postgres
- SYNAPSE_POSTGRES_USER: postgres
- SYNAPSE_POSTGRES_PASSWORD: postgres
- working_dir: /src
- volumes:
- - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/docker-compose.py37.pg95.yaml b/.buildkite/docker-compose.py37.pg95.yaml
deleted file mode 100644
index 54ca794072..0000000000
--- a/.buildkite/docker-compose.py37.pg95.yaml
+++ /dev/null
@@ -1,23 +0,0 @@
-version: '3.1'
-
-services:
-
- postgres:
- image: postgres:9.5
- environment:
- POSTGRES_PASSWORD: postgres
- POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
- command: -c fsync=off
-
- testenv:
- image: python:3.7
- depends_on:
- - postgres
- env_file: docker-compose-env
- environment:
- SYNAPSE_POSTGRES_HOST: postgres
- SYNAPSE_POSTGRES_USER: postgres
- SYNAPSE_POSTGRES_PASSWORD: postgres
- working_dir: /src
- volumes:
- - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/docker-compose.py38.pg12.yaml b/.buildkite/docker-compose.py38.pg12.yaml
deleted file mode 100644
index 934a34cf02..0000000000
--- a/.buildkite/docker-compose.py38.pg12.yaml
+++ /dev/null
@@ -1,23 +0,0 @@
-version: '3.1'
-
-services:
-
- postgres:
- image: postgres:12
- environment:
- POSTGRES_PASSWORD: postgres
- POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8"
- command: -c fsync=off
-
- testenv:
- image: python:3.8
- depends_on:
- - postgres
- env_file: docker-compose-env
- environment:
- SYNAPSE_POSTGRES_HOST: postgres
- SYNAPSE_POSTGRES_USER: postgres
- SYNAPSE_POSTGRES_PASSWORD: postgres
- working_dir: /src
- volumes:
- - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
diff --git a/.buildkite/docker-compose.sytest.py37.redis.yaml b/.buildkite/docker-compose.sytest.py37.redis.yaml
deleted file mode 100644
index b9e80cc557..0000000000
--- a/.buildkite/docker-compose.sytest.py37.redis.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-version: '3.1'
-
-services:
-
- redis:
- image: redis:5.0
-
- sytest:
- image: matrixdotorg/sytest-synapse:py37
- depends_on:
- - redis
- env_file: docker-compose-env
- environment:
- POSTGRES: "1"
- WORKERS: "1"
- BLACKLIST: "synapse-blacklist-with-workers"
- REDIS: "redis"
- working_dir: "/src"
- entrypoint: ""
- volumes:
- - ${BUILDKITE_BUILD_CHECKOUT_PATH}:/src
- - ${BUILDKITE_BUILD_CHECKOUT_PATH}/logs:/logs
diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
index 8decfe33ed..3712f92ad9 100644
--- a/.buildkite/pipeline.yml
+++ b/.buildkite/pipeline.yml
@@ -10,6 +10,13 @@ x-yaml-aliases:
apt-get update && apt-get install -y xmlsec1
python -m pip install tox
+ retry: &retry_setup
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
env:
COVERALLS_REPO_TOKEN: wsJWOby6j0uCYFiCes3r0XauxO27mx8lD
@@ -19,7 +26,7 @@ steps:
- "python -m pip install tox"
- "tox -e check_codestyle"
plugins:
- - docker#v3.0.1:
+ - docker#v3.7.0:
image: "python:3.6"
mount-buildkite-agent: false
@@ -28,7 +35,7 @@ steps:
- "python -m pip install tox"
- "tox -e packaging"
plugins:
- - docker#v3.0.1:
+ - docker#v3.7.0:
image: "python:3.6"
mount-buildkite-agent: false
@@ -37,7 +44,7 @@ steps:
- "python -m pip install tox"
- "tox -e check_isort"
plugins:
- - docker#v3.0.1:
+ - docker#v3.7.0:
image: "python:3.6"
mount-buildkite-agent: false
@@ -46,7 +53,15 @@ steps:
- "python -m pip install tox"
- "tox -e check-sampleconfig"
plugins:
- - docker#v3.0.1:
+ - docker#v3.7.0:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - label: "\U0001F5A5 check unix line-endings"
+ command:
+ - "scripts-dev/check_line_terminators.sh"
+ plugins:
+ - docker#v3.7.0:
image: "python:3.6"
mount-buildkite-agent: false
@@ -55,202 +70,263 @@ steps:
- "python -m pip install tox"
- "tox -e mypy"
plugins:
- - docker#v3.0.1:
+ - docker#v3.7.0:
+ image: "python:3.7"
+ mount-buildkite-agent: false
+
+ - label: ":package: build distribution files"
+ branches: "release-*"
+ agents:
+ queue: ephemeral-small
+ command:
+ - python setup.py sdist bdist_wheel
+ plugins:
+ - docker#v3.7.0:
image: "python:3.7"
mount-buildkite-agent: false
+ - artifacts#v1.3.0:
+ upload:
+ - dist/*
- wait
################################################################################
#
- # `trial` tests
+ # Twisted `trial` tests
+ #
+ # Our Intent is to test:
+ # - All supported Python versions (using SQLite) with current dependencies
+ # - The oldest and newest supported pairings of Python and PostgreSQL
+ #
+ # We also test two special cases:
+ # - The newest supported Python, without any optional dependencies
+ # - The oldest supported Python, with its oldest supported dependency versions
#
################################################################################
- - label: ":python: 3.5 / SQLite / Old Deps"
+ # -- Special Case: Oldest Python w/ Oldest Deps
+
+ # anoa: I've commented this out for DINUM as it was breaking on the 1.31.0 merge
+ # and it was taking way too long to solve. DINUM aren't even using Python 3.6
+ # anyways.
+# - label: ":python: 3.6 (Old Deps)"
+# command:
+# - ".buildkite/scripts/test_old_deps.sh"
+# env:
+# TRIAL_FLAGS: "-j 2"
+# plugins:
+# - docker#v3.7.0:
+# # We use bionic to get an old python (3.6.5) and sqlite (3.22)
+# image: "ubuntu:bionic"
+# workdir: "/src"
+# mount-buildkite-agent: false
+# propagate-environment: true
+# - artifacts#v1.3.0:
+# upload: [ "_trial_temp/*/*.log" ]
+# retry: *retry_setup
+
+ # -- Special Case: Newest Python w/o Optional Deps
+
+ - label: ":python: 3.9 (No Extras)"
command:
- - ".buildkite/scripts/test_old_deps.sh"
+ - *trial_setup
+ - "tox -e py39-noextras,combine"
env:
TRIAL_FLAGS: "-j 2"
plugins:
- - docker#v3.0.1:
- image: "ubuntu:xenial" # We use xenial to get an old sqlite and python
+ - docker#v3.7.0:
+ image: "python:3.9"
workdir: "/src"
mount-buildkite-agent: false
propagate-environment: true
- - artifacts#v1.2.0:
+ - artifacts#v1.3.0:
upload: [ "_trial_temp/*/*.log" ]
-# - matrix-org/coveralls#v1.0:
-# parallel: "true"
- retry:
- automatic:
- - exit_status: -1
- limit: 2
- - exit_status: 2
- limit: 2
-
- - label: ":python: 3.5 / SQLite"
+ retry: *retry_setup
+
+ # -- All Supported Python Versions (SQLite)
+
+ - label: ":python: 3.6"
command:
- *trial_setup
- - "tox -e py35,combine"
+ - "tox -e py36,combine"
env:
TRIAL_FLAGS: "-j 2"
plugins:
- - docker#v3.0.1:
- image: "python:3.5"
+ - docker#v3.7.0:
+ image: "python:3.6"
workdir: "/src"
mount-buildkite-agent: false
propagate-environment: true
- - artifacts#v1.2.0:
+ - artifacts#v1.3.0:
upload: [ "_trial_temp/*/*.log" ]
-# - matrix-org/coveralls#v1.0:
-# parallel: "true"
- retry:
- automatic:
- - exit_status: -1
- limit: 2
- - exit_status: 2
- limit: 2
-
- - label: ":python: 3.6 / SQLite"
+ retry: *retry_setup
+
+ - label: ":python: 3.7"
command:
- *trial_setup
- - "tox -e py36,combine"
+ - "tox -e py37,combine"
env:
TRIAL_FLAGS: "-j 2"
plugins:
- - docker#v3.0.1:
- image: "python:3.6"
+ - docker#v3.7.0:
+ image: "python:3.7"
workdir: "/src"
mount-buildkite-agent: false
propagate-environment: true
- - artifacts#v1.2.0:
+ - artifacts#v1.3.0:
upload: [ "_trial_temp/*/*.log" ]
-# - matrix-org/coveralls#v1.0:
-# parallel: "true"
- retry:
- automatic:
- - exit_status: -1
- limit: 2
- - exit_status: 2
- limit: 2
-
- - label: ":python: 3.7 / SQLite"
+ retry: *retry_setup
+
+ - label: ":python: 3.8"
command:
- *trial_setup
- - "tox -e py37,combine"
+ - "tox -e py38,combine"
env:
TRIAL_FLAGS: "-j 2"
plugins:
- - docker#v3.0.1:
- image: "python:3.7"
+ - docker#v3.7.0:
+ image: "python:3.8"
workdir: "/src"
mount-buildkite-agent: false
propagate-environment: true
- - artifacts#v1.2.0:
+ - artifacts#v1.3.0:
upload: [ "_trial_temp/*/*.log" ]
-# - matrix-org/coveralls#v1.0:
-# parallel: "true"
- retry:
- automatic:
- - exit_status: -1
- limit: 2
- - exit_status: 2
- limit: 2
-
- - label: ":python: 3.5 / :postgres: 9.5"
- agents:
- queue: "medium"
- env:
- TRIAL_FLAGS: "-j 8"
- PYTHON_VERSION: "3.5"
- POSTGRES_VERSION: "9.5"
+ retry: *retry_setup
+
+ - label: ":python: 3.9"
command:
- *trial_setup
- - "python -m tox -e py35-postgres,combine"
+ - "tox -e py39,combine"
+ env:
+ TRIAL_FLAGS: "-j 2"
plugins:
- - docker-compose#v3.7.0:
- run: testenv
- config:
- - .buildkite/docker-compose.yaml
+ - docker#v3.7.0:
+ image: "python:3.9"
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
- artifacts#v1.3.0:
upload: [ "_trial_temp/*/*.log" ]
-# - matrix-org/coveralls#v1.0:
-# parallel: "true"
- retry:
- automatic:
- - exit_status: -1
- limit: 2
- - exit_status: 2
- limit: 2
-
- - label: ":python: 3.7 / :postgres: 11"
+ retry: *retry_setup
+
+ # -- Oldest and Newest Supported Python and Postgres Pairings
+
+ - label: ":python: 3.6 :postgres: 9.6"
agents:
queue: "medium"
env:
TRIAL_FLAGS: "-j 8"
- PYTHON_VERSION: "3.7"
- POSTGRES_VERSION: "11"
+ PYTHON_VERSION: "3.6"
+ POSTGRES_VERSION: "9.6"
command:
- *trial_setup
- - "tox -e py37-postgres,combine"
+ - "python -m tox -e py36-postgres,combine"
plugins:
+ - matrix-org/download#v1.1.0:
+ urls:
+ - https://raw.githubusercontent.com/matrix-org/pipelines/master/synapse/docker-compose.yaml
+ - https://raw.githubusercontent.com/matrix-org/pipelines/master/synapse/docker-compose-env
- docker-compose#v3.7.0:
run: testenv
config:
- - .buildkite/docker-compose.yaml
+ - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.yaml
- artifacts#v1.3.0:
upload: [ "_trial_temp/*/*.log" ]
-# - matrix-org/coveralls#v1.0:
-# parallel: "true"
- retry:
- automatic:
- - exit_status: -1
- limit: 2
- - exit_status: 2
- limit: 2
-
- - label: ":python: 3.8 / :postgres: 12"
+ retry: *retry_setup
+
+ - label: ":python: 3.9 :postgres: 13"
agents:
queue: "medium"
env:
TRIAL_FLAGS: "-j 8"
- PYTHON_VERSION: "3.8"
- POSTGRES_VERSION: "12"
+ PYTHON_VERSION: "3.9"
+ POSTGRES_VERSION: "13"
command:
- *trial_setup
- - "tox -e py38-postgres,combine"
+ - "python -m tox -e py39-postgres,combine"
plugins:
+ - matrix-org/download#v1.1.0:
+ urls:
+ - https://raw.githubusercontent.com/matrix-org/pipelines/master/synapse/docker-compose.yaml
+ - https://raw.githubusercontent.com/matrix-org/pipelines/master/synapse/docker-compose-env
- docker-compose#v3.7.0:
run: testenv
config:
- - .buildkite/docker-compose.yaml
+ - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.yaml
- artifacts#v1.3.0:
upload: [ "_trial_temp/*/*.log" ]
-# - matrix-org/coveralls#v1.0:
-# parallel: "true"
- retry:
- automatic:
- - exit_status: -1
- limit: 2
- - exit_status: 2
- limit: 2
+ retry: *retry_setup
+
+ # -- Experimentally test against PyPy
+ # Only runs when the build message includes the string "pypy"
+ # This step is allowed to fail
+
+ - label: ":python: PyPy3.6"
+ if: "build.message =~ /pypy/i || build.branch =~ /pypy/i"
+ soft_fail: true
+ command:
+ # No *trial_setup due to docker-library/pypy#52
+ - "apt-get update && apt-get install -y xmlsec1"
+ - "pypy -m pip install tox"
+
+ - "tox -e pypy36,combine"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ plugins:
+ - docker#v3.7.0:
+ image: "pypy:3.6"
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - artifacts#v1.3.0:
+ upload: [ "_trial_temp/*/*.log" ]
+ retry: *retry_setup
+
################################################################################
#
# Sytest
#
+ # Our tests have three dimensions:
+ # 1. Topology (Monolith, Workers, Workers w/ Redis)
+ # 2. Database (SQLite, PostgreSQL)
+ # 3. Python Version
+ #
+ # Tests can run against either a single or multiple PostgreSQL databases.
+ # This is configured by setting `MULTI_POSTGRES=1` in the environment.
+ #
+ # We mostly care about testing each topology.
+ # To vary the Python and Postgres versions, we use Docker images which are based
+ # on an assortment of Linux distributions.
+ # - "bionic" (Ubuntu 18.04) has Python 3.6 and Postgres 10
+ # - "buster" (Debian 10) has Python 3.7 and Postgres 11
+ # - "testing" (Debian 11) (currently) has Python 3.9 and Postgres 13
+ #
+ # TODO: this leaves us without sytests for Postgres 9.6. How much do we care
+ # about that?
+ #
+ # Our intent is to test:
+ # - Monolith:
+ # - Older Distro + SQLite
+ # - Older Python + Older PostgreSQL
+ # - Newer Python + Newer PostgreSQL
+ # - Workers:
+ # - Older Python + Older PostgreSQL (MULTI_POSTGRES)
+ # - Newer Python + Newer PostgreSQL (MULTI_POSTGRES)
+ # - Workers w/ Redis:
+ # - Newer Python + Newer PostgreSQL
+ #
################################################################################
- - label: "SyTest - :python: 3.5 / SQLite / Monolith"
+ - label: "SyTest Monolith :sqlite::ubuntu: 18.04"
agents:
queue: "medium"
command:
- "bash .buildkite/merge_base_branch.sh"
- "bash /bootstrap.sh synapse"
plugins:
- - docker#v3.0.1:
- image: "matrixdotorg/sytest-synapse:dinsic"
+ - docker#v3.7.0:
+ image: "matrixdotorg/sytest-synapse:bionic"
propagate-environment: true
always-pull: true
workdir: "/src"
@@ -259,21 +335,14 @@ steps:
shell: ["-x", "-c"]
mount-buildkite-agent: false
volumes: ["./logs:/logs"]
- - artifacts#v1.2.0:
+ - artifacts#v1.3.0:
upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
- matrix-org/annotate:
path: "logs/annotate.md"
style: "error"
-# - matrix-org/coveralls#v1.0:
-# parallel: "true"
- retry:
- automatic:
- - exit_status: -1
- limit: 2
- - exit_status: 2
- limit: 2
-
- - label: "SyTest - :python: 3.5 / :postgres: 9.6 / Monolith"
+ retry: *retry_setup
+
+ - label: "SyTest Monolith :postgres::ubuntu: 18.04"
agents:
queue: "medium"
env:
@@ -282,8 +351,8 @@ steps:
- "bash .buildkite/merge_base_branch.sh"
- "bash /bootstrap.sh synapse"
plugins:
- - docker#v3.0.1:
- image: "matrixdotorg/sytest-synapse:dinsic"
+ - docker#v3.7.0:
+ image: "matrixdotorg/sytest-synapse:bionic"
propagate-environment: true
always-pull: true
workdir: "/src"
@@ -292,24 +361,43 @@ steps:
shell: ["-x", "-c"]
mount-buildkite-agent: false
volumes: ["./logs:/logs"]
- - artifacts#v1.2.0:
+ - artifacts#v1.3.0:
upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
- matrix-org/annotate:
path: "logs/annotate.md"
style: "error"
-# - matrix-org/coveralls#v1.0:
-# parallel: "true"
- retry:
- automatic:
- - exit_status: -1
- limit: 2
- - exit_status: 2
- limit: 2
-
- - label: "SyTest - :python: 3.5 / :postgres: 9.6 / Workers"
+ retry: *retry_setup
+
+ - label: "SyTest Monolith :postgres::debian: testing"
agents:
queue: "medium"
env:
+ POSTGRES: "1"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash /bootstrap.sh synapse"
+ plugins:
+ - docker#v3.7.0:
+ image: "matrixdotorg/sytest-synapse:testing"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: "/bin/sh"
+ init: false
+ shell: ["-x", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.3.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ style: "error"
+ retry: *retry_setup
+
+ - label: "SyTest Workers :postgres::ubuntu: 18.04"
+ agents:
+ queue: "xlarge"
+ env:
MULTI_POSTGRES: "1" # Test with split out databases
POSTGRES: "1"
WORKERS: "1"
@@ -319,8 +407,8 @@ steps:
- "bash -c 'cat /src/sytest-blacklist /src/.buildkite/worker-blacklist > /src/synapse-blacklist-with-workers'"
- "bash /bootstrap.sh synapse"
plugins:
- - docker#v3.0.1:
- image: "matrixdotorg/sytest-synapse:dinsic"
+ - docker#v3.7.0:
+ image: "matrixdotorg/sytest-synapse:bionic"
propagate-environment: true
always-pull: true
workdir: "/src"
@@ -329,32 +417,28 @@ steps:
shell: ["-x", "-c"]
mount-buildkite-agent: false
volumes: ["./logs:/logs"]
- - artifacts#v1.2.0:
+ - artifacts#v1.3.0:
upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
- matrix-org/annotate:
path: "logs/annotate.md"
style: "error"
- # - matrix-org/coveralls#v1.0:
- # parallel: "true"
- retry:
- automatic:
- - exit_status: -1
- limit: 2
- - exit_status: 2
- limit: 2
+ retry: *retry_setup
-
- - label: "SyTest - :python: 3.8 / :postgres: 12 / Monolith"
+ - label: "SyTest Workers :postgres::debian: 10"
agents:
- queue: "medium"
+ queue: "xlarge"
env:
+ MULTI_POSTGRES: "1" # Test with split out databases
POSTGRES: "1"
+ WORKERS: "1"
+ BLACKLIST: "synapse-blacklist-with-workers"
command:
- "bash .buildkite/merge_base_branch.sh"
+ - "bash -c 'cat /src/sytest-blacklist /src/.buildkite/worker-blacklist > /src/synapse-blacklist-with-workers'"
- "bash /bootstrap.sh synapse"
plugins:
- - docker#v3.0.1:
- image: "matrixdotorg/sytest-synapse:dinsic"
+ - docker#v3.7.0:
+ image: "matrixdotorg/sytest-synapse:buster"
propagate-environment: true
always-pull: true
workdir: "/src"
@@ -363,35 +447,29 @@ steps:
shell: ["-x", "-c"]
mount-buildkite-agent: false
volumes: ["./logs:/logs"]
- - artifacts#v1.2.0:
+ - artifacts#v1.3.0:
upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
- matrix-org/annotate:
path: "logs/annotate.md"
style: "error"
-# - matrix-org/coveralls#v1.0:
-# parallel: "true"
- retry:
- automatic:
- - exit_status: -1
- limit: 2
- - exit_status: 2
- limit: 2
-
- - label: "SyTest - :python: 3.7 / :postgres: 11 / Workers"
+ retry: *retry_setup
+
+ - label: "SyTest Workers :redis::postgres::debian: 10"
agents:
- queue: "medium"
+ # this one seems to need a lot of memory.
+ queue: "xlarge"
env:
- MULTI_POSTGRES: "1" # Test with split out databases
POSTGRES: "1"
WORKERS: "1"
+ REDIS: "1"
BLACKLIST: "synapse-blacklist-with-workers"
command:
- "bash .buildkite/merge_base_branch.sh"
- "bash -c 'cat /src/sytest-blacklist /src/.buildkite/worker-blacklist > /src/synapse-blacklist-with-workers'"
- "bash /bootstrap.sh synapse"
plugins:
- - docker#v3.0.1:
- image: "matrixdotorg/sytest-synapse:dinsic"
+ - docker#v3.7.0:
+ image: "matrixdotorg/sytest-synapse:buster"
propagate-environment: true
always-pull: true
workdir: "/src"
@@ -400,92 +478,60 @@ steps:
shell: ["-x", "-c"]
mount-buildkite-agent: false
volumes: ["./logs:/logs"]
- - artifacts#v1.2.0:
+ - artifacts#v1.3.0:
upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
- matrix-org/annotate:
path: "logs/annotate.md"
style: "error"
- # - matrix-org/coveralls#v1.0:
- # parallel: "true"
- retry:
- automatic:
- - exit_status: -1
- limit: 2
- - exit_status: 2
- limit: 2
-
-# TODO: Enable once Synapse v1.13.0 is merged in
-# - label: "SyTest - :python: 3.7 / :postgres: 11 / Workers / :redis: Redis"
-# agents:
-# queue: "medium"
-# command:
-# - bash -c "cat /src/sytest-blacklist /src/.buildkite/worker-blacklist > /src/synapse-blacklist-with-workers && ./.buildkite/merge_base_branch.sh && /bootstrap.sh synapse --redis-host redis"
-# plugins:
-# - matrix-org/download#v1.1.0:
-# urls:
-# - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose.sytest.py37.redis.yaml
-# - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose-env
-# - docker-compose#v2.1.0:
-# run: sytest
-# config:
-# - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.sytest.py37.redis.yaml
-# - artifacts#v1.2.0:
-# upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/results.tap" ]
-# - matrix-org/annotate:
-# path: "logs/annotate.md"
-# style: "error"
-## - matrix-org/coveralls#v1.0:
-## parallel: "true"
-# retry:
-# automatic:
-# - exit_status: -1
-# limit: 2
-# - exit_status: 2
-# limit: 2
+ retry: *retry_setup
################################################################################
#
# synapse_port_db
#
+ # Tests the oldest and newest supported pairings of Python and PostgreSQL
+ #
################################################################################
- - label: "synapse_port_db / :python: 3.5 / :postgres: 9.5"
+ - label: "Port DB :python: 3.6 :postgres: 9.6"
agents:
queue: "medium"
+ env:
+ PYTHON_VERSION: "3.6"
+ POSTGRES_VERSION: "9.6"
command:
- "bash .buildkite/scripts/test_synapse_port_db.sh"
plugins:
- matrix-org/download#v1.1.0:
urls:
- - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose.py35.pg95.yaml
- - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose-env
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/anoa/dinsic_release_1_31_0/.buildkite/docker-compose.yaml
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/anoa/dinsic_release_1_31_0/.buildkite/docker-compose-env
- docker-compose#v2.1.0:
run: testenv
config:
- - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.py35.pg95.yaml
- - artifacts#v1.2.0:
+ - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.yaml
+ - artifacts#v1.3.0:
upload: [ "_trial_temp/*/*.log" ]
-# - matrix-org/coveralls#v1.0:
-# parallel: "true"
- - label: "synapse_port_db / :python: 3.7 / :postgres: 11"
+ - label: "Port DB :python: 3.9 :postgres: 13"
agents:
queue: "medium"
+ env:
+ PYTHON_VERSION: "3.9"
+ POSTGRES_VERSION: "13"
command:
- "bash .buildkite/scripts/test_synapse_port_db.sh"
plugins:
- matrix-org/download#v1.1.0:
urls:
- - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose.py37.pg11.yaml
- - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/dinsic/.buildkite/docker-compose-env
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/anoa/dinsic_release_1_31_0/.buildkite/docker-compose.yaml
+ - https://raw.githubusercontent.com/matrix-org/synapse-dinsic/anoa/dinsic_release_1_31_0/.buildkite/docker-compose-env
- docker-compose#v2.1.0:
run: testenv
config:
- - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.py37.pg11.yaml
- - artifacts#v1.2.0:
+ - /tmp/download-${BUILDKITE_BUILD_ID}/docker-compose.yaml
+ - artifacts#v1.3.0:
upload: [ "_trial_temp/*/*.log" ]
-# - matrix-org/coveralls#v1.0:
-# parallel: "true"
# - wait: ~
# continue_on_failure: true
@@ -504,6 +550,10 @@ steps:
- "docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile ."
# We use the complement:latest image to provide Complement's dependencies, but want
# to actually run against the latest version of Complement, so download it here.
+ # NOTE: We use the `anoa/knock_room_v7` branch here while knocking is still experimental on mainline.
+ # This branch essentially uses the stable identifiers for all knock-related room state so that things
+ # don't clash when rooms created on dinsic's Synapse potentially federate with mainline Synapse's in
+ # the future.
- "wget https://github.com/matrix-org/complement/archive/anoa/knock_room_v7.tar.gz"
- "tar -xzf knock_room_v7.tar.gz"
# Build a second docker image on top of the above image. This one sets up Synapse with a generated config file,
diff --git a/.buildkite/scripts/create_postgres_db.py b/.buildkite/scripts/create_postgres_db.py
index df6082b0ac..956339de5c 100755
--- a/.buildkite/scripts/create_postgres_db.py
+++ b/.buildkite/scripts/create_postgres_db.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+
from synapse.storage.engines import create_engine
logger = logging.getLogger("create_postgres_db")
diff --git a/.buildkite/scripts/test_old_deps.sh b/.buildkite/scripts/test_old_deps.sh
index 28e6694b5d..3753f41a40 100755
--- a/.buildkite/scripts/test_old_deps.sh
+++ b/.buildkite/scripts/test_old_deps.sh
@@ -1,16 +1,16 @@
-#!/bin/bash
+#!/usr/bin/env bash
# this script is run by buildkite in a plain `xenial` container; it installs the
-# minimal requirements for tox and hands over to the py35-old tox environment.
+# minimal requirements for tox and hands over to the py3-old tox environment.
set -ex
apt-get update
-apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
+apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
export LANG="C.UTF-8"
# Prevent virtualenv from auto-updating pip to an incompatible version
export VIRTUALENV_NO_DOWNLOAD=1
-exec tox -e py35-old,combine
+exec tox -e py3-old,combine
diff --git a/.buildkite/scripts/test_synapse_port_db.sh b/.buildkite/scripts/test_synapse_port_db.sh
index 9ed2177635..8914319e38 100755
--- a/.buildkite/scripts/test_synapse_port_db.sh
+++ b/.buildkite/scripts/test_synapse_port_db.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# Test script for 'synapse_port_db', which creates a virtualenv, installs Synapse along
# with additional dependencies needed for the test (such as coverage or the PostgreSQL
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 088da55735..1ac48a71ba 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -7,12 +7,14 @@ jobs:
- checkout
- docker_prepare
- run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
+ # for release builds, we want to get the amd64 image out asap, so first
+ # we do an amd64-only build, before following up with a multiarch build.
- docker_build:
tag: -t matrixdotorg/synapse:${CIRCLE_TAG}
platforms: linux/amd64
- docker_build:
tag: -t matrixdotorg/synapse:${CIRCLE_TAG}
- platforms: linux/amd64,linux/arm/v7,linux/arm64
+ platforms: linux/amd64,linux/arm64
dockerhubuploadlatest:
docker:
@@ -21,12 +23,11 @@ jobs:
- checkout
- docker_prepare
- run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
+ # for `latest`, we don't want the arm images to disappear, so don't update the tag
+ # until all of the platforms are built.
- docker_build:
tag: -t matrixdotorg/synapse:latest
- platforms: linux/amd64
- - docker_build:
- tag: -t matrixdotorg/synapse:latest
- platforms: linux/amd64,linux/arm/v7,linux/arm64
+ platforms: linux/amd64,linux/arm64
workflows:
build:
diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 0000000000..83ddd568c2
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,8 @@
+# Black reformatting (#5482).
+32e7c9e7f20b57dd081023ac42d6931a8da9b3a3
+
+# Target Python 3.5 with black (#8664).
+aff1eb7c671b0a3813407321d2702ec46c71fa56
+
+# Update black to 20.8b1 (#9381).
+0a00b7ff14890987f09112a2ae696c61001e6cf1
diff --git a/.gitignore b/.gitignore
index 9bb5bdd647..295a18b539 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,16 +6,19 @@
*.egg
*.egg-info
*.lock
-*.pyc
+*.py[cod]
*.snap
*.tac
_trial_temp/
_trial_temp*/
/out
+.DS_Store
+__pycache__/
# stuff that is likely to exist when you run a server locally
/*.db
/*.log
+/*.log.*
/*.log.config
/*.pid
/.python-version
diff --git a/CHANGES.md b/CHANGES.md
index 0cecb83498..27483532d0 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,720 @@
+Synapse 1.31.0 (2021-04-06)
+===========================
+
+**Note:** As announced in v1.25.0, and in line with the deprecation policy for platform dependencies, this is the last release to support Python 3.5 and PostgreSQL 9.5. Future versions of Synapse will require Python 3.6+ and PostgreSQL 9.6+, as per our [deprecation policy](docs/deprecation_policy.md).
+
+This is also the last release that the Synapse team will be publishing packages for Debian Stretch and Ubuntu Xenial.
+
+
+Improved Documentation
+----------------------
+
+- Add a document describing the deprecation policy for platform dependencies. ([\#9723](https://github.com/matrix-org/synapse/issues/9723))
+
+
+Internal Changes
+----------------
+
+- Revert using `dmypy run` in lint script. ([\#9720](https://github.com/matrix-org/synapse/issues/9720))
+- Pin flake8-bugbear's version. ([\#9734](https://github.com/matrix-org/synapse/issues/9734))
+
+
+Synapse 1.31.0rc1 (2021-03-30)
+==============================
+
+Features
+--------
+
+- Add support to OpenID Connect login for requiring attributes on the `userinfo` response. Contributed by Hubbe King. ([\#9609](https://github.com/matrix-org/synapse/issues/9609))
+- Add initial experimental support for a "space summary" API. ([\#9643](https://github.com/matrix-org/synapse/issues/9643), [\#9652](https://github.com/matrix-org/synapse/issues/9652), [\#9653](https://github.com/matrix-org/synapse/issues/9653))
+- Add support for the busy presence state as described in [MSC3026](https://github.com/matrix-org/matrix-doc/pull/3026). ([\#9644](https://github.com/matrix-org/synapse/issues/9644))
+- Add support for credentials for proxy authentication in the `HTTPS_PROXY` environment variable. ([\#9657](https://github.com/matrix-org/synapse/issues/9657))
+
+
+Bugfixes
+--------
+
+- Fix a longstanding bug that could cause issues when editing a reply to a message. ([\#9585](https://github.com/matrix-org/synapse/issues/9585))
+- Fix the `/capabilities` endpoint to return `m.change_password` as disabled if the local password database is not used for authentication. Contributed by @dklimpel. ([\#9588](https://github.com/matrix-org/synapse/issues/9588))
+- Check if local passwords are enabled before setting them for the user. ([\#9636](https://github.com/matrix-org/synapse/issues/9636))
+- Fix a bug where federation sending can stall due to `concurrent access` database exceptions when it falls behind. ([\#9639](https://github.com/matrix-org/synapse/issues/9639))
+- Fix a bug introduced in Synapse 1.30.1 which meant the suggested `pip` incantation to install an updated `cryptography` was incorrect. ([\#9699](https://github.com/matrix-org/synapse/issues/9699))
+
+
+Updates to the Docker image
+---------------------------
+
+- Speed up Docker builds and make it nicer to test against Complement while developing (install all dependencies before copying the project). ([\#9610](https://github.com/matrix-org/synapse/issues/9610))
+- Include [opencontainers labels](https://github.com/opencontainers/image-spec/blob/master/annotations.md#pre-defined-annotation-keys) in the Docker image. ([\#9612](https://github.com/matrix-org/synapse/issues/9612))
+
+
+Improved Documentation
+----------------------
+
+- Clarify that `register_new_matrix_user` is present also when installed via non-pip package. ([\#9074](https://github.com/matrix-org/synapse/issues/9074))
+- Update source install documentation to mention platform prerequisites before the source install steps. ([\#9667](https://github.com/matrix-org/synapse/issues/9667))
+- Improve worker documentation for fallback/web auth endpoints. ([\#9679](https://github.com/matrix-org/synapse/issues/9679))
+- Update the sample configuration for OIDC authentication. ([\#9695](https://github.com/matrix-org/synapse/issues/9695))
+
+
+Internal Changes
+----------------
+
+- Preparatory steps for removing redundant `outlier` data from `event_json.internal_metadata` column. ([\#9411](https://github.com/matrix-org/synapse/issues/9411))
+- Add type hints to the caching module. ([\#9442](https://github.com/matrix-org/synapse/issues/9442))
+- Introduce flake8-bugbear to the test suite and fix some of its lint violations. ([\#9499](https://github.com/matrix-org/synapse/issues/9499), [\#9659](https://github.com/matrix-org/synapse/issues/9659))
+- Add additional type hints to the Homeserver object. ([\#9631](https://github.com/matrix-org/synapse/issues/9631), [\#9638](https://github.com/matrix-org/synapse/issues/9638), [\#9675](https://github.com/matrix-org/synapse/issues/9675), [\#9681](https://github.com/matrix-org/synapse/issues/9681))
+- Only save remote cross-signing and device keys if they're different from the current ones. ([\#9634](https://github.com/matrix-org/synapse/issues/9634))
+- Rename storage function to fix spelling and not conflict with another function's name. ([\#9637](https://github.com/matrix-org/synapse/issues/9637))
+- Improve performance of federation catch up by sending the latest events in the room to the remote, rather than just the last event sent by the local server. ([\#9640](https://github.com/matrix-org/synapse/issues/9640), [\#9664](https://github.com/matrix-org/synapse/issues/9664))
+- In the `federation_client` commandline client, stop automatically adding the URL prefix, so that servlets on other prefixes can be tested. ([\#9645](https://github.com/matrix-org/synapse/issues/9645))
+- In the `federation_client` commandline client, handle inline `signing_key`s in `homeserver.yaml`. ([\#9647](https://github.com/matrix-org/synapse/issues/9647))
+- Fixed some antipattern issues to improve code quality. ([\#9649](https://github.com/matrix-org/synapse/issues/9649))
+- Add a storage method for pulling all current user presence state from the database. ([\#9650](https://github.com/matrix-org/synapse/issues/9650))
+- Import `HomeServer` from the proper module. ([\#9665](https://github.com/matrix-org/synapse/issues/9665))
+- Increase default join ratelimiting burst rate. ([\#9674](https://github.com/matrix-org/synapse/issues/9674))
+- Add type hints to third party event rules and visibility modules. ([\#9676](https://github.com/matrix-org/synapse/issues/9676))
+- Bump mypy-zope to 0.2.13 to fix "Cannot determine consistent method resolution order (MRO)" errors when running mypy a second time. ([\#9678](https://github.com/matrix-org/synapse/issues/9678))
+- Use interpreter from `$PATH` via `/usr/bin/env` instead of absolute paths in various scripts. ([\#9689](https://github.com/matrix-org/synapse/issues/9689))
+- Make it possible to use `dmypy`. ([\#9692](https://github.com/matrix-org/synapse/issues/9692))
+- Suppress "CryptographyDeprecationWarning: int_from_bytes is deprecated". ([\#9698](https://github.com/matrix-org/synapse/issues/9698))
+- Use `dmypy run` in lint script for improved performance in type-checking while developing. ([\#9701](https://github.com/matrix-org/synapse/issues/9701))
+- Fix undetected mypy error when using Python 3.6. ([\#9703](https://github.com/matrix-org/synapse/issues/9703))
+- Fix type-checking CI on develop. ([\#9709](https://github.com/matrix-org/synapse/issues/9709))
+
+
+Synapse 1.30.1 (2021-03-26)
+===========================
+
+This release is identical to Synapse 1.30.0, with the exception of explicitly
+setting a minimum version of Python's Cryptography library to ensure that users
+of Synapse are protected from the recent [OpenSSL security advisories](https://mta.openssl.org/pipermail/openssl-announce/2021-March/000198.html),
+especially CVE-2021-3449.
+
+Note that Cryptography defaults to bundling its own statically linked copy of
+OpenSSL, which means that you may not be protected by your operating system's
+security updates.
+
+It's also worth noting that Cryptography no longer supports Python 3.5, so
+admins deploying to older environments may not be protected against this or
+future vulnerabilities. Synapse will be dropping support for Python 3.5 at the
+end of March.
+
+
+Updates to the Docker image
+---------------------------
+
+- Ensure that the docker container has up to date versions of openssl. ([\#9697](https://github.com/matrix-org/synapse/issues/9697))
+
+
+Internal Changes
+----------------
+
+- Enforce that `cryptography` dependency is up to date to ensure it has the most recent openssl patches. ([\#9697](https://github.com/matrix-org/synapse/issues/9697))
+
+
+Synapse 1.30.0 (2021-03-22)
+===========================
+
+Note that this release deprecates the ability for appservices to
+call `POST /_matrix/client/r0/register` without the body parameter `type`. Appservice
+developers should use a `type` value of `m.login.application_service` as
+per [the spec](https://matrix.org/docs/spec/application_service/r0.1.2#server-admin-style-permissions).
+In future releases, calling this endpoint with an access token - but without a `m.login.application_service`
+type - will fail.
+
+
+No significant changes.
+
+
+Synapse 1.30.0rc1 (2021-03-16)
+==============================
+
+Features
+--------
+
+- Add prometheus metrics for number of users successfully registering and logging in. ([\#9510](https://github.com/matrix-org/synapse/issues/9510), [\#9511](https://github.com/matrix-org/synapse/issues/9511), [\#9573](https://github.com/matrix-org/synapse/issues/9573))
+- Add `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time` prometheus metrics, which monitor federation delays by reporting the timestamps of messages sent and received to a set of remote servers. ([\#9540](https://github.com/matrix-org/synapse/issues/9540))
+- Add support for generating JSON Web Tokens dynamically for use as OIDC client secrets. ([\#9549](https://github.com/matrix-org/synapse/issues/9549))
+- Optimise handling of incomplete room history for incoming federation. ([\#9601](https://github.com/matrix-org/synapse/issues/9601))
+- Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). ([\#9617](https://github.com/matrix-org/synapse/issues/9617))
+- Tell spam checker modules about the SSO IdP a user registered through if one was used. ([\#9626](https://github.com/matrix-org/synapse/issues/9626))
+
+
+Bugfixes
+--------
+
+- Fix long-standing bug when generating thumbnails for some images with transparency: `TypeError: cannot unpack non-iterable int object`. ([\#9473](https://github.com/matrix-org/synapse/issues/9473))
+- Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. ([\#9542](https://github.com/matrix-org/synapse/issues/9542), [\#9583](https://github.com/matrix-org/synapse/issues/9583))
+- Fix bug where federation requests were not correctly retried on 5xx responses. ([\#9567](https://github.com/matrix-org/synapse/issues/9567))
+- Fix re-activating an account via the admin API when local passwords are disabled. ([\#9587](https://github.com/matrix-org/synapse/issues/9587))
+- Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages. ([\#9597](https://github.com/matrix-org/synapse/issues/9597))
+- Fix a bug introduced in v1.28.0 where the OpenID Connect callback endpoint could error with a `MacaroonInitException`. ([\#9620](https://github.com/matrix-org/synapse/issues/9620))
+- Fix Internal Server Error on `GET /_synapse/client/saml2/authn_response` request. ([\#9623](https://github.com/matrix-org/synapse/issues/9623))
+
+
+Updates to the Docker image
+---------------------------
+
+- Make use of an improved malloc implementation (`jemalloc`) in the docker image. ([\#8553](https://github.com/matrix-org/synapse/issues/8553))
+
+
+Improved Documentation
+----------------------
+
+- Add relayd entry to reverse proxy example configurations. ([\#9508](https://github.com/matrix-org/synapse/issues/9508))
+- Improve the SAML2 upgrade notes for 1.27.0. ([\#9550](https://github.com/matrix-org/synapse/issues/9550))
+- Link to the "List user's media" admin API from the media admin API docs. ([\#9571](https://github.com/matrix-org/synapse/issues/9571))
+- Clarify the spam checker modules documentation example to mention that `parse_config` is a required method. ([\#9580](https://github.com/matrix-org/synapse/issues/9580))
+- Clarify the sample configuration for `stats` settings. ([\#9604](https://github.com/matrix-org/synapse/issues/9604))
+
+
+Deprecations and Removals
+-------------------------
+
+- The `synapse_federation_last_sent_pdu_age` and `synapse_federation_last_received_pdu_age` prometheus metrics have been removed. They are replaced by `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time`. ([\#9540](https://github.com/matrix-org/synapse/issues/9540))
+- Registering an Application Service user without using the `m.login.application_service` login type will be unsupported in an upcoming Synapse release. ([\#9559](https://github.com/matrix-org/synapse/issues/9559))
+
+
+Internal Changes
+----------------
+
+- Add tests to ResponseCache. ([\#9458](https://github.com/matrix-org/synapse/issues/9458))
+- Add type hints to purge room and server notice admin API. ([\#9520](https://github.com/matrix-org/synapse/issues/9520))
+- Add extra logging to ObservableDeferred when callbacks throw exceptions. ([\#9523](https://github.com/matrix-org/synapse/issues/9523))
+- Fix incorrect type hints. ([\#9528](https://github.com/matrix-org/synapse/issues/9528), [\#9543](https://github.com/matrix-org/synapse/issues/9543), [\#9591](https://github.com/matrix-org/synapse/issues/9591), [\#9608](https://github.com/matrix-org/synapse/issues/9608), [\#9618](https://github.com/matrix-org/synapse/issues/9618))
+- Add an additional test for purging a room. ([\#9541](https://github.com/matrix-org/synapse/issues/9541))
+- Add a `.git-blame-ignore-revs` file with the hashes of auto-formatting. ([\#9560](https://github.com/matrix-org/synapse/issues/9560))
+- Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle. ([\#9561](https://github.com/matrix-org/synapse/issues/9561))
+- Fix spurious errors reported by the `config-lint.sh` script. ([\#9562](https://github.com/matrix-org/synapse/issues/9562))
+- Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper. ([\#9563](https://github.com/matrix-org/synapse/issues/9563))
+- Do not have mypy ignore type hints from unpaddedbase64. ([\#9568](https://github.com/matrix-org/synapse/issues/9568))
+- Improve efficiency of calculating the auth chain in large rooms. ([\#9576](https://github.com/matrix-org/synapse/issues/9576))
+- Convert `synapse.types.Requester` to an `attrs` class. ([\#9586](https://github.com/matrix-org/synapse/issues/9586))
+- Add logging for redis connection setup. ([\#9590](https://github.com/matrix-org/synapse/issues/9590))
+- Improve logging when processing incoming transactions. ([\#9596](https://github.com/matrix-org/synapse/issues/9596))
+- Remove unused `stats.retention` setting, and emit a warning if stats are disabled. ([\#9604](https://github.com/matrix-org/synapse/issues/9604))
+- Prevent attempting to bundle aggregations for state events in /context APIs. ([\#9619](https://github.com/matrix-org/synapse/issues/9619))
+
+
+Synapse 1.29.0 (2021-03-08)
+===========================
+
+Note that synapse now expects an `X-Forwarded-Proto` header when used with a reverse proxy. Please see [UPGRADE.rst](UPGRADE.rst#upgrading-to-v1290) for more details on this change.
+
+
+No significant changes.
+
+
+Synapse 1.29.0rc1 (2021-03-04)
+==============================
+
+Features
+--------
+
+- Add rate limiters to cross-user key sharing requests. ([\#8957](https://github.com/matrix-org/synapse/issues/8957))
+- Add `order_by` to the admin API `GET /_synapse/admin/v1/users/<user_id>/media`. Contributed by @dklimpel. ([\#8978](https://github.com/matrix-org/synapse/issues/8978))
+- Add some configuration settings to make users' profile data more private. ([\#9203](https://github.com/matrix-org/synapse/issues/9203))
+- The `no_proxy` and `NO_PROXY` environment variables are now respected in proxied HTTP clients with the lowercase form taking precedence if both are present. Additionally, the lowercase `https_proxy` environment variable is now respected in proxied HTTP clients on top of existing support for the uppercase `HTTPS_PROXY` form and takes precedence if both are present. Contributed by Timothy Leung. ([\#9372](https://github.com/matrix-org/synapse/issues/9372))
+- Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users. ([\#9383](https://github.com/matrix-org/synapse/issues/9383), [\#9385](https://github.com/matrix-org/synapse/issues/9385))
+- Add support for regenerating thumbnails if they have been deleted but the original image is still stored. ([\#9438](https://github.com/matrix-org/synapse/issues/9438))
+- Add support for `X-Forwarded-Proto` header when using a reverse proxy. ([\#9472](https://github.com/matrix-org/synapse/issues/9472), [\#9501](https://github.com/matrix-org/synapse/issues/9501), [\#9512](https://github.com/matrix-org/synapse/issues/9512), [\#9539](https://github.com/matrix-org/synapse/issues/9539))
+
+
+Bugfixes
+--------
+
+- Fix a bug where users' pushers were not all deleted when they deactivated their account. ([\#9285](https://github.com/matrix-org/synapse/issues/9285), [\#9516](https://github.com/matrix-org/synapse/issues/9516))
+- Fix a bug where a lot of unnecessary presence updates were sent when joining a room. ([\#9402](https://github.com/matrix-org/synapse/issues/9402))
+- Fix a bug that caused multiple calls to the experimental `shared_rooms` endpoint to return stale results. ([\#9416](https://github.com/matrix-org/synapse/issues/9416))
+- Fix a bug in single sign-on which could cause a "No session cookie found" error. ([\#9436](https://github.com/matrix-org/synapse/issues/9436))
+- Fix bug introduced in v1.27.0 where allowing a user to choose their own username when logging in via single sign-on did not work unless an `idp_icon` was defined. ([\#9440](https://github.com/matrix-org/synapse/issues/9440))
+- Fix a bug introduced in v1.26.0 where some sequences were not properly configured when running `synapse_port_db`. ([\#9449](https://github.com/matrix-org/synapse/issues/9449))
+- Fix deleting pushers when using sharded pushers. ([\#9465](https://github.com/matrix-org/synapse/issues/9465), [\#9466](https://github.com/matrix-org/synapse/issues/9466), [\#9479](https://github.com/matrix-org/synapse/issues/9479), [\#9536](https://github.com/matrix-org/synapse/issues/9536))
+- Fix missing startup checks for the consistency of certain PostgreSQL sequences. ([\#9470](https://github.com/matrix-org/synapse/issues/9470))
+- Fix a long-standing bug where the media repository could leak file descriptors while previewing media. ([\#9497](https://github.com/matrix-org/synapse/issues/9497))
+- Properly purge the event chain cover index when purging history. ([\#9498](https://github.com/matrix-org/synapse/issues/9498))
+- Fix missing chain cover index due to a schema delta not being applied correctly. Only affected servers that ran development versions. ([\#9503](https://github.com/matrix-org/synapse/issues/9503))
+- Fix a bug introduced in v1.25.0 where `/_synapse/admin/join/` would fail when given a room alias. ([\#9506](https://github.com/matrix-org/synapse/issues/9506))
+- Prevent presence background jobs from running when presence is disabled. ([\#9530](https://github.com/matrix-org/synapse/issues/9530))
+- Fix rare edge case that caused a background update to fail if the server had rejected an event that had duplicate auth events. ([\#9537](https://github.com/matrix-org/synapse/issues/9537))
+
+
+Improved Documentation
+----------------------
+
+- Update the example systemd config to propagate reloads to individual units. ([\#9463](https://github.com/matrix-org/synapse/issues/9463))
+
+
+Internal Changes
+----------------
+
+- Add documentation and type hints to `parse_duration`. ([\#9432](https://github.com/matrix-org/synapse/issues/9432))
+- Remove vestiges of `uploads_path` configuration setting. ([\#9462](https://github.com/matrix-org/synapse/issues/9462))
+- Add a comment about systemd-python. ([\#9464](https://github.com/matrix-org/synapse/issues/9464))
+- Test that we require validated email for email pushers. ([\#9496](https://github.com/matrix-org/synapse/issues/9496))
+- Allow python to generate bytecode for synapse. ([\#9502](https://github.com/matrix-org/synapse/issues/9502))
+- Fix incorrect type hints. ([\#9515](https://github.com/matrix-org/synapse/issues/9515), [\#9518](https://github.com/matrix-org/synapse/issues/9518))
+- Add type hints to device and event report admin API. ([\#9519](https://github.com/matrix-org/synapse/issues/9519))
+- Add type hints to user admin API. ([\#9521](https://github.com/matrix-org/synapse/issues/9521))
+- Bump the versions of mypy and mypy-zope used for static type checking. ([\#9529](https://github.com/matrix-org/synapse/issues/9529))
+
+
+Synapse 1.28.0 (2021-02-25)
+===========================
+
+Note that this release drops support for ARMv7 in the official Docker images, due to repeated problems building for ARMv7 (and the associated maintenance burden this entails).
+
+This release also fixes the documentation included in v1.27.0 around the callback URI for SAML2 identity providers. If your server is configured to use single sign-on via a SAML2 IdP, you may need to make configuration changes. Please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
+
+
+Internal Changes
+----------------
+
+- Revert change in v1.28.0rc1 to remove the deprecated SAML endpoint. ([\#9474](https://github.com/matrix-org/synapse/issues/9474))
+
+
+Synapse 1.28.0rc1 (2021-02-19)
+==============================
+
+Removal warning
+---------------
+
+The v1 list accounts API is deprecated and will be removed in a future release.
+This API was undocumented and misleading. It can be replaced by the
+[v2 list accounts API](https://github.com/matrix-org/synapse/blob/release-v1.28.0/docs/admin_api/user_admin_api.rst#list-accounts),
+which has been available since Synapse 1.7.0 (2019-12-13).
+
+Please check if you're using any scripts which use the admin API and replace
+`GET /_synapse/admin/v1/users/<user_id>` with `GET /_synapse/admin/v2/users`.
+
+
+Features
+--------
+
+- New admin API to get the context of an event: `/_synapse/admin/rooms/{roomId}/context/{eventId}`. ([\#9150](https://github.com/matrix-org/synapse/issues/9150))
+- Further improvements to the user experience of registration via single sign-on. ([\#9300](https://github.com/matrix-org/synapse/issues/9300), [\#9301](https://github.com/matrix-org/synapse/issues/9301))
+- Add hook to spam checker modules that allow checking file uploads and remote downloads. ([\#9311](https://github.com/matrix-org/synapse/issues/9311))
+- Add support for receiving OpenID Connect authentication responses via form `POST`s rather than `GET`s. ([\#9376](https://github.com/matrix-org/synapse/issues/9376))
+- Add the shadow-banning status to the admin API for user info. ([\#9400](https://github.com/matrix-org/synapse/issues/9400))
+
+
+Bugfixes
+--------
+
+- Fix long-standing bug where sending email notifications would fail for rooms that the server had since left. ([\#9257](https://github.com/matrix-org/synapse/issues/9257))
+- Fix bug introduced in Synapse 1.27.0rc1 which meant the "session expired" error page during SSO registration was badly formatted. ([\#9296](https://github.com/matrix-org/synapse/issues/9296))
+- Assert a maximum length for some parameters for spec compliance. ([\#9321](https://github.com/matrix-org/synapse/issues/9321), [\#9393](https://github.com/matrix-org/synapse/issues/9393))
+- Fix additional errors when previewing URLs: "AttributeError 'NoneType' object has no attribute 'xpath'" and "ValueError: Unicode strings with encoding declaration are not supported. Please use bytes input or XML fragments without declaration.". ([\#9333](https://github.com/matrix-org/synapse/issues/9333))
+- Fix a bug causing Synapse to impose the wrong type constraints on fields when processing responses from appservices to `/_matrix/app/v1/thirdparty/user/{protocol}`. ([\#9361](https://github.com/matrix-org/synapse/issues/9361))
+- Fix bug where Synapse would occasionally stop reconnecting to Redis after the connection was lost. ([\#9391](https://github.com/matrix-org/synapse/issues/9391))
+- Fix a long-standing bug when upgrading a room: "TypeError: '>' not supported between instances of 'NoneType' and 'int'". ([\#9395](https://github.com/matrix-org/synapse/issues/9395))
+- Reduce the amount of memory used when generating the URL preview of a file that is larger than the `max_spider_size`. ([\#9421](https://github.com/matrix-org/synapse/issues/9421))
+- Fix a long-standing bug in the deduplication of old presence, resulting in no deduplication. ([\#9425](https://github.com/matrix-org/synapse/issues/9425))
+- The `ui_auth.session_timeout` config option can now be specified in terms of number of seconds/minutes/etc/. Contributed by Rishabh Arya. ([\#9426](https://github.com/matrix-org/synapse/issues/9426))
+- Fix a bug introduced in v1.27.0: "TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType." related to the user directory. ([\#9428](https://github.com/matrix-org/synapse/issues/9428))
+
+
+Updates to the Docker image
+---------------------------
+
+- Drop support for ARMv7 in Docker images. ([\#9433](https://github.com/matrix-org/synapse/issues/9433))
+
+
+Improved Documentation
+----------------------
+
+- Reorganize CHANGELOG.md. ([\#9281](https://github.com/matrix-org/synapse/issues/9281))
+- Add note to `auto_join_rooms` config option explaining existing rooms must be publicly joinable. ([\#9291](https://github.com/matrix-org/synapse/issues/9291))
+- Correct name of Synapse's service file in TURN howto. ([\#9308](https://github.com/matrix-org/synapse/issues/9308))
+- Fix the braces in the `oidc_providers` section of the sample config. ([\#9317](https://github.com/matrix-org/synapse/issues/9317))
+- Update installation instructions on Fedora. ([\#9322](https://github.com/matrix-org/synapse/issues/9322))
+- Add HTTP/2 support to the nginx example configuration. Contributed by David Vo. ([\#9390](https://github.com/matrix-org/synapse/issues/9390))
+- Update docs for using Gitea as OpenID provider. ([\#9404](https://github.com/matrix-org/synapse/issues/9404))
+- Document that pusher instances are shardable. ([\#9407](https://github.com/matrix-org/synapse/issues/9407))
+- Fix erroneous documentation from v1.27.0 about updating the SAML2 callback URL. ([\#9434](https://github.com/matrix-org/synapse/issues/9434))
+
+
+Deprecations and Removals
+-------------------------
+
+- Deprecate old admin API `GET /_synapse/admin/v1/users/<user_id>`. ([\#9429](https://github.com/matrix-org/synapse/issues/9429))
+
+
+Internal Changes
+----------------
+
+- Fix 'object name reserved for internal use' errors with recent versions of SQLite. ([\#9003](https://github.com/matrix-org/synapse/issues/9003))
+- Add experimental support for running Synapse with PyPy. ([\#9123](https://github.com/matrix-org/synapse/issues/9123))
+- Deny access to additional IP addresses by default. ([\#9240](https://github.com/matrix-org/synapse/issues/9240))
+- Update the `Cursor` type hints to better match PEP 249. ([\#9299](https://github.com/matrix-org/synapse/issues/9299))
+- Add debug logging for SRV lookups. Contributed by @Bubu. ([\#9305](https://github.com/matrix-org/synapse/issues/9305))
+- Improve logging for OIDC login flow. ([\#9307](https://github.com/matrix-org/synapse/issues/9307))
+- Share the code for handling required attributes between the CAS and SAML handlers. ([\#9326](https://github.com/matrix-org/synapse/issues/9326))
+- Clean up the code to load the metadata for OpenID Connect identity providers. ([\#9362](https://github.com/matrix-org/synapse/issues/9362))
+- Convert tests to use `HomeserverTestCase`. ([\#9377](https://github.com/matrix-org/synapse/issues/9377), [\#9396](https://github.com/matrix-org/synapse/issues/9396))
+- Update the version of black used to 20.8b1. ([\#9381](https://github.com/matrix-org/synapse/issues/9381))
+- Allow OIDC config to override discovered values. ([\#9384](https://github.com/matrix-org/synapse/issues/9384))
+- Remove some dead code from the acceptance of room invites path. ([\#9394](https://github.com/matrix-org/synapse/issues/9394))
+- Clean up an unused method in the presence handler code. ([\#9408](https://github.com/matrix-org/synapse/issues/9408))
+
+
+Synapse 1.27.0 (2021-02-16)
+===========================
+
+Note that this release includes a change in Synapse to use Redis as a cache ─ as well as a pub/sub mechanism ─ if Redis support is enabled for workers. No action is needed by server administrators, and we do not expect resource usage of the Redis instance to change dramatically.
+
+This release also changes the callback URI for OpenID Connect (OIDC) and SAML2 identity providers. If your server is configured to use single sign-on via an OIDC/OAuth2 or SAML2 IdP, you may need to make configuration changes. Please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
+
+This release also changes escaping of variables in the HTML templates for SSO or email notifications. If you have customised these templates, please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
+
+
+Bugfixes
+--------
+
+- Fix building Docker images for armv7. ([\#9405](https://github.com/matrix-org/synapse/issues/9405))
+
+
+Synapse 1.27.0rc2 (2021-02-11)
+==============================
+
+Features
+--------
+
+- Further improvements to the user experience of registration via single sign-on. ([\#9297](https://github.com/matrix-org/synapse/issues/9297))
+
+
+Bugfixes
+--------
+
+- Fix ratelimiting introduced in v1.27.0rc1 for invites to respect the `ratelimit` flag on application services. ([\#9302](https://github.com/matrix-org/synapse/issues/9302))
+- Do not automatically calculate `public_baseurl` since it can be wrong in some situations. Reverts behaviour introduced in v1.26.0. ([\#9313](https://github.com/matrix-org/synapse/issues/9313))
+
+
+Improved Documentation
+----------------------
+
+- Clarify the sample configuration for changes made to the template loading code. ([\#9310](https://github.com/matrix-org/synapse/issues/9310))
+
+
+Synapse 1.27.0rc1 (2021-02-02)
+==============================
+
+Features
+--------
+
+- Add an admin API for getting and deleting forward extremities for a room. ([\#9062](https://github.com/matrix-org/synapse/issues/9062))
+- Add an admin API for retrieving the current room state of a room. ([\#9168](https://github.com/matrix-org/synapse/issues/9168))
+- Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). ([\#9183](https://github.com/matrix-org/synapse/issues/9183), [\#9242](https://github.com/matrix-org/synapse/issues/9242))
+- Add an admin API endpoint for shadow-banning users. ([\#9209](https://github.com/matrix-org/synapse/issues/9209))
+- Add ratelimits to the 3PID `/requestToken` APIs. ([\#9238](https://github.com/matrix-org/synapse/issues/9238))
+- Add support to the OpenID Connect integration for adding the user's email address. ([\#9245](https://github.com/matrix-org/synapse/issues/9245))
+- Add ratelimits to invites in rooms and to specific users. ([\#9258](https://github.com/matrix-org/synapse/issues/9258))
+- Improve the user experience of setting up an account via single-sign on. ([\#9262](https://github.com/matrix-org/synapse/issues/9262), [\#9272](https://github.com/matrix-org/synapse/issues/9272), [\#9275](https://github.com/matrix-org/synapse/issues/9275), [\#9276](https://github.com/matrix-org/synapse/issues/9276), [\#9277](https://github.com/matrix-org/synapse/issues/9277), [\#9286](https://github.com/matrix-org/synapse/issues/9286), [\#9287](https://github.com/matrix-org/synapse/issues/9287))
+- Add phone home stats for encrypted messages. ([\#9283](https://github.com/matrix-org/synapse/issues/9283))
+- Update the redirect URI for OIDC authentication. ([\#9288](https://github.com/matrix-org/synapse/issues/9288))
+
+
+Bugfixes
+--------
+
+- Fix spurious errors in logs when deleting a non-existant pusher. ([\#9121](https://github.com/matrix-org/synapse/issues/9121))
+- Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled). ([\#9163](https://github.com/matrix-org/synapse/issues/9163))
+- Fix a long-standing bug where an internal server error was raised when attempting to preview an HTML document in an unknown character encoding. ([\#9164](https://github.com/matrix-org/synapse/issues/9164))
+- Fix a long-standing bug where invalid data could cause errors when calculating the presentable room name for push. ([\#9165](https://github.com/matrix-org/synapse/issues/9165))
+- Fix bug where we sometimes didn't detect that Redis connections had died, causing workers to not see new data. ([\#9218](https://github.com/matrix-org/synapse/issues/9218))
+- Fix a bug where `None` was passed to Synapse modules instead of an empty dictionary if an empty module `config` block was provided in the homeserver config. ([\#9229](https://github.com/matrix-org/synapse/issues/9229))
+- Fix a bug in the `make_room_admin` admin API where it failed if the admin with the greatest power level was not in the room. Contributed by Pankaj Yadav. ([\#9235](https://github.com/matrix-org/synapse/issues/9235))
+- Prevent password hashes from getting dropped if a client failed threepid validation during a User Interactive Auth stage. Removes a workaround for an ancient bug in Riot Web <v0.7.4. ([\#9265](https://github.com/matrix-org/synapse/issues/9265))
+- Fix single-sign-on when the endpoints are routed to synapse workers. ([\#9271](https://github.com/matrix-org/synapse/issues/9271))
+
+
+Improved Documentation
+----------------------
+
+- Add docs for using Gitea as OpenID provider. ([\#9134](https://github.com/matrix-org/synapse/issues/9134))
+- Add link to Matrix VoIP tester for turn-howto. ([\#9135](https://github.com/matrix-org/synapse/issues/9135))
+- Add notes on integrating with Facebook for SSO login. ([\#9244](https://github.com/matrix-org/synapse/issues/9244))
+
+
+Deprecations and Removals
+-------------------------
+
+- The `service_url` parameter in `cas_config` is deprecated in favor of `public_baseurl`. ([\#9199](https://github.com/matrix-org/synapse/issues/9199))
+- Add new endpoint `/_synapse/client/saml2` for SAML2 authentication callbacks, and deprecate the old endpoint `/_matrix/saml2`. ([\#9289](https://github.com/matrix-org/synapse/issues/9289))
+
+
+Internal Changes
+----------------
+
+- Add tests to `test_user.UsersListTestCase` for List Users Admin API. ([\#9045](https://github.com/matrix-org/synapse/issues/9045))
+- Various improvements to the federation client. ([\#9129](https://github.com/matrix-org/synapse/issues/9129))
+- Speed up chain cover calculation when persisting a batch of state events at once. ([\#9176](https://github.com/matrix-org/synapse/issues/9176))
+- Add a `long_description_type` to the package metadata. ([\#9180](https://github.com/matrix-org/synapse/issues/9180))
+- Speed up batch insertion when using PostgreSQL. ([\#9181](https://github.com/matrix-org/synapse/issues/9181), [\#9188](https://github.com/matrix-org/synapse/issues/9188))
+- Emit an error at startup if different Identity Providers are configured with the same `idp_id`. ([\#9184](https://github.com/matrix-org/synapse/issues/9184))
+- Improve performance of concurrent use of `StreamIDGenerators`. ([\#9190](https://github.com/matrix-org/synapse/issues/9190))
+- Add some missing source directories to the automatic linting script. ([\#9191](https://github.com/matrix-org/synapse/issues/9191))
+- Precompute joined hosts and store in Redis. ([\#9198](https://github.com/matrix-org/synapse/issues/9198), [\#9227](https://github.com/matrix-org/synapse/issues/9227))
+- Clean-up template loading code. ([\#9200](https://github.com/matrix-org/synapse/issues/9200))
+- Fix the Python 3.5 old dependencies build. ([\#9217](https://github.com/matrix-org/synapse/issues/9217))
+- Update `isort` to v5.7.0 to bypass a bug where it would disagree with `black` about formatting. ([\#9222](https://github.com/matrix-org/synapse/issues/9222))
+- Add type hints to handlers code. ([\#9223](https://github.com/matrix-org/synapse/issues/9223), [\#9232](https://github.com/matrix-org/synapse/issues/9232))
+- Fix Debian package building on Ubuntu 16.04 LTS (Xenial). ([\#9254](https://github.com/matrix-org/synapse/issues/9254))
+- Minor performance improvement during TLS handshake. ([\#9255](https://github.com/matrix-org/synapse/issues/9255))
+- Refactor the generation of summary text for email notifications. ([\#9260](https://github.com/matrix-org/synapse/issues/9260))
+- Restore PyPy compatibility by not calling CPython-specific GC methods when under PyPy. ([\#9270](https://github.com/matrix-org/synapse/issues/9270))
+
+
+Synapse 1.26.0 (2021-01-27)
+===========================
+
+This release brings a new schema version for Synapse and rolling back to a previous
+version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details
+on these changes and for general upgrade guidance.
+
+No significant changes since 1.26.0rc2.
+
+
+Synapse 1.26.0rc2 (2021-01-25)
+==============================
+
+Bugfixes
+--------
+
+- Fix receipts and account data not being sent down sync. Introduced in v1.26.0rc1. ([\#9193](https://github.com/matrix-org/synapse/issues/9193), [\#9195](https://github.com/matrix-org/synapse/issues/9195))
+- Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1. ([\#9210](https://github.com/matrix-org/synapse/issues/9210))
+
+
+Internal Changes
+----------------
+
+- Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. ([\#9189](https://github.com/matrix-org/synapse/issues/9189))
+- Bump minimum `psycopg2` version to v2.8. ([\#9204](https://github.com/matrix-org/synapse/issues/9204))
+
+
+Synapse 1.26.0rc1 (2021-01-20)
+==============================
+
+This release brings a new schema version for Synapse and rolling back to a previous
+version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details
+on these changes and for general upgrade guidance.
+
+Features
+--------
+
+- Add support for multiple SSO Identity Providers. ([\#9015](https://github.com/matrix-org/synapse/issues/9015), [\#9017](https://github.com/matrix-org/synapse/issues/9017), [\#9036](https://github.com/matrix-org/synapse/issues/9036), [\#9067](https://github.com/matrix-org/synapse/issues/9067), [\#9081](https://github.com/matrix-org/synapse/issues/9081), [\#9082](https://github.com/matrix-org/synapse/issues/9082), [\#9105](https://github.com/matrix-org/synapse/issues/9105), [\#9107](https://github.com/matrix-org/synapse/issues/9107), [\#9109](https://github.com/matrix-org/synapse/issues/9109), [\#9110](https://github.com/matrix-org/synapse/issues/9110), [\#9127](https://github.com/matrix-org/synapse/issues/9127), [\#9153](https://github.com/matrix-org/synapse/issues/9153), [\#9154](https://github.com/matrix-org/synapse/issues/9154), [\#9177](https://github.com/matrix-org/synapse/issues/9177))
+- During user-interactive authentication via single-sign-on, give a better error if the user uses the wrong account on the SSO IdP. ([\#9091](https://github.com/matrix-org/synapse/issues/9091))
+- Give the `public_baseurl` a default value, if it is not explicitly set in the configuration file. ([\#9159](https://github.com/matrix-org/synapse/issues/9159))
+- Improve performance when calculating ignored users in large rooms. ([\#9024](https://github.com/matrix-org/synapse/issues/9024))
+- Implement [MSC2176](https://github.com/matrix-org/matrix-doc/pull/2176) in an experimental room version. ([\#8984](https://github.com/matrix-org/synapse/issues/8984))
+- Add an admin API for protecting local media from quarantine. ([\#9086](https://github.com/matrix-org/synapse/issues/9086))
+- Remove a user's avatar URL and display name when deactivated with the Admin API. ([\#8932](https://github.com/matrix-org/synapse/issues/8932))
+- Update `/_synapse/admin/v1/users/<user_id>/joined_rooms` to work for both local and remote users. ([\#8948](https://github.com/matrix-org/synapse/issues/8948))
+- Add experimental support for handling to-device messages on worker processes. ([\#9042](https://github.com/matrix-org/synapse/issues/9042), [\#9043](https://github.com/matrix-org/synapse/issues/9043), [\#9044](https://github.com/matrix-org/synapse/issues/9044), [\#9130](https://github.com/matrix-org/synapse/issues/9130))
+- Add experimental support for handling `/keys/claim` and `/room_keys` APIs on worker processes. ([\#9068](https://github.com/matrix-org/synapse/issues/9068))
+- Add experimental support for handling `/devices` API on worker processes. ([\#9092](https://github.com/matrix-org/synapse/issues/9092))
+- Add experimental support for moving off receipts and account data persistence off master. ([\#9104](https://github.com/matrix-org/synapse/issues/9104), [\#9166](https://github.com/matrix-org/synapse/issues/9166))
+
+
+Bugfixes
+--------
+
+- Fix a long-standing issue where an internal server error would occur when requesting a profile over federation that did not include a display name / avatar URL. ([\#9023](https://github.com/matrix-org/synapse/issues/9023))
+- Fix a long-standing bug where some caches could grow larger than configured. ([\#9028](https://github.com/matrix-org/synapse/issues/9028))
+- Fix error handling during insertion of client IPs into the database. ([\#9051](https://github.com/matrix-org/synapse/issues/9051))
+- Fix bug where we didn't correctly record CPU time spent in `on_new_event` block. ([\#9053](https://github.com/matrix-org/synapse/issues/9053))
+- Fix a minor bug which could cause confusing error messages from invalid configurations. ([\#9054](https://github.com/matrix-org/synapse/issues/9054))
+- Fix incorrect exit code when there is an error at startup. ([\#9059](https://github.com/matrix-org/synapse/issues/9059))
+- Fix `JSONDecodeError` spamming the logs when sending transactions to remote servers. ([\#9070](https://github.com/matrix-org/synapse/issues/9070))
+- Fix "Failed to send request" errors when a client provides an invalid room alias. ([\#9071](https://github.com/matrix-org/synapse/issues/9071))
+- Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0. ([\#9114](https://github.com/matrix-org/synapse/issues/9114), [\#9116](https://github.com/matrix-org/synapse/issues/9116))
+- Fix corruption of `pushers` data when a postgres bouncer is used. ([\#9117](https://github.com/matrix-org/synapse/issues/9117))
+- Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login. ([\#9128](https://github.com/matrix-org/synapse/issues/9128))
+- Fix "Unhandled error in Deferred: BodyExceededMaxSize" errors when .well-known files that are too large. ([\#9108](https://github.com/matrix-org/synapse/issues/9108))
+- Fix "UnboundLocalError: local variable 'length' referenced before assignment" errors when the response body exceeds the expected size. This bug was introduced in v1.25.0. ([\#9145](https://github.com/matrix-org/synapse/issues/9145))
+- Fix a long-standing bug "ValueError: invalid literal for int() with base 10" when `/publicRooms` is requested with an invalid `server` parameter. ([\#9161](https://github.com/matrix-org/synapse/issues/9161))
+
+
+Improved Documentation
+----------------------
+
+- Add some extra docs for getting Synapse running on macOS. ([\#8997](https://github.com/matrix-org/synapse/issues/8997))
+- Correct a typo in the `systemd-with-workers` documentation. ([\#9035](https://github.com/matrix-org/synapse/issues/9035))
+- Correct a typo in `INSTALL.md`. ([\#9040](https://github.com/matrix-org/synapse/issues/9040))
+- Add missing `user_mapping_provider` configuration to the Keycloak OIDC example. Contributed by @chris-ruecker. ([\#9057](https://github.com/matrix-org/synapse/issues/9057))
+- Quote `pip install` packages when extras are used to avoid shells interpreting bracket characters. ([\#9151](https://github.com/matrix-org/synapse/issues/9151))
+
+
+Deprecations and Removals
+-------------------------
+
+- Remove broken and unmaintained `demo/webserver.py` script. ([\#9039](https://github.com/matrix-org/synapse/issues/9039))
+
+
+Internal Changes
+----------------
+
+- Improve efficiency of large state resolutions. ([\#8868](https://github.com/matrix-org/synapse/issues/8868), [\#9029](https://github.com/matrix-org/synapse/issues/9029), [\#9115](https://github.com/matrix-org/synapse/issues/9115), [\#9118](https://github.com/matrix-org/synapse/issues/9118), [\#9124](https://github.com/matrix-org/synapse/issues/9124))
+- Various clean-ups to the structured logging and logging context code. ([\#8939](https://github.com/matrix-org/synapse/issues/8939))
+- Ensure rejected events get added to some metadata tables. ([\#9016](https://github.com/matrix-org/synapse/issues/9016))
+- Ignore date-rotated homeserver logs saved to disk. ([\#9018](https://github.com/matrix-org/synapse/issues/9018))
+- Remove an unused column from `access_tokens` table. ([\#9025](https://github.com/matrix-org/synapse/issues/9025))
+- Add a `-noextras` factor to `tox.ini`, to support running the tests with no optional dependencies. ([\#9030](https://github.com/matrix-org/synapse/issues/9030))
+- Fix running unit tests when optional dependencies are not installed. ([\#9031](https://github.com/matrix-org/synapse/issues/9031))
+- Allow bumping schema version when using split out state database. ([\#9033](https://github.com/matrix-org/synapse/issues/9033))
+- Configure the linters to run on a consistent set of files. ([\#9038](https://github.com/matrix-org/synapse/issues/9038))
+- Various cleanups to device inbox store. ([\#9041](https://github.com/matrix-org/synapse/issues/9041))
+- Drop unused database tables. ([\#9055](https://github.com/matrix-org/synapse/issues/9055))
+- Remove unused `SynapseService` class. ([\#9058](https://github.com/matrix-org/synapse/issues/9058))
+- Remove unnecessary declarations in the tests for the admin API. ([\#9063](https://github.com/matrix-org/synapse/issues/9063))
+- Remove `SynapseRequest.get_user_agent`. ([\#9069](https://github.com/matrix-org/synapse/issues/9069))
+- Remove redundant `Homeserver.get_ip_from_request` method. ([\#9080](https://github.com/matrix-org/synapse/issues/9080))
+- Add type hints to media repository. ([\#9093](https://github.com/matrix-org/synapse/issues/9093))
+- Fix the wrong arguments being passed to `BlacklistingAgentWrapper` from `MatrixFederationAgent`. Contributed by Timothy Leung. ([\#9098](https://github.com/matrix-org/synapse/issues/9098))
+- Reduce the scope of caught exceptions in `BlacklistingAgentWrapper`. ([\#9106](https://github.com/matrix-org/synapse/issues/9106))
+- Improve `UsernamePickerTestCase`. ([\#9112](https://github.com/matrix-org/synapse/issues/9112))
+- Remove dependency on `distutils`. ([\#9125](https://github.com/matrix-org/synapse/issues/9125))
+- Enforce that replication HTTP clients are called with keyword arguments only. ([\#9144](https://github.com/matrix-org/synapse/issues/9144))
+- Fix the Python 3.5 / old dependencies build in CI. ([\#9146](https://github.com/matrix-org/synapse/issues/9146))
+- Replace the old `perspectives` option in the Synapse docker config file template with `trusted_key_servers`. ([\#9157](https://github.com/matrix-org/synapse/issues/9157))
+
+
+Synapse 1.25.0 (2021-01-13)
+===========================
+
+Ending Support for Python 3.5 and Postgres 9.5
+----------------------------------------------
+
+With this release, the Synapse team is announcing a formal deprecation policy for our platform dependencies, like Python and PostgreSQL:
+
+All future releases of Synapse will follow the upstream end-of-life schedules.
+
+Which means:
+
+* This is the last release which guarantees support for Python 3.5.
+* We will end support for PostgreSQL 9.5 early next month.
+* We will end support for Python 3.6 and PostgreSQL 9.6 near the end of the year.
+
+Crucially, this means __we will not produce .deb packages for Debian 9 (Stretch) or Ubuntu 16.04 (Xenial)__ beyond the transition period described below.
+
+The website https://endoflife.date/ has convenient summaries of the support schedules for projects like [Python](https://endoflife.date/python) and [PostgreSQL](https://endoflife.date/postgresql).
+
+If you are unable to upgrade your environment to a supported version of Python or Postgres, we encourage you to consider using the [Synapse Docker images](./INSTALL.md#docker-images-and-ansible-playbooks) instead.
+
+### Transition Period
+
+We will make a good faith attempt to avoid breaking compatibility in all releases through the end of March 2021. However, critical security vulnerabilities in dependencies or other unanticipated circumstances may arise which necessitate breaking compatibility earlier.
+
+We intend to continue producing .deb packages for Debian 9 (Stretch) and Ubuntu 16.04 (Xenial) through the transition period.
+
+Removal warning
+---------------
+
+The old [Purge Room API](https://github.com/matrix-org/synapse/tree/master/docs/admin_api/purge_room.md)
+and [Shutdown Room API](https://github.com/matrix-org/synapse/tree/master/docs/admin_api/shutdown_room.md)
+are deprecated and will be removed in a future release. They will be replaced by the
+[Delete Room API](https://github.com/matrix-org/synapse/tree/master/docs/admin_api/rooms.md#delete-room-api).
+
+`POST /_synapse/admin/v1/rooms/<room_id>/delete` replaces `POST /_synapse/admin/v1/purge_room` and
+`POST /_synapse/admin/v1/shutdown_room/<room_id>`.
+
+Bugfixes
+--------
+
+- Fix HTTP proxy support when using a proxy that is on a blacklisted IP. Introduced in v1.25.0rc1. Contributed by @Bubu. ([\#9084](https://github.com/matrix-org/synapse/issues/9084))
+
+
+Synapse 1.25.0rc1 (2021-01-06)
+==============================
+
+Features
+--------
+
+- Add an admin API that lets server admins get power in rooms in which local users have power. ([\#8756](https://github.com/matrix-org/synapse/issues/8756))
+- Add optional HTTP authentication to replication endpoints. ([\#8853](https://github.com/matrix-org/synapse/issues/8853))
+- Improve the error messages printed as a result of configuration problems for extension modules. ([\#8874](https://github.com/matrix-org/synapse/issues/8874))
+- Add the number of local devices to Room Details Admin API. Contributed by @dklimpel. ([\#8886](https://github.com/matrix-org/synapse/issues/8886))
+- Add `X-Robots-Tag` header to stop web crawlers from indexing media. Contributed by Aaron Raimist. ([\#8887](https://github.com/matrix-org/synapse/issues/8887))
+- Spam-checkers may now define their methods as `async`. ([\#8890](https://github.com/matrix-org/synapse/issues/8890))
+- Add support for allowing users to pick their own user ID during a single-sign-on login. ([\#8897](https://github.com/matrix-org/synapse/issues/8897), [\#8900](https://github.com/matrix-org/synapse/issues/8900), [\#8911](https://github.com/matrix-org/synapse/issues/8911), [\#8938](https://github.com/matrix-org/synapse/issues/8938), [\#8941](https://github.com/matrix-org/synapse/issues/8941), [\#8942](https://github.com/matrix-org/synapse/issues/8942), [\#8951](https://github.com/matrix-org/synapse/issues/8951))
+- Add an `email.invite_client_location` configuration option to send a web client location to the invite endpoint on the identity server which allows customisation of the email template. ([\#8930](https://github.com/matrix-org/synapse/issues/8930))
+- The search term in the list room and list user Admin APIs is now treated as case-insensitive. ([\#8931](https://github.com/matrix-org/synapse/issues/8931))
+- Apply an IP range blacklist to push and key revocation requests. ([\#8821](https://github.com/matrix-org/synapse/issues/8821), [\#8870](https://github.com/matrix-org/synapse/issues/8870), [\#8954](https://github.com/matrix-org/synapse/issues/8954))
+- Add an option to allow re-use of user-interactive authentication sessions for a period of time. ([\#8970](https://github.com/matrix-org/synapse/issues/8970))
+- Allow running the redact endpoint on workers. ([\#8994](https://github.com/matrix-org/synapse/issues/8994))
+
+
+Bugfixes
+--------
+
+- Fix bug where we might not correctly calculate the current state for rooms with multiple extremities. ([\#8827](https://github.com/matrix-org/synapse/issues/8827))
+- Fix a long-standing bug in the register admin endpoint (`/_synapse/admin/v1/register`) when the `mac` field was not provided. The endpoint now properly returns a 400 error. Contributed by @edwargix. ([\#8837](https://github.com/matrix-org/synapse/issues/8837))
+- Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password. ([\#8858](https://github.com/matrix-org/synapse/issues/8858))
+- Fix a longstanding bug where a 500 error would be returned if the `Content-Length` header was not provided to the upload media resource. ([\#8862](https://github.com/matrix-org/synapse/issues/8862))
+- Add additional validation to pusher URLs to be compliant with the specification. ([\#8865](https://github.com/matrix-org/synapse/issues/8865))
+- Fix the error code that is returned when a user tries to register on a homeserver on which new-user registration has been disabled. ([\#8867](https://github.com/matrix-org/synapse/issues/8867))
+- Fix a bug where `PUT /_synapse/admin/v2/users/<user_id>` failed to create a new user when `avatar_url` is specified. Bug introduced in Synapse v1.9.0. ([\#8872](https://github.com/matrix-org/synapse/issues/8872))
+- Fix a 500 error when attempting to preview an empty HTML file. ([\#8883](https://github.com/matrix-org/synapse/issues/8883))
+- Fix occasional deadlock when handling SIGHUP. ([\#8918](https://github.com/matrix-org/synapse/issues/8918))
+- Fix login API to not ratelimit application services that have ratelimiting disabled. ([\#8920](https://github.com/matrix-org/synapse/issues/8920))
+- Fix bug where we ratelimited auto joining of rooms on registration (using `auto_join_rooms` config). ([\#8921](https://github.com/matrix-org/synapse/issues/8921))
+- Fix a bug where deactivated users appeared in the user directory when their profile information was updated. ([\#8933](https://github.com/matrix-org/synapse/issues/8933), [\#8964](https://github.com/matrix-org/synapse/issues/8964))
+- Fix bug introduced in Synapse v1.24.0 which would cause an exception on startup if both `enabled` and `localdb_enabled` were set to `False` in the `password_config` setting of the configuration file. ([\#8937](https://github.com/matrix-org/synapse/issues/8937))
+- Fix a bug where 500 errors would be returned if the `m.room_history_visibility` event had invalid content. ([\#8945](https://github.com/matrix-org/synapse/issues/8945))
+- Fix a bug causing common English words to not be considered for a user directory search. ([\#8959](https://github.com/matrix-org/synapse/issues/8959))
+- Fix bug where application services couldn't register new ghost users if the server had reached its MAU limit. ([\#8962](https://github.com/matrix-org/synapse/issues/8962))
+- Fix a long-standing bug where a `m.image` event without a `url` would cause errors on push. ([\#8965](https://github.com/matrix-org/synapse/issues/8965))
+- Fix a small bug in v2 state resolution algorithm, which could also cause performance issues for rooms with large numbers of power levels. ([\#8971](https://github.com/matrix-org/synapse/issues/8971))
+- Add validation to the `sendToDevice` API to raise a missing parameters error instead of a 500 error. ([\#8975](https://github.com/matrix-org/synapse/issues/8975))
+- Add validation of group IDs to raise a 400 error instead of a 500 eror. ([\#8977](https://github.com/matrix-org/synapse/issues/8977))
+
+
+Improved Documentation
+----------------------
+
+- Fix the "Event persist rate" section of the included grafana dashboard by adding missing prometheus rules. ([\#8802](https://github.com/matrix-org/synapse/issues/8802))
+- Combine related media admin API docs. ([\#8839](https://github.com/matrix-org/synapse/issues/8839))
+- Fix an error in the documentation for the SAML username mapping provider. ([\#8873](https://github.com/matrix-org/synapse/issues/8873))
+- Clarify comments around template directories in `sample_config.yaml`. ([\#8891](https://github.com/matrix-org/synapse/issues/8891))
+- Move instructions for database setup, adjusted heading levels and improved syntax highlighting in [INSTALL.md](../INSTALL.md). Contributed by @fossterer. ([\#8987](https://github.com/matrix-org/synapse/issues/8987))
+- Update the example value of `group_creation_prefix` in the sample configuration. ([\#8992](https://github.com/matrix-org/synapse/issues/8992))
+- Link the Synapse developer room to the development section in the docs. ([\#9002](https://github.com/matrix-org/synapse/issues/9002))
+
+
+Deprecations and Removals
+-------------------------
+
+- Deprecate Shutdown Room and Purge Room Admin APIs. ([\#8829](https://github.com/matrix-org/synapse/issues/8829))
+
+
+Internal Changes
+----------------
+
+- Properly store the mapping of external ID to Matrix ID for CAS users. ([\#8856](https://github.com/matrix-org/synapse/issues/8856), [\#8958](https://github.com/matrix-org/synapse/issues/8958))
+- Remove some unnecessary stubbing from unit tests. ([\#8861](https://github.com/matrix-org/synapse/issues/8861))
+- Remove unused `FakeResponse` class from unit tests. ([\#8864](https://github.com/matrix-org/synapse/issues/8864))
+- Pass `room_id` to `get_auth_chain_difference`. ([\#8879](https://github.com/matrix-org/synapse/issues/8879))
+- Add type hints to push module. ([\#8880](https://github.com/matrix-org/synapse/issues/8880), [\#8882](https://github.com/matrix-org/synapse/issues/8882), [\#8901](https://github.com/matrix-org/synapse/issues/8901), [\#8940](https://github.com/matrix-org/synapse/issues/8940), [\#8943](https://github.com/matrix-org/synapse/issues/8943), [\#9020](https://github.com/matrix-org/synapse/issues/9020))
+- Simplify logic for handling user-interactive-auth via single-sign-on servers. ([\#8881](https://github.com/matrix-org/synapse/issues/8881))
+- Skip the SAML tests if the requirements (`pysaml2` and `xmlsec1`) aren't available. ([\#8905](https://github.com/matrix-org/synapse/issues/8905))
+- Fix multiarch docker image builds. ([\#8906](https://github.com/matrix-org/synapse/issues/8906))
+- Don't publish `latest` docker image until all archs are built. ([\#8909](https://github.com/matrix-org/synapse/issues/8909))
+- Various clean-ups to the structured logging and logging context code. ([\#8916](https://github.com/matrix-org/synapse/issues/8916), [\#8935](https://github.com/matrix-org/synapse/issues/8935))
+- Automatically drop stale forward-extremities under some specific conditions. ([\#8929](https://github.com/matrix-org/synapse/issues/8929))
+- Refactor test utilities for injecting HTTP requests. ([\#8946](https://github.com/matrix-org/synapse/issues/8946))
+- Add a maximum size of 50 kilobytes to .well-known lookups. ([\#8950](https://github.com/matrix-org/synapse/issues/8950))
+- Fix bug in `generate_log_config` script which made it write empty files. ([\#8952](https://github.com/matrix-org/synapse/issues/8952))
+- Clean up tox.ini file; disable coverage checking for non-test runs. ([\#8963](https://github.com/matrix-org/synapse/issues/8963))
+- Add type hints to the admin and room list handlers. ([\#8973](https://github.com/matrix-org/synapse/issues/8973))
+- Add type hints to the receipts and user directory handlers. ([\#8976](https://github.com/matrix-org/synapse/issues/8976))
+- Drop the unused `local_invites` table. ([\#8979](https://github.com/matrix-org/synapse/issues/8979))
+- Add type hints to the base storage code. ([\#8980](https://github.com/matrix-org/synapse/issues/8980))
+- Support using PyJWT v2.0.0 in the test suite. ([\#8986](https://github.com/matrix-org/synapse/issues/8986))
+- Fix `tests.federation.transport.RoomDirectoryFederationTests` and ensure it runs in CI. ([\#8998](https://github.com/matrix-org/synapse/issues/8998))
+- Add type hints to the crypto module. ([\#8999](https://github.com/matrix-org/synapse/issues/8999))
+
+
Synapse 1.24.0 (2020-12-09)
===========================
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 1d7bb8f969..b6a70f7ffe 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,4 +1,31 @@
-# Contributing code to Synapse
+Welcome to Synapse
+
+This document aims to get you started with contributing to this repo!
+
+- [1. Who can contribute to Synapse?](#1-who-can-contribute-to-synapse)
+- [2. What do I need?](#2-what-do-i-need)
+- [3. Get the source.](#3-get-the-source)
+- [4. Install the dependencies](#4-install-the-dependencies)
+ * [Under Unix (macOS, Linux, BSD, ...)](#under-unix-macos-linux-bsd-)
+ * [Under Windows](#under-windows)
+- [5. Get in touch.](#5-get-in-touch)
+- [6. Pick an issue.](#6-pick-an-issue)
+- [7. Turn coffee and documentation into code and documentation!](#7-turn-coffee-and-documentation-into-code-and-documentation)
+- [8. Test, test, test!](#8-test-test-test)
+ * [Run the linters.](#run-the-linters)
+ * [Run the unit tests.](#run-the-unit-tests)
+ * [Run the integration tests.](#run-the-integration-tests)
+- [9. Submit your patch.](#9-submit-your-patch)
+ * [Changelog](#changelog)
+ + [How do I know what to call the changelog file before I create the PR?](#how-do-i-know-what-to-call-the-changelog-file-before-i-create-the-pr)
+ + [Debian changelog](#debian-changelog)
+ * [Sign off](#sign-off)
+- [10. Turn feedback into better code.](#10-turn-feedback-into-better-code)
+- [11. Find a new issue.](#11-find-a-new-issue)
+- [Notes for maintainers on merging PRs etc](#notes-for-maintainers-on-merging-prs-etc)
+- [Conclusion](#conclusion)
+
+# 1. Who can contribute to Synapse?
Everyone is welcome to contribute code to [matrix.org
projects](https://github.com/matrix-org), provided that they are willing to
@@ -9,70 +36,179 @@ license the code under the same terms as the project's overall 'outbound'
license - in our case, this is almost always Apache Software License v2 (see
[LICENSE](LICENSE)).
-## How to contribute
+# 2. What do I need?
+
+The code of Synapse is written in Python 3. To do pretty much anything, you'll need [a recent version of Python 3](https://wiki.python.org/moin/BeginnersGuide/Download).
+
+The source code of Synapse is hosted on GitHub. You will also need [a recent version of git](https://github.com/git-guides/install-git).
+
+For some tests, you will need [a recent version of Docker](https://docs.docker.com/get-docker/).
+
+
+# 3. Get the source.
The preferred and easiest way to contribute changes is to fork the relevant
-project on github, and then [create a pull request](
+project on GitHub, and then [create a pull request](
https://help.github.com/articles/using-pull-requests/) to ask us to pull your
changes into our repo.
-Some other points to follow:
+Please base your changes on the `develop` branch.
+
+```sh
+git clone git@github.com:YOUR_GITHUB_USER_NAME/synapse.git
+git checkout develop
+```
+
+If you need help getting started with git, this is beyond the scope of the document, but you
+can find many good git tutorials on the web.
+
+# 4. Install the dependencies
- * Please base your changes on the `develop` branch.
+## Under Unix (macOS, Linux, BSD, ...)
- * Please follow the [code style requirements](#code-style).
+Once you have installed Python 3 and added the source, please open a terminal and
+setup a *virtualenv*, as follows:
+
+```sh
+cd path/where/you/have/cloned/the/repository
+python3 -m venv ./env
+source ./env/bin/activate
+pip install -e ".[all,lint,mypy,test]"
+pip install tox
+```
+
+This will install the developer dependencies for the project.
+
+## Under Windows
+
+TBD
- * Please include a [changelog entry](#changelog) with each PR.
- * Please [sign off](#sign-off) your contribution.
+# 5. Get in touch.
- * Please keep an eye on the pull request for feedback from the [continuous
- integration system](#continuous-integration-and-testing) and try to fix any
- errors that come up.
+Join our developer community on Matrix: #synapse-dev:matrix.org !
- * If you need to [update your PR](#updating-your-pull-request), just add new
- commits to your branch rather than rebasing.
-## Code style
+# 6. Pick an issue.
+
+Fix your favorite problem or perhaps find a [Good First Issue](https://github.com/matrix-org/synapse/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+First+Issue%22)
+to work on.
+
+
+# 7. Turn coffee and documentation into code and documentation!
Synapse's code style is documented [here](docs/code_style.md). Please follow
it, including the conventions for the [sample configuration
file](docs/code_style.md#configuration-file-format).
-Many of the conventions are enforced by scripts which are run as part of the
-[continuous integration system](#continuous-integration-and-testing). To help
-check if you have followed the code style, you can run `scripts-dev/lint.sh`
-locally. You'll need python 3.6 or later, and to install a number of tools:
+There is a growing amount of documentation located in the [docs](docs)
+directory. This documentation is intended primarily for sysadmins running their
+own Synapse instance, as well as developers interacting externally with
+Synapse. [docs/dev](docs/dev) exists primarily to house documentation for
+Synapse developers. [docs/admin_api](docs/admin_api) houses documentation
+regarding Synapse's Admin API, which is used mostly by sysadmins and external
+service developers.
-```
-# Install the dependencies
-pip install -e ".[lint,mypy]"
+If you add new files added to either of these folders, please use [GitHub-Flavoured
+Markdown](https://guides.github.com/features/mastering-markdown/).
+
+Some documentation also exists in [Synapse's GitHub
+Wiki](https://github.com/matrix-org/synapse/wiki), although this is primarily
+contributed to by community authors.
+
+
+# 8. Test, test, test!
+<a name="test-test-test"></a>
+
+While you're developing and before submitting a patch, you'll
+want to test your code.
+
+## Run the linters.
+
+The linters look at your code and do two things:
+
+- ensure that your code follows the coding style adopted by the project;
+- catch a number of errors in your code.
+
+They're pretty fast, don't hesitate!
-# Run the linter script
+```sh
+source ./env/bin/activate
./scripts-dev/lint.sh
```
-**Note that the script does not just test/check, but also reformats code, so you
-may wish to ensure any new code is committed first**.
+Note that this script *will modify your files* to fix styling errors.
+Make sure that you have saved all your files.
-By default, this script checks all files and can take some time; if you alter
-only certain files, you might wish to specify paths as arguments to reduce the
-run-time:
+If you wish to restrict the linters to only the files changed since the last commit
+(much faster!), you can instead run:
+```sh
+source ./env/bin/activate
+./scripts-dev/lint.sh -d
```
+
+Or if you know exactly which files you wish to lint, you can instead run:
+
+```sh
+source ./env/bin/activate
./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
```
-You can also provide the `-d` option, which will lint the files that have been
-changed since the last git commit. This will often be significantly faster than
-linting the whole codebase.
+## Run the unit tests.
+
+The unit tests run parts of Synapse, including your changes, to see if anything
+was broken. They are slower than the linters but will typically catch more errors.
+
+```sh
+source ./env/bin/activate
+trial tests
+```
+
+If you wish to only run *some* unit tests, you may specify
+another module instead of `tests` - or a test class or a method:
+
+```sh
+source ./env/bin/activate
+trial tests.rest.admin.test_room tests.handlers.test_admin.ExfiltrateData.test_invite
+```
+
+If your tests fail, you may wish to look at the logs:
+
+```sh
+less _trial_temp/test.log
+```
+
+## Run the integration tests.
+
+The integration tests are a more comprehensive suite of tests. They
+run a full version of Synapse, including your changes, to check if
+anything was broken. They are slower than the unit tests but will
+typically catch more errors.
+
+The following command will let you run the integration test with the most common
+configuration:
+
+```sh
+$ docker run --rm -it -v /path/where/you/have/cloned/the/repository\:/src:ro -v /path/to/where/you/want/logs\:/logs matrixdotorg/sytest-synapse:py37
+```
+
+This configuration should generally cover your needs. For more details about other configurations, see [documentation in the SyTest repo](https://github.com/matrix-org/sytest/blob/develop/docker/README.md).
+
-Before pushing new changes, ensure they don't produce linting errors. Commit any
-files that were corrected.
+# 9. Submit your patch.
+
+Once you're happy with your patch, it's time to prepare a Pull Request.
+
+To prepare a Pull Request, please:
+
+1. verify that [all the tests pass](#test-test-test), including the coding style;
+2. [sign off](#sign-off) your contribution;
+3. `git push` your commit to your fork of Synapse;
+4. on GitHub, [create the Pull Request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request);
+5. add a [changelog entry](#changelog) and push it to your Pull Request;
+6. for most contributors, that's all - however, if you are a member of the organization `matrix-org`, on GitHub, please request a review from `matrix.org / Synapse Core`.
-Please ensure your changes match the cosmetic style of the existing project,
-and **never** mix cosmetic and functional changes in the same commit, as it
-makes it horribly hard to review otherwise.
## Changelog
@@ -156,24 +292,6 @@ directory, you will need both a regular newsfragment *and* an entry in the
debian changelog. (Though typically such changes should be submitted as two
separate pull requests.)
-## Documentation
-
-There is a growing amount of documentation located in the [docs](docs)
-directory. This documentation is intended primarily for sysadmins running their
-own Synapse instance, as well as developers interacting externally with
-Synapse. [docs/dev](docs/dev) exists primarily to house documentation for
-Synapse developers. [docs/admin_api](docs/admin_api) houses documentation
-regarding Synapse's Admin API, which is used mostly by sysadmins and external
-service developers.
-
-New files added to both folders should be written in [Github-Flavoured
-Markdown](https://guides.github.com/features/mastering-markdown/), and attempts
-should be made to migrate existing documents to markdown where possible.
-
-Some documentation also exists in [Synapse's Github
-Wiki](https://github.com/matrix-org/synapse/wiki), although this is primarily
-contributed to by community authors.
-
## Sign off
In order to have a concrete record that your contribution is intentional
@@ -240,47 +358,36 @@ Git allows you to add this signoff automatically when using the `-s`
flag to `git commit`, which uses the name and email set in your
`user.name` and `user.email` git configs.
-## Continuous integration and testing
-[Buildkite](https://buildkite.com/matrix-dot-org/synapse) will automatically
-run a series of checks and tests against any PR which is opened against the
-project; if your change breaks the build, this will be shown in GitHub, with
-links to the build results. If your build fails, please try to fix the errors
-and update your branch.
+# 10. Turn feedback into better code.
+
+Once the Pull Request is opened, you will see a few things:
-To run unit tests in a local development environment, you can use:
+1. our automated CI (Continuous Integration) pipeline will run (again) the linters, the unit tests, the integration tests and more;
+2. one or more of the developers will take a look at your Pull Request and offer feedback.
-- ``tox -e py35`` (requires tox to be installed by ``pip install tox``)
- for SQLite-backed Synapse on Python 3.5.
-- ``tox -e py36`` for SQLite-backed Synapse on Python 3.6.
-- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6
- (requires a running local PostgreSQL with access to create databases).
-- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5
- (requires Docker). Entirely self-contained, recommended if you don't want to
- set up PostgreSQL yourself.
+From this point, you should:
-Docker images are available for running the integration tests (SyTest) locally,
-see the [documentation in the SyTest repo](
-https://github.com/matrix-org/sytest/blob/develop/docker/README.md) for more
-information.
+1. Look at the results of the CI pipeline.
+ - If there is any error, fix the error.
+2. If a developer has requested changes, make these changes and let us know if it is ready for a developer to review again.
+3. Create a new commit with the changes.
+ - Please do NOT overwrite the history. New commits make the reviewer's life easier.
+ - Push this commits to your Pull Request.
+4. Back to 1.
-## Updating your pull request
+Once both the CI and the developers are happy, the patch will be merged into Synapse and released shortly!
-If you decide to make changes to your pull request - perhaps to address issues
-raised in a review, or to fix problems highlighted by [continuous
-integration](#continuous-integration-and-testing) - just add new commits to your
-branch, and push to GitHub. The pull request will automatically be updated.
+# 11. Find a new issue.
-Please **avoid** rebasing your branch, especially once the PR has been
-reviewed: doing so makes it very difficult for a reviewer to see what has
-changed since a previous review.
+By now, you know the drill!
-## Notes for maintainers on merging PRs etc
+# Notes for maintainers on merging PRs etc
There are some notes for those with commit access to the project on how we
manage git [here](docs/dev/git.md).
-## Conclusion
+# Conclusion
That's it! Matrix is a very open and collaborative project as you might expect
given our obsession with open communication. If we're going to successfully
diff --git a/INSTALL.md b/INSTALL.md
index eaeb690092..7b40689234 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -1,19 +1,45 @@
-- [Choosing your server name](#choosing-your-server-name)
-- [Picking a database engine](#picking-a-database-engine)
-- [Installing Synapse](#installing-synapse)
- - [Installing from source](#installing-from-source)
- - [Platform-Specific Instructions](#platform-specific-instructions)
- - [Prebuilt packages](#prebuilt-packages)
-- [Setting up Synapse](#setting-up-synapse)
- - [TLS certificates](#tls-certificates)
- - [Client Well-Known URI](#client-well-known-uri)
- - [Email](#email)
- - [Registering a user](#registering-a-user)
- - [Setting up a TURN server](#setting-up-a-turn-server)
- - [URL previews](#url-previews)
-- [Troubleshooting Installation](#troubleshooting-installation)
-
-# Choosing your server name
+# Installation Instructions
+
+There are 3 steps to follow under **Installation Instructions**.
+
+- [Installation Instructions](#installation-instructions)
+ - [Choosing your server name](#choosing-your-server-name)
+ - [Installing Synapse](#installing-synapse)
+ - [Installing from source](#installing-from-source)
+ - [Platform-specific prerequisites](#platform-specific-prerequisites)
+ - [Debian/Ubuntu/Raspbian](#debianubunturaspbian)
+ - [ArchLinux](#archlinux)
+ - [CentOS/Fedora](#centosfedora)
+ - [macOS](#macos)
+ - [OpenSUSE](#opensuse)
+ - [OpenBSD](#openbsd)
+ - [Windows](#windows)
+ - [Prebuilt packages](#prebuilt-packages)
+ - [Docker images and Ansible playbooks](#docker-images-and-ansible-playbooks)
+ - [Debian/Ubuntu](#debianubuntu)
+ - [Matrix.org packages](#matrixorg-packages)
+ - [Downstream Debian packages](#downstream-debian-packages)
+ - [Downstream Ubuntu packages](#downstream-ubuntu-packages)
+ - [Fedora](#fedora)
+ - [OpenSUSE](#opensuse-1)
+ - [SUSE Linux Enterprise Server](#suse-linux-enterprise-server)
+ - [ArchLinux](#archlinux-1)
+ - [Void Linux](#void-linux)
+ - [FreeBSD](#freebsd)
+ - [OpenBSD](#openbsd-1)
+ - [NixOS](#nixos)
+ - [Setting up Synapse](#setting-up-synapse)
+ - [Using PostgreSQL](#using-postgresql)
+ - [TLS certificates](#tls-certificates)
+ - [Client Well-Known URI](#client-well-known-uri)
+ - [Email](#email)
+ - [Registering a user](#registering-a-user)
+ - [Setting up a TURN server](#setting-up-a-turn-server)
+ - [URL previews](#url-previews)
+ - [Troubleshooting Installation](#troubleshooting-installation)
+
+
+## Choosing your server name
It is important to choose the name for your server before you install Synapse,
because it cannot be changed later.
@@ -29,46 +55,24 @@ that your email address is probably `user@example.com` rather than
`user@email.example.com`) - but doing so may require more advanced setup: see
[Setting up Federation](docs/federate.md).
-# Picking a database engine
+## Installing Synapse
-Synapse offers two database engines:
- * [PostgreSQL](https://www.postgresql.org)
- * [SQLite](https://sqlite.org/)
-
-Almost all installations should opt to use PostgreSQL. Advantages include:
-
-* significant performance improvements due to the superior threading and
- caching model, smarter query optimiser
-* allowing the DB to be run on separate hardware
-
-For information on how to install and use PostgreSQL, please see
-[docs/postgres.md](docs/postgres.md)
-
-By default Synapse uses SQLite and in doing so trades performance for convenience.
-SQLite is only recommended in Synapse for testing purposes or for servers with
-light workloads.
-
-# Installing Synapse
-
-## Installing from source
+### Installing from source
(Prebuilt packages are available for some platforms - see [Prebuilt packages](#prebuilt-packages).)
+When installing from source please make sure that the [Platform-specific prerequisites](#platform-specific-prerequisites) are already installed.
+
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
- Python 3.5.2 or later, up to Python 3.9.
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
-Synapse is written in Python but some of the libraries it uses are written in
-C. So before we can install Synapse itself we need a working C compiler and the
-header files for Python C extensions. See [Platform-Specific
-Instructions](#platform-specific-instructions) for information on installing
-these on various platforms.
To install the Synapse homeserver run:
-```
+```sh
mkdir -p ~/synapse
virtualenv -p python3 ~/synapse/env
source ~/synapse/env/bin/activate
@@ -85,7 +89,7 @@ prefer.
This Synapse installation can then be later upgraded by using pip again with the
update flag:
-```
+```sh
source ~/synapse/env/bin/activate
pip install -U matrix-synapse
```
@@ -93,7 +97,7 @@ pip install -U matrix-synapse
Before you can start Synapse, you will need to generate a configuration
file. To do this, run (in your virtualenv, as before):
-```
+```sh
cd ~/synapse
python -m synapse.app.homeserver \
--server-name my.domain.name \
@@ -111,70 +115,58 @@ wise to back them up somewhere safe. (If, for whatever reason, you do need to
change your homeserver's keys, you may find that other homeserver have the
old key cached. If you update the signing key, you should change the name of the
key in the `<server name>.signing.key` file (the second word) to something
-different. See the
-[spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys)
-for more information on key management).
+different. See the [spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys) for more information on key management).
To actually run your new homeserver, pick a working directory for Synapse to
run (e.g. `~/synapse`), and:
-```
+```sh
cd ~/synapse
source env/bin/activate
synctl start
```
-### Platform-Specific Instructions
+#### Platform-specific prerequisites
-#### Debian/Ubuntu/Raspbian
+Synapse is written in Python but some of the libraries it uses are written in
+C. So before we can install Synapse itself we need a working C compiler and the
+header files for Python C extensions.
+
+##### Debian/Ubuntu/Raspbian
Installing prerequisites on Ubuntu or Debian:
-```
-sudo apt-get install build-essential python3-dev libffi-dev \
+```sh
+sudo apt install build-essential python3-dev libffi-dev \
python3-pip python3-setuptools sqlite3 \
libssl-dev virtualenv libjpeg-dev libxslt1-dev
```
-#### ArchLinux
+##### ArchLinux
Installing prerequisites on ArchLinux:
-```
+```sh
sudo pacman -S base-devel python python-pip \
python-setuptools python-virtualenv sqlite3
```
-#### CentOS/Fedora
+##### CentOS/Fedora
-Installing prerequisites on CentOS 8 or Fedora>26:
+Installing prerequisites on CentOS or Fedora Linux:
-```
+```sh
sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
- libwebp-devel tk-devel redhat-rpm-config \
- python3-virtualenv libffi-devel openssl-devel
+ libwebp-devel libxml2-devel libxslt-devel libpq-devel \
+ python3-virtualenv libffi-devel openssl-devel python3-devel
sudo dnf groupinstall "Development Tools"
```
-Installing prerequisites on CentOS 7 or Fedora<=25:
-
-```
-sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
- lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \
- python3-virtualenv libffi-devel openssl-devel
-sudo yum groupinstall "Development Tools"
-```
-
-Note that Synapse does not support versions of SQLite before 3.11, and CentOS 7
-uses SQLite 3.7. You may be able to work around this by installing a more
-recent SQLite version, but it is recommended that you instead use a Postgres
-database: see [docs/postgres.md](docs/postgres.md).
-
-#### macOS
+##### macOS
Installing prerequisites on macOS:
-```
+```sh
xcode-select --install
sudo easy_install pip
sudo pip install virtualenv
@@ -184,22 +176,23 @@ brew install pkg-config libffi
On macOS Catalina (10.15) you may need to explicitly install OpenSSL
via brew and inform `pip` about it so that `psycopg2` builds:
-```
+```sh
brew install openssl@1.1
-export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/
+export LDFLAGS="-L/usr/local/opt/openssl/lib"
+export CPPFLAGS="-I/usr/local/opt/openssl/include"
```
-#### OpenSUSE
+##### OpenSUSE
Installing prerequisites on openSUSE:
-```
+```sh
sudo zypper in -t pattern devel_basis
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
python-devel libffi-devel libopenssl-devel libjpeg62-devel
```
-#### OpenBSD
+##### OpenBSD
A port of Synapse is available under `net/synapse`. The filesystem
underlying the homeserver directory (defaults to `/var/synapse`) has to be
@@ -213,73 +206,72 @@ mounted with `wxallowed` (cf. `mount(8)`).
Creating a `WRKOBJDIR` for building python under `/usr/local` (which on a
default OpenBSD installation is mounted with `wxallowed`):
-```
+```sh
doas mkdir /usr/local/pobj_wxallowed
```
Assuming `PORTS_PRIVSEP=Yes` (cf. `bsd.port.mk(5)`) and `SUDO=doas` are
configured in `/etc/mk.conf`:
-```
+```sh
doas chown _pbuild:_pbuild /usr/local/pobj_wxallowed
```
Setting the `WRKOBJDIR` for building python:
-```
+```sh
echo WRKOBJDIR_lang/python/3.7=/usr/local/pobj_wxallowed \\nWRKOBJDIR_lang/python/2.7=/usr/local/pobj_wxallowed >> /etc/mk.conf
```
Building Synapse:
-```
+```sh
cd /usr/ports/net/synapse
make install
```
-#### Windows
+##### Windows
If you wish to run or develop Synapse on Windows, the Windows Subsystem For
Linux provides a Linux environment on Windows 10 which is capable of using the
Debian, Fedora, or source installation methods. More information about WSL can
-be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for
-Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server
+be found at <https://docs.microsoft.com/en-us/windows/wsl/install-win10> for
+Windows 10 and <https://docs.microsoft.com/en-us/windows/wsl/install-on-server>
for Windows Server.
-## Prebuilt packages
+### Prebuilt packages
As an alternative to installing from source, prebuilt packages are available
for a number of platforms.
-### Docker images and Ansible playbooks
+#### Docker images and Ansible playbooks
-There is an offical synapse image available at
-https://hub.docker.com/r/matrixdotorg/synapse which can be used with
+There is an official synapse image available at
+<https://hub.docker.com/r/matrixdotorg/synapse> which can be used with
the docker-compose file available at [contrib/docker](contrib/docker). Further
information on this including configuration options is available in the README
on hub.docker.com.
Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a
Dockerfile to automate a synapse server in a single Docker image, at
-https://hub.docker.com/r/avhost/docker-matrix/tags/
+<https://hub.docker.com/r/avhost/docker-matrix/tags/>
Slavi Pantaleev has created an Ansible playbook,
which installs the offical Docker image of Matrix Synapse
along with many other Matrix-related services (Postgres database, Element, coturn,
ma1sd, SSL support, etc.).
For more details, see
-https://github.com/spantaleev/matrix-docker-ansible-deploy
-
+<https://github.com/spantaleev/matrix-docker-ansible-deploy>
-### Debian/Ubuntu
+#### Debian/Ubuntu
-#### Matrix.org packages
+##### Matrix.org packages
Matrix.org provides Debian/Ubuntu packages of the latest stable version of
-Synapse via https://packages.matrix.org/debian/. They are available for Debian
+Synapse via <https://packages.matrix.org/debian/>. They are available for Debian
9 (Stretch), Ubuntu 16.04 (Xenial), and later. To use them:
-```
+```sh
sudo apt install -y lsb-release wget apt-transport-https
sudo wget -O /usr/share/keyrings/matrix-org-archive-keyring.gpg https://packages.matrix.org/debian/matrix-org-archive-keyring.gpg
echo "deb [signed-by=/usr/share/keyrings/matrix-org-archive-keyring.gpg] https://packages.matrix.org/debian/ $(lsb_release -cs) main" |
@@ -299,7 +291,7 @@ The fingerprint of the repository signing key (as shown by `gpg
/usr/share/keyrings/matrix-org-archive-keyring.gpg`) is
`AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`.
-#### Downstream Debian packages
+##### Downstream Debian packages
We do not recommend using the packages from the default Debian `buster`
repository at this time, as they are old and suffer from known security
@@ -311,49 +303,49 @@ for information on how to use backports.
If you are using Debian `sid` or testing, Synapse is available in the default
repositories and it should be possible to install it simply with:
-```
+```sh
sudo apt install matrix-synapse
```
-#### Downstream Ubuntu packages
+##### Downstream Ubuntu packages
We do not recommend using the packages in the default Ubuntu repository
at this time, as they are old and suffer from known security vulnerabilities.
The latest version of Synapse can be installed from [our repository](#matrixorg-packages).
-### Fedora
+#### Fedora
Synapse is in the Fedora repositories as `matrix-synapse`:
-```
+```sh
sudo dnf install matrix-synapse
```
Oleg Girko provides Fedora RPMs at
-https://obs.infoserver.lv/project/monitor/matrix-synapse
+<https://obs.infoserver.lv/project/monitor/matrix-synapse>
-### OpenSUSE
+#### OpenSUSE
Synapse is in the OpenSUSE repositories as `matrix-synapse`:
-```
+```sh
sudo zypper install matrix-synapse
```
-### SUSE Linux Enterprise Server
+#### SUSE Linux Enterprise Server
Unofficial package are built for SLES 15 in the openSUSE:Backports:SLE-15 repository at
-https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/
+<https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/>
-### ArchLinux
+#### ArchLinux
The quickest way to get up and running with ArchLinux is probably with the community package
-https://www.archlinux.org/packages/community/any/matrix-synapse/, which should pull in most of
+<https://www.archlinux.org/packages/community/any/matrix-synapse/>, which should pull in most of
the necessary dependencies.
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 ):
-```
+```sh
sudo pip install --upgrade pip
```
@@ -362,28 +354,28 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
compile it under the right architecture. (This should not be needed if
installing under virtualenv):
-```
+```sh
sudo pip uninstall py-bcrypt
sudo pip install py-bcrypt
```
-### Void Linux
+#### Void Linux
Synapse can be found in the void repositories as 'synapse':
-```
+```sh
xbps-install -Su
xbps-install -S synapse
```
-### FreeBSD
+#### FreeBSD
Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
- - Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
- - Packages: `pkg install py37-matrix-synapse`
+- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
+- Packages: `pkg install py37-matrix-synapse`
-### OpenBSD
+#### OpenBSD
As of OpenBSD 6.7 Synapse is available as a pre-compiled binary. The filesystem
underlying the homeserver directory (defaults to `/var/synapse`) has to be
@@ -392,20 +384,35 @@ and mounting it to `/var/synapse` should be taken into consideration.
Installing Synapse:
-```
+```sh
doas pkg_add synapse
```
-### NixOS
+#### NixOS
Robin Lambertz has packaged Synapse for NixOS at:
-https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix
+<https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix>
-# Setting up Synapse
+## Setting up Synapse
Once you have installed synapse as above, you will need to configure it.
-## TLS certificates
+### Using PostgreSQL
+
+By default Synapse uses [SQLite](https://sqlite.org/) and in doing so trades performance for convenience.
+SQLite is only recommended in Synapse for testing purposes or for servers with
+very light workloads.
+
+Almost all installations should opt to use [PostgreSQL](https://www.postgresql.org). Advantages include:
+
+- significant performance improvements due to the superior threading and
+ caching model, smarter query optimiser
+- allowing the DB to be run on separate hardware
+
+For information on how to install and use PostgreSQL in Synapse, please see
+[docs/postgres.md](docs/postgres.md)
+
+### TLS certificates
The default configuration exposes a single HTTP port on the local
interface: `http://localhost:8008`. It is suitable for local testing,
@@ -419,19 +426,19 @@ The recommended way to do so is to set up a reverse proxy on port
Alternatively, you can configure Synapse to expose an HTTPS port. To do
so, you will need to edit `homeserver.yaml`, as follows:
-* First, under the `listeners` section, uncomment the configuration for the
+- First, under the `listeners` section, uncomment the configuration for the
TLS-enabled listener. (Remove the hash sign (`#`) at the start of
each line). The relevant lines are like this:
- ```
- - port: 8448
- type: http
- tls: true
- resources:
- - names: [client, federation]
+```yaml
+ - port: 8448
+ type: http
+ tls: true
+ resources:
+ - names: [client, federation]
```
-* You will also need to uncomment the `tls_certificate_path` and
+- You will also need to uncomment the `tls_certificate_path` and
`tls_private_key_path` lines under the `TLS` section. You will need to manage
provisioning of these certificates yourself — Synapse had built-in ACME
support, but the ACMEv1 protocol Synapse implements is deprecated, not
@@ -446,7 +453,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
For a more detailed guide to configuring your server for federation, see
[federate.md](docs/federate.md).
-## Client Well-Known URI
+### Client Well-Known URI
Setting up the client Well-Known URI is optional but if you set it up, it will
allow users to enter their full username (e.g. `@user:<server_name>`) into clients
@@ -457,7 +464,7 @@ about the actual homeserver URL you are using.
The URL `https://<server_name>/.well-known/matrix/client` should return JSON in
the following format.
-```
+```json
{
"m.homeserver": {
"base_url": "https://<matrix.example.com>"
@@ -467,7 +474,7 @@ the following format.
It can optionally contain identity server information as well.
-```
+```json
{
"m.homeserver": {
"base_url": "https://<matrix.example.com>"
@@ -484,7 +491,8 @@ Cross-Origin Resource Sharing (CORS) headers. A recommended value would be
view it.
In nginx this would be something like:
-```
+
+```nginx
location /.well-known/matrix/client {
return 200 '{"m.homeserver": {"base_url": "https://<matrix.example.com>"}}';
default_type application/json;
@@ -497,11 +505,11 @@ correctly. `public_baseurl` should be set to the URL that clients will use to
connect to your server. This is the same URL you put for the `m.homeserver`
`base_url` above.
-```
+```yaml
public_baseurl: "https://<matrix.example.com>"
```
-## Email
+### Email
It is desirable for Synapse to have the capability to send email. This allows
Synapse to send password reset emails, send verifications when an email address
@@ -516,18 +524,28 @@ and `notif_from` fields filled out. You may also need to set `smtp_user`,
If email is not configured, password reset, registration and notifications via
email will be disabled.
-## Registering a user
+### Registering a user
The easiest way to create a new user is to do so from a client like [Element](https://element.io/).
-Alternatively you can do so from the command line if you have installed via pip.
+Alternatively, you can do so from the command line. This can be done as follows:
-This can be done as follows:
+ 1. If synapse was installed via pip, activate the virtualenv as follows (if Synapse was
+ installed via a prebuilt package, `register_new_matrix_user` should already be
+ on the search path):
+ ```sh
+ cd ~/synapse
+ source env/bin/activate
+ synctl start # if not already running
+ ```
+ 2. Run the following command:
+ ```sh
+ register_new_matrix_user -c homeserver.yaml http://localhost:8008
+ ```
+This will prompt you to add details for the new user, and will then connect to
+the running Synapse to create the new user. For example:
```
-$ source ~/synapse/env/bin/activate
-$ synctl start # if not already running
-$ register_new_matrix_user -c homeserver.yaml http://localhost:8008
New user localpart: erikj
Password:
Confirm password:
@@ -542,12 +560,12 @@ value is generated by `--generate-config`), but it should be kept secret, as
anyone with knowledge of it can register users, including admin accounts,
on your server even if `enable_registration` is `false`.
-## Setting up a TURN server
+### Setting up a TURN server
For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
-## URL previews
+### URL previews
Synapse includes support for previewing URLs, which is disabled by default. To
turn it on you must enable the `url_preview_enabled: True` config parameter
@@ -557,19 +575,18 @@ This is critical from a security perspective to stop arbitrary Matrix users
spidering 'internal' URLs on your network. At the very least we recommend that
your loopback and RFC1918 IP addresses are blacklisted.
-This also requires the optional `lxml` and `netaddr` python dependencies to be
-installed. This in turn requires the `libxml2` library to be available - on
-Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for
-your OS.
+This also requires the optional `lxml` python dependency to be installed. This
+in turn requires the `libxml2` library to be available - on Debian/Ubuntu this
+means `apt-get install libxml2-dev`, or equivalent for your OS.
-# Troubleshooting Installation
+### Troubleshooting Installation
`pip` seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.:
-```
+```sh
pip install twisted
```
diff --git a/MANIFEST.in b/MANIFEST.in
index 0a9cf4f51c..09bef29705 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -21,9 +21,10 @@ recursive-include scripts *
recursive-include scripts-dev *
recursive-include synapse *.pyi
recursive-include tests *.py
-include tests/http/ca.crt
-include tests/http/ca.key
-include tests/http/server.key
+recursive-include tests *.pem
+recursive-include tests *.p8
+recursive-include tests *.crt
+recursive-include tests *.key
recursive-include synapse/res *
recursive-include synapse/static *.css
diff --git a/README.rst b/README.rst
index d724cf97da..655a2bf3be 100644
--- a/README.rst
+++ b/README.rst
@@ -183,8 +183,9 @@ Using a reverse proxy with Synapse
It is recommended to put a reverse proxy such as
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
-`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_ or
-`HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
+`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_,
+`HAProxy <https://www.haproxy.org/>`_ or
+`relayd <https://man.openbsd.org/relayd.8>`_ in front of Synapse. One advantage of
doing so is that it means that you can expose the default https port (443) to
Matrix clients without needing to run Synapse with root privileges.
@@ -243,6 +244,8 @@ Then update the ``users`` table in the database::
Synapse Development
===================
+Join our developer community on Matrix: `#synapse-dev:matrix.org <https://matrix.to/#/#synapse-dev:matrix.org>`_
+
Before setting up a development environment for synapse, make sure you have the
system dependencies (such as the python header files) installed - see
`Installing from source <INSTALL.md#installing-from-source>`_.
@@ -278,6 +281,27 @@ differ)::
PASSED (skips=15, successes=1322)
+We recommend using the demo which starts 3 federated instances running on ports `8080` - `8082`
+
+ ./demo/start.sh
+
+(to stop, you can use `./demo/stop.sh`)
+
+If you just want to start a single instance of the app and run it directly::
+
+ # Create the homeserver.yaml config once
+ python -m synapse.app.homeserver \
+ --server-name my.domain.name \
+ --config-path homeserver.yaml \
+ --generate-config \
+ --report-stats=[yes|no]
+
+ # Start the app
+ python -m synapse.app.homeserver --config-path homeserver.yaml
+
+
+
+
Running the Integration Tests
=============================
@@ -290,6 +314,15 @@ Testing with SyTest is recommended for verifying that changes related to the
Client-Server API are functioning correctly. See the `installation instructions
<https://github.com/matrix-org/sytest#installing>`_ for details.
+
+Platform dependencies
+=====================
+
+Synapse uses a number of platform dependencies such as Python and PostgreSQL,
+and aims to follow supported upstream versions. See the
+`<docs/deprecation_policy.md>`_ document for more details.
+
+
Troubleshooting
===============
@@ -365,7 +398,7 @@ likely cause. The misbehavior can be worked around by setting
People can't accept room invitations from me
--------------------------------------------
-The typical failure mode here is that you send an invitation to someone
+The typical failure mode here is that you send an invitation to someone
to join a room or direct chat, but when they go to accept it, they get an
error (typically along the lines of "Invalid signature"). They might see
something like the following in their logs::
diff --git a/UPGRADE.rst b/UPGRADE.rst
index 6825b567e9..ba488e1041 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -5,6 +5,16 @@ Before upgrading check if any special steps are required to upgrade from the
version you currently have installed to the current version of Synapse. The extra
instructions that may be required are listed later in this document.
+* Check that your versions of Python and PostgreSQL are still supported.
+
+ Synapse follows upstream lifecycles for `Python`_ and `PostgreSQL`_, and
+ removes support for versions which are no longer maintained.
+
+ The website https://endoflife.date also offers convenient summaries.
+
+ .. _Python: https://devguide.python.org/devcycle/#end-of-life-branches
+ .. _PostgreSQL: https://www.postgresql.org/support/versioning/
+
* If Synapse was installed using `prebuilt packages
<INSTALL.md#prebuilt-packages>`_, you will need to follow the normal process
for upgrading those packages.
@@ -75,6 +85,172 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
+Upgrading to v1.29.0
+====================
+
+Requirement for X-Forwarded-Proto header
+----------------------------------------
+
+When using Synapse with a reverse proxy (in particular, when using the
+`x_forwarded` option on an HTTP listener), Synapse now expects to receive an
+`X-Forwarded-Proto` header on incoming HTTP requests. If it is not set, Synapse
+will log a warning on each received request.
+
+To avoid the warning, administrators using a reverse proxy should ensure that
+the reverse proxy sets `X-Forwarded-Proto` header to `https` or `http` to
+indicate the protocol used by the client.
+
+Synapse also requires the `Host` header to be preserved.
+
+See the `reverse proxy documentation <docs/reverse_proxy.md>`_, where the
+example configurations have been updated to show how to set these headers.
+
+(Users of `Caddy <https://caddyserver.com/>`_ are unaffected, since we believe it
+sets `X-Forwarded-Proto` by default.)
+
+Upgrading to v1.27.0
+====================
+
+Changes to callback URI for OAuth2 / OpenID Connect and SAML2
+-------------------------------------------------------------
+
+This version changes the URI used for callbacks from OAuth2 and SAML2 identity providers:
+
+* If your server is configured for single sign-on via an OpenID Connect or OAuth2 identity
+ provider, you will need to add ``[synapse public baseurl]/_synapse/client/oidc/callback``
+ to the list of permitted "redirect URIs" at the identity provider.
+
+ See `docs/openid.md <docs/openid.md>`_ for more information on setting up OpenID
+ Connect.
+
+* If your server is configured for single sign-on via a SAML2 identity provider, you will
+ need to add ``[synapse public baseurl]/_synapse/client/saml2/authn_response`` as a permitted
+ "ACS location" (also known as "allowed callback URLs") at the identity provider.
+
+ The "Issuer" in the "AuthnRequest" to the SAML2 identity provider is also updated to
+ ``[synapse public baseurl]/_synapse/client/saml2/metadata.xml``. If your SAML2 identity
+ provider uses this property to validate or otherwise identify Synapse, its configuration
+ will need to be updated to use the new URL. Alternatively you could create a new, separate
+ "EntityDescriptor" in your SAML2 identity provider with the new URLs and leave the URLs in
+ the existing "EntityDescriptor" as they were.
+
+Changes to HTML templates
+-------------------------
+
+The HTML templates for SSO and email notifications now have `Jinja2's autoescape <https://jinja.palletsprojects.com/en/2.11.x/api/#autoescaping>`_
+enabled for files ending in ``.html``, ``.htm``, and ``.xml``. If you have customised
+these templates and see issues when viewing them you might need to update them.
+It is expected that most configurations will need no changes.
+
+If you have customised the templates *names* for these templates, it is recommended
+to verify they end in ``.html`` to ensure autoescape is enabled.
+
+The above applies to the following templates:
+
+* ``add_threepid.html``
+* ``add_threepid_failure.html``
+* ``add_threepid_success.html``
+* ``notice_expiry.html``
+* ``notice_expiry.html``
+* ``notif_mail.html`` (which, by default, includes ``room.html`` and ``notif.html``)
+* ``password_reset.html``
+* ``password_reset_confirmation.html``
+* ``password_reset_failure.html``
+* ``password_reset_success.html``
+* ``registration.html``
+* ``registration_failure.html``
+* ``registration_success.html``
+* ``sso_account_deactivated.html``
+* ``sso_auth_bad_user.html``
+* ``sso_auth_confirm.html``
+* ``sso_auth_success.html``
+* ``sso_error.html``
+* ``sso_login_idp_picker.html``
+* ``sso_redirect_confirm.html``
+
+Upgrading to v1.26.0
+====================
+
+Rolling back to v1.25.0 after a failed upgrade
+----------------------------------------------
+
+v1.26.0 includes a lot of large changes. If something problematic occurs, you
+may want to roll-back to a previous version of Synapse. Because v1.26.0 also
+includes a new database schema version, reverting that version is also required
+alongside the generic rollback instructions mentioned above. In short, to roll
+back to v1.25.0 you need to:
+
+1. Stop the server
+2. Decrease the schema version in the database:
+
+ .. code:: sql
+
+ UPDATE schema_version SET version = 58;
+
+3. Delete the ignored users & chain cover data:
+
+ .. code:: sql
+
+ DROP TABLE IF EXISTS ignored_users;
+ UPDATE rooms SET has_auth_chain_index = false;
+
+ For PostgreSQL run:
+
+ .. code:: sql
+
+ TRUNCATE event_auth_chain_links;
+ TRUNCATE event_auth_chains;
+
+ For SQLite run:
+
+ .. code:: sql
+
+ DELETE FROM event_auth_chain_links;
+ DELETE FROM event_auth_chains;
+
+4. Mark the deltas as not run (so they will re-run on upgrade).
+
+ .. code:: sql
+
+ DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/01ignored_user.py";
+ DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/06chain_cover_index.sql";
+
+5. Downgrade Synapse by following the instructions for your installation method
+ in the "Rolling back to older versions" section above.
+
+Upgrading to v1.25.0
+====================
+
+Last release supporting Python 3.5
+----------------------------------
+
+This is the last release of Synapse which guarantees support with Python 3.5,
+which passed its upstream End of Life date several months ago.
+
+We will attempt to maintain support through March 2021, but without guarantees.
+
+In the future, Synapse will follow upstream schedules for ending support of
+older versions of Python and PostgreSQL. Please upgrade to at least Python 3.6
+and PostgreSQL 9.6 as soon as possible.
+
+Blacklisting IP ranges
+----------------------
+
+Synapse v1.25.0 includes new settings, ``ip_range_blacklist`` and
+``ip_range_whitelist``, for controlling outgoing requests from Synapse for federation,
+identity servers, push, and for checking key validity for third-party invite events.
+The previous setting, ``federation_ip_range_blacklist``, is deprecated. The new
+``ip_range_blacklist`` defaults to private IP ranges if it is not defined.
+
+If you have never customised ``federation_ip_range_blacklist`` it is recommended
+that you remove that setting.
+
+If you have customised ``federation_ip_range_blacklist`` you should update the
+setting name to ``ip_range_blacklist``.
+
+If you have a custom push server that is reached via private IP space you may
+need to customise ``ip_range_blacklist`` or ``ip_range_whitelist``.
+
Upgrading to v1.24.0
====================
@@ -105,7 +281,7 @@ shown below:
return {"localpart": localpart}
-Removal historical Synapse Admin API
+Removal historical Synapse Admin API
------------------------------------
Historically, the Synapse Admin API has been accessible under:
diff --git a/changelog.d/8821.bugfix b/changelog.d/8821.bugfix
deleted file mode 100644
index 8ddfbf31ce..0000000000
--- a/changelog.d/8821.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Apply the `federation_ip_range_blacklist` to push and key revocation requests.
diff --git a/changelog.d/8906.misc b/changelog.d/8906.misc
deleted file mode 100644
index 8b95e4c553..0000000000
--- a/changelog.d/8906.misc
+++ /dev/null
@@ -1 +0,0 @@
-Fix multiarch docker image builds.
diff --git a/changelog.d/8986.misc b/changelog.d/8986.misc
deleted file mode 100644
index 6aefc78784..0000000000
--- a/changelog.d/8986.misc
+++ /dev/null
@@ -1 +0,0 @@
-Support using PyJWT v2.0.0 in the test suite.
diff --git a/changelog.d/9146.misc b/changelog.d/9146.misc
deleted file mode 100644
index 7af29baa30..0000000000
--- a/changelog.d/9146.misc
+++ /dev/null
@@ -1 +0,0 @@
-Fix the Python 3.5 + old dependencies build in CI.
diff --git a/changelog.d/9217.misc b/changelog.d/9217.misc
deleted file mode 100644
index 72bacc7110..0000000000
--- a/changelog.d/9217.misc
+++ /dev/null
@@ -1 +0,0 @@
-Fix the Python 3.5 old dependencies build.
diff --git a/changelog.d/9372.feature b/changelog.d/9372.feature
deleted file mode 100644
index 3cb01004c9..0000000000
--- a/changelog.d/9372.feature
+++ /dev/null
@@ -1 +0,0 @@
-The `no_proxy` and `NO_PROXY` environment variables are now respected in proxied HTTP clients with the lowercase form taking precedence if both are present. Additionally, the lowercase `https_proxy` environment variable is now respected in proxied HTTP clients on top of existing support for the uppercase `HTTPS_PROXY` form and takes precedence if both are present. Contributed by Timothy Leung.
diff --git a/changelog.d/9657.feature b/changelog.d/9657.feature
deleted file mode 100644
index c56a615a8b..0000000000
--- a/changelog.d/9657.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add support for credentials for proxy authentication in the `HTTPS_PROXY` environment variable.
diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py
index ab1e1f1f4c..67e032244e 100755
--- a/contrib/cmdclient/console.py
+++ b/contrib/cmdclient/console.py
@@ -92,7 +92,7 @@ class SynapseCmd(cmd.Cmd):
return self.config["user"].split(":")[1]
def do_config(self, line):
- """ Show the config for this client: "config"
+ """Show the config for this client: "config"
Edit a key value mapping: "config key value" e.g. "config token 1234"
Config variables:
user: The username to auth with.
@@ -360,7 +360,7 @@ class SynapseCmd(cmd.Cmd):
print(e)
def do_topic(self, line):
- """"topic [set|get] <roomid> [<newtopic>]"
+ """ "topic [set|get] <roomid> [<newtopic>]"
Set the topic for a room: topic set <roomid> <newtopic>
Get the topic for a room: topic get <roomid>
"""
@@ -690,7 +690,7 @@ class SynapseCmd(cmd.Cmd):
self._do_presence_state(2, line)
def _parse(self, line, keys, force_keys=False):
- """ Parses the given line.
+ """Parses the given line.
Args:
line : The line to parse
@@ -721,7 +721,7 @@ class SynapseCmd(cmd.Cmd):
query_params={"access_token": None},
alt_text=None,
):
- """ Runs an HTTP request and pretty prints the output.
+ """Runs an HTTP request and pretty prints the output.
Args:
method: HTTP method
diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py
index 345120b612..851e80c25b 100644
--- a/contrib/cmdclient/http.py
+++ b/contrib/cmdclient/http.py
@@ -23,11 +23,10 @@ from twisted.web.http_headers import Headers
class HttpClient:
- """ Interface for talking json over http
- """
+ """Interface for talking json over http"""
def put_json(self, url, data):
- """ Sends the specifed json data using PUT
+ """Sends the specifed json data using PUT
Args:
url (str): The URL to PUT data to.
@@ -41,7 +40,7 @@ class HttpClient:
pass
def get_json(self, url, args=None):
- """ Gets some json from the given host homeserver and path
+ """Gets some json from the given host homeserver and path
Args:
url (str): The URL to GET data from.
@@ -58,7 +57,7 @@ class HttpClient:
class TwistedHttpClient(HttpClient):
- """ Wrapper around the twisted HTTP client api.
+ """Wrapper around the twisted HTTP client api.
Attributes:
agent (twisted.web.client.Agent): The twisted Agent used to send the
@@ -87,8 +86,7 @@ class TwistedHttpClient(HttpClient):
defer.returnValue(json.loads(body))
def _create_put_request(self, url, json_data, headers_dict={}):
- """ Wrapper of _create_request to issue a PUT request
- """
+ """Wrapper of _create_request to issue a PUT request"""
if "Content-Type" not in headers_dict:
raise defer.error(RuntimeError("Must include Content-Type header for PUTs"))
@@ -98,8 +96,7 @@ class TwistedHttpClient(HttpClient):
)
def _create_get_request(self, url, headers_dict={}):
- """ Wrapper of _create_request to issue a GET request
- """
+ """Wrapper of _create_request to issue a GET request"""
return self._create_request("GET", url, headers_dict=headers_dict)
@defer.inlineCallbacks
@@ -127,8 +124,7 @@ class TwistedHttpClient(HttpClient):
@defer.inlineCallbacks
def _create_request(self, method, url, producer=None, headers_dict={}):
- """ Creates and sends a request to the given url
- """
+ """Creates and sends a request to the given url"""
headers_dict["User-Agent"] = ["Synapse Cmd Client"]
retries_left = 5
@@ -185,8 +181,7 @@ class _RawProducer:
class _JsonProducer:
- """ Used by the twisted http client to create the HTTP body from json
- """
+ """Used by the twisted http client to create the HTTP body from json"""
def __init__(self, jsn):
self.data = jsn
diff --git a/contrib/experiments/cursesio.py b/contrib/experiments/cursesio.py
index 15a22c3a0e..cff73650e6 100644
--- a/contrib/experiments/cursesio.py
+++ b/contrib/experiments/cursesio.py
@@ -63,8 +63,7 @@ class CursesStdIO:
self.redraw()
def redraw(self):
- """ method for redisplaying lines
- based on internal list of lines """
+ """method for redisplaying lines based on internal list of lines"""
self.stdscr.clear()
self.paintStatus(self.statusText)
diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py
index d4c35ff2fc..7fbc7d8fc6 100644
--- a/contrib/experiments/test_messaging.py
+++ b/contrib/experiments/test_messaging.py
@@ -56,7 +56,7 @@ def excpetion_errback(failure):
class InputOutput:
- """ This is responsible for basic I/O so that a user can interact with
+ """This is responsible for basic I/O so that a user can interact with
the example app.
"""
@@ -68,8 +68,7 @@ class InputOutput:
self.server = server
def on_line(self, line):
- """ This is where we process commands.
- """
+ """This is where we process commands."""
try:
m = re.match(r"^join (\S+)$", line)
@@ -133,7 +132,7 @@ class IOLoggerHandler(logging.Handler):
class Room:
- """ Used to store (in memory) the current membership state of a room, and
+ """Used to store (in memory) the current membership state of a room, and
which home servers we should send PDUs associated with the room to.
"""
@@ -148,8 +147,7 @@ class Room:
self.have_got_metadata = False
def add_participant(self, participant):
- """ Someone has joined the room
- """
+ """Someone has joined the room"""
self.participants.add(participant)
self.invited.discard(participant)
@@ -160,14 +158,13 @@ class Room:
self.oldest_server = server
def add_invited(self, invitee):
- """ Someone has been invited to the room
- """
+ """Someone has been invited to the room"""
self.invited.add(invitee)
self.servers.add(origin_from_ucid(invitee))
class HomeServer(ReplicationHandler):
- """ A very basic home server implentation that allows people to join a
+ """A very basic home server implentation that allows people to join a
room and then invite other people.
"""
@@ -181,8 +178,7 @@ class HomeServer(ReplicationHandler):
self.output = output
def on_receive_pdu(self, pdu):
- """ We just received a PDU
- """
+ """We just received a PDU"""
pdu_type = pdu.pdu_type
if pdu_type == "sy.room.message":
@@ -199,23 +195,20 @@ class HomeServer(ReplicationHandler):
)
def _on_message(self, pdu):
- """ We received a message
- """
+ """We received a message"""
self.output.print_line(
"#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"])
)
def _on_join(self, context, joinee):
- """ Someone has joined a room, either a remote user or a local user
- """
+ """Someone has joined a room, either a remote user or a local user"""
room = self._get_or_create_room(context)
room.add_participant(joinee)
self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED"))
def _on_invite(self, origin, context, invitee):
- """ Someone has been invited
- """
+ """Someone has been invited"""
room = self._get_or_create_room(context)
room.add_invited(invitee)
@@ -228,8 +221,7 @@ class HomeServer(ReplicationHandler):
@defer.inlineCallbacks
def send_message(self, room_name, sender, body):
- """ Send a message to a room!
- """
+ """Send a message to a room!"""
destinations = yield self.get_servers_for_context(room_name)
try:
@@ -247,8 +239,7 @@ class HomeServer(ReplicationHandler):
@defer.inlineCallbacks
def join_room(self, room_name, sender, joinee):
- """ Join a room!
- """
+ """Join a room!"""
self._on_join(room_name, joinee)
destinations = yield self.get_servers_for_context(room_name)
@@ -269,8 +260,7 @@ class HomeServer(ReplicationHandler):
@defer.inlineCallbacks
def invite_to_room(self, room_name, sender, invitee):
- """ Invite someone to a room!
- """
+ """Invite someone to a room!"""
self._on_invite(self.server_name, room_name, invitee)
destinations = yield self.get_servers_for_context(room_name)
diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py
index b3de468687..495fd4e10a 100644
--- a/contrib/jitsimeetbridge/jitsimeetbridge.py
+++ b/contrib/jitsimeetbridge/jitsimeetbridge.py
@@ -193,15 +193,12 @@ class TrivialXmppClient:
time.sleep(7)
print("SSRC spammer started")
while self.running:
- ssrcMsg = (
- "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>"
- % {
- "tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid),
- "nick": self.userId,
- "assrc": self.ssrcs["audio"],
- "vssrc": self.ssrcs["video"],
- }
- )
+ ssrcMsg = "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" % {
+ "tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid),
+ "nick": self.userId,
+ "assrc": self.ssrcs["audio"],
+ "vssrc": self.ssrcs["video"],
+ }
res = self.sendIq(ssrcMsg)
print("reply from ssrc announce: ", res)
time.sleep(10)
diff --git a/contrib/prometheus/synapse-v2.rules b/contrib/prometheus/synapse-v2.rules
index 6ccca2daaf..7e405bf7f0 100644
--- a/contrib/prometheus/synapse-v2.rules
+++ b/contrib/prometheus/synapse-v2.rules
@@ -58,3 +58,21 @@ groups:
labels:
type: "PDU"
expr: 'synapse_federation_transaction_queue_pending_pdus + 0'
+
+ - record: synapse_storage_events_persisted_by_source_type
+ expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_type="remote"})
+ labels:
+ type: remote
+ - record: synapse_storage_events_persisted_by_source_type
+ expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity="*client*",origin_type="local"})
+ labels:
+ type: local
+ - record: synapse_storage_events_persisted_by_source_type
+ expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity!="*client*",origin_type="local"})
+ labels:
+ type: bridges
+ - record: synapse_storage_events_persisted_by_event_type
+ expr: sum without(origin_entity, origin_type) (synapse_storage_events_persisted_events_sep)
+ - record: synapse_storage_events_persisted_by_origin
+ expr: sum without(type) (synapse_storage_events_persisted_events_sep)
+
diff --git a/contrib/purge_api/purge_history.sh b/contrib/purge_api/purge_history.sh
index e7dd5d6468..c45136ff53 100644
--- a/contrib/purge_api/purge_history.sh
+++ b/contrib/purge_api/purge_history.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
# this script will use the api:
# https://github.com/matrix-org/synapse/blob/master/docs/admin_api/purge_history_api.rst
diff --git a/contrib/purge_api/purge_remote_media.sh b/contrib/purge_api/purge_remote_media.sh
index 77220d3bd5..4930d9529c 100644
--- a/contrib/purge_api/purge_remote_media.sh
+++ b/contrib/purge_api/purge_remote_media.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
DOMAIN=yourserver.tld
# add this user as admin in your home server:
diff --git a/debian/build_virtualenv b/debian/build_virtualenv
index cbdde93f96..cad7d16883 100755
--- a/debian/build_virtualenv
+++ b/debian/build_virtualenv
@@ -33,11 +33,13 @@ esac
# Use --builtin-venv to use the better `venv` module from CPython 3.4+ rather
# than the 2/3 compatible `virtualenv`.
+# Pin pip to 20.3.4 to fix breakage in 21.0 on py3.5 (xenial)
+
dh_virtualenv \
--install-suffix "matrix-synapse" \
--builtin-venv \
--python "$SNAKE" \
- --upgrade-pip \
+ --upgrade-pip-to="20.3.4" \
--preinstall="lxml" \
--preinstall="mock" \
--extra-pip-arg="--no-cache-dir" \
@@ -56,10 +58,10 @@ trap "rm -r $tmpdir" EXIT
cp -r tests "$tmpdir"
PYTHONPATH="$tmpdir" \
- "${TARGET_PYTHON}" -B -m twisted.trial --reporter=text -j2 tests
+ "${TARGET_PYTHON}" -m twisted.trial --reporter=text -j2 tests
# build the config file
-"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_config" \
+"${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_config" \
--config-dir="/etc/matrix-synapse" \
--data-dir="/var/lib/matrix-synapse" |
perl -pe '
@@ -85,7 +87,7 @@ PYTHONPATH="$tmpdir" \
' > "${PACKAGE_BUILD_DIR}/etc/matrix-synapse/homeserver.yaml"
# build the log config file
-"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_log_config" \
+"${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_log_config" \
--output-file="${PACKAGE_BUILD_DIR}/etc/matrix-synapse/log.yaml"
# add a dependency on the right version of python to substvars.
diff --git a/debian/changelog b/debian/changelog
index 6b819d201d..09602ff54b 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,68 @@
+matrix-synapse-py3 (1.31.0) stable; urgency=medium
+
+ * New synapse release 1.31.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Tue, 06 Apr 2021 13:08:29 +0100
+
+matrix-synapse-py3 (1.30.1) stable; urgency=medium
+
+ * New synapse release 1.30.1.
+
+ -- Synapse Packaging team <packages@matrix.org> Fri, 26 Mar 2021 12:01:28 +0000
+
+matrix-synapse-py3 (1.30.0) stable; urgency=medium
+
+ * New synapse release 1.30.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Mon, 22 Mar 2021 13:15:34 +0000
+
+matrix-synapse-py3 (1.29.0) stable; urgency=medium
+
+ [ Jonathan de Jong ]
+ * Remove the python -B flag (don't generate bytecode) in scripts and documentation.
+
+ [ Synapse Packaging team ]
+ * New synapse release 1.29.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Mon, 08 Mar 2021 13:51:50 +0000
+
+matrix-synapse-py3 (1.28.0) stable; urgency=medium
+
+ * New synapse release 1.28.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Thu, 25 Feb 2021 10:21:57 +0000
+
+matrix-synapse-py3 (1.27.0) stable; urgency=medium
+
+ [ Dan Callahan ]
+ * Fix build on Ubuntu 16.04 LTS (Xenial).
+
+ [ Synapse Packaging team ]
+ * New synapse release 1.27.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Tue, 16 Feb 2021 13:11:28 +0000
+
+matrix-synapse-py3 (1.26.0) stable; urgency=medium
+
+ [ Richard van der Hoff ]
+ * Remove dependency on `python3-distutils`.
+
+ [ Synapse Packaging team ]
+ * New synapse release 1.26.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Wed, 27 Jan 2021 12:43:35 -0500
+
+matrix-synapse-py3 (1.25.0) stable; urgency=medium
+
+ [ Dan Callahan ]
+ * Update dependencies to account for the removal of the transitional
+ dh-systemd package from Debian Bullseye.
+
+ [ Synapse Packaging team ]
+ * New synapse release 1.25.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Wed, 13 Jan 2021 10:14:55 +0000
+
matrix-synapse-py3 (1.24.0) stable; urgency=medium
* New synapse release 1.24.0.
diff --git a/debian/control b/debian/control
index bae14b41e4..8167a901a4 100644
--- a/debian/control
+++ b/debian/control
@@ -3,9 +3,11 @@ Section: contrib/python
Priority: extra
Maintainer: Synapse Packaging team <packages@matrix.org>
# keep this list in sync with the build dependencies in docker/Dockerfile-dhvirtualenv.
+# TODO: Remove the dependency on dh-systemd after dropping support for Ubuntu xenial
+# On all other supported releases, it's merely a transitional package which
+# does nothing but depends on debhelper (> 9.20160709)
Build-Depends:
- debhelper (>= 9),
- dh-systemd,
+ debhelper (>= 9.20160709) | dh-systemd,
dh-virtualenv (>= 1.1),
libsystemd-dev,
libpq-dev,
@@ -29,7 +31,6 @@ Pre-Depends: dpkg (>= 1.16.1)
Depends:
adduser,
debconf,
- python3-distutils|libpython3-stdlib (<< 3.6),
${misc:Depends},
${shlibs:Depends},
${synapse:pydepends},
diff --git a/debian/synctl.1 b/debian/synctl.1
index 437f8f9e0e..af58c8d224 100644
--- a/debian/synctl.1
+++ b/debian/synctl.1
@@ -44,7 +44,7 @@ Configuration file may be generated as follows:
.
.nf
-$ python \-B \-m synapse\.app\.homeserver \-c config\.yaml \-\-generate\-config \-\-server\-name=<server name>
+$ python \-m synapse\.app\.homeserver \-c config\.yaml \-\-generate\-config \-\-server\-name=<server name>
.
.fi
.
diff --git a/debian/synctl.ronn b/debian/synctl.ronn
index 1bad6094f3..10cbda988f 100644
--- a/debian/synctl.ronn
+++ b/debian/synctl.ronn
@@ -41,7 +41,7 @@ process.
Configuration file may be generated as follows:
- $ python -B -m synapse.app.homeserver -c config.yaml --generate-config --server-name=<server name>
+ $ python -m synapse.app.homeserver -c config.yaml --generate-config --server-name=<server name>
## ENVIRONMENT
diff --git a/demo/clean.sh b/demo/clean.sh
index 418ca9457e..6b809f6e83 100755
--- a/demo/clean.sh
+++ b/demo/clean.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
set -e
diff --git a/demo/start.sh b/demo/start.sh
index f6b5ea137f..621a5698b8 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
DIR="$( cd "$( dirname "$0" )" && pwd )"
diff --git a/demo/stop.sh b/demo/stop.sh
index 85a1d2c161..f9dddc5914 100755
--- a/demo/stop.sh
+++ b/demo/stop.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
DIR="$( cd "$( dirname "$0" )" && pwd )"
diff --git a/demo/webserver.py b/demo/webserver.py
deleted file mode 100644
index ba176d3bd2..0000000000
--- a/demo/webserver.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import argparse
-import BaseHTTPServer
-import os
-import SimpleHTTPServer
-import cgi, logging
-
-from daemonize import Daemonize
-
-
-class SimpleHTTPRequestHandlerWithPOST(SimpleHTTPServer.SimpleHTTPRequestHandler):
- UPLOAD_PATH = "upload"
-
- """
- Accept all post request as file upload
- """
-
- def do_POST(self):
-
- path = os.path.join(self.UPLOAD_PATH, os.path.basename(self.path))
- length = self.headers["content-length"]
- data = self.rfile.read(int(length))
-
- with open(path, "wb") as fh:
- fh.write(data)
-
- self.send_response(200)
- self.send_header("Content-Type", "application/json")
- self.end_headers()
-
- # Return the absolute path of the uploaded file
- self.wfile.write('{"url":"/%s"}' % path)
-
-
-def setup():
- parser = argparse.ArgumentParser()
- parser.add_argument("directory")
- parser.add_argument("-p", "--port", dest="port", type=int, default=8080)
- parser.add_argument("-P", "--pid-file", dest="pid", default="web.pid")
- args = parser.parse_args()
-
- # Get absolute path to directory to serve, as daemonize changes to '/'
- os.chdir(args.directory)
- dr = os.getcwd()
-
- httpd = BaseHTTPServer.HTTPServer(("", args.port), SimpleHTTPRequestHandlerWithPOST)
-
- def run():
- os.chdir(dr)
- httpd.serve_forever()
-
- daemon = Daemonize(
- app="synapse-webclient", pid=args.pid, action=run, auto_close_fds=False
- )
-
- daemon.start()
-
-
-if __name__ == "__main__":
- setup()
diff --git a/docker/Dockerfile b/docker/Dockerfile
index afd896ffc1..5b7bf02776 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -18,6 +18,11 @@ ARG PYTHON_VERSION=3.8
###
FROM docker.io/python:${PYTHON_VERSION}-slim as builder
+LABEL org.opencontainers.image.url='https://matrix.org/docs/projects/server/synapse'
+LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/synapse/blob/master/docker/README.md'
+LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git'
+LABEL org.opencontainers.image.licenses='Apache-2.0'
+
# install the OS build deps
RUN apt-get update && apt-get install -y \
build-essential \
@@ -28,31 +33,32 @@ RUN apt-get update && apt-get install -y \
libwebp-dev \
libxml++2.6-dev \
libxslt1-dev \
+ openssl \
+ rustc \
zlib1g-dev \
- && rm -rf /var/lib/apt/lists/*
+ && rm -rf /var/lib/apt/lists/*
-# Build dependencies that are not available as wheels, to speed up rebuilds
-RUN pip install --prefix="/install" --no-warn-script-location \
- frozendict \
- jaeger-client \
- opentracing \
- # Match the version constraints of Synapse
- "prometheus_client>=0.4.0" \
- psycopg2 \
- pycparser \
- pyrsistent \
- pyyaml \
- simplejson \
- threadloop \
- thrift
-
-# now install synapse and all of the python deps to /install.
-COPY synapse /synapse/synapse/
+# Copy just what we need to pip install
COPY scripts /synapse/scripts/
COPY MANIFEST.in README.rst setup.py synctl /synapse/
+COPY synapse/__init__.py /synapse/synapse/__init__.py
+COPY synapse/python_dependencies.py /synapse/synapse/python_dependencies.py
+# To speed up rebuilds, install all of the dependencies before we copy over
+# the whole synapse project so that we this layer in the Docker cache can be
+# used while you develop on the source
+#
+# This is aiming at installing the `install_requires` and `extras_require` from `setup.py`
RUN pip install --prefix="/install" --no-warn-script-location \
- /synapse[all]
+ /synapse[all]
+
+# Copy over the rest of the project
+COPY synapse /synapse/synapse/
+
+# Install the synapse package itself and all of its children packages.
+#
+# This is aiming at installing only the `packages=find_packages(...)` from `setup.py
+RUN pip install --prefix="/install" --no-deps --no-warn-script-location /synapse
###
### Stage 1: runtime
@@ -67,7 +73,10 @@ RUN apt-get update && apt-get install -y \
libpq5 \
libwebp6 \
xmlsec1 \
- && rm -rf /var/lib/apt/lists/*
+ libjemalloc2 \
+ libssl-dev \
+ openssl \
+ && rm -rf /var/lib/apt/lists/*
COPY --from=builder /install /usr/local
COPY ./docker/start.py /start.py
@@ -80,4 +89,4 @@ EXPOSE 8008/tcp 8009/tcp 8448/tcp
ENTRYPOINT ["/start.py"]
HEALTHCHECK --interval=1m --timeout=5s \
- CMD curl -fSs http://localhost:8008/health || exit 1
+ CMD curl -fSs http://localhost:8008/health || exit 1
diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv
index 619585d5fa..0d74630370 100644
--- a/docker/Dockerfile-dhvirtualenv
+++ b/docker/Dockerfile-dhvirtualenv
@@ -27,6 +27,7 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \
wget
# fetch and unpack the package
+# TODO: Upgrade to 1.2.2 once xenial is dropped
RUN mkdir /dh-virtualenv
RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz
RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz
@@ -50,17 +51,22 @@ FROM ${distro}
ARG distro=""
ENV distro ${distro}
+# Python < 3.7 assumes LANG="C" means ASCII-only and throws on printing unicode
+# http://bugs.python.org/issue19846
+ENV LANG C.UTF-8
+
# Install the build dependencies
#
# NB: keep this list in sync with the list of build-deps in debian/control
# TODO: it would be nice to do that automatically.
+# TODO: Remove the dh-systemd stanza after dropping support for Ubuntu xenial
+# it's a transitional package on all other, more recent releases
RUN apt-get update -qq -o Acquire::Languages=none \
&& env DEBIAN_FRONTEND=noninteractive apt-get install \
-yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \
build-essential \
debhelper \
devscripts \
- dh-systemd \
libsystemd-dev \
lsb-release \
pkg-config \
@@ -69,7 +75,11 @@ RUN apt-get update -qq -o Acquire::Languages=none \
python3-setuptools \
python3-venv \
sqlite3 \
- libpq-dev
+ libpq-dev \
+ xmlsec1 \
+ && ( env DEBIAN_FRONTEND=noninteractive apt-get install \
+ -yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \
+ dh-systemd || true )
COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb /
diff --git a/docker/README.md b/docker/README.md
index c8f27b8566..3a7dc585e7 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -11,7 +11,6 @@ The image also does *not* provide a TURN server.
By default, the image expects a single volume, located at ``/data``, that will hold:
* configuration files;
-* temporary files during uploads;
* uploaded media and thumbnails;
* the SQLite database if you do not configure postgres;
* the appservices configuration.
@@ -205,3 +204,8 @@ healthcheck:
timeout: 10s
retries: 3
```
+
+## Using jemalloc
+
+Jemalloc is embedded in the image and will be used instead of the default allocator.
+You can read about jemalloc by reading the Synapse [README](../README.md)
\ No newline at end of file
diff --git a/docker/build_debian.sh b/docker/build_debian.sh
index f312f0715f..f426d2b77b 100644
--- a/docker/build_debian.sh
+++ b/docker/build_debian.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
# The script to build the Debian package, as ran inside the Docker image.
diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml
index a808485c12..0dea62a87d 100644
--- a/docker/conf/homeserver.yaml
+++ b/docker/conf/homeserver.yaml
@@ -89,7 +89,6 @@ federation_rc_concurrent: 3
## Files ##
media_store_path: "/data/media"
-uploads_path: "/data/uploads"
max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "50M" }}"
max_image_pixels: "32M"
dynamic_thumbnails: false
@@ -198,12 +197,10 @@ old_signing_keys: {}
key_refresh_interval: "1d" # 1 Day.
# The trusted servers to download signing keys from.
-perspectives:
- servers:
- "matrix.org":
- verify_keys:
- "ed25519:auto":
- key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+trusted_key_servers:
+ - server_name: matrix.org
+ verify_keys:
+ "ed25519:auto": "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
password_config:
enabled: true
diff --git a/docker/run_pg_tests.sh b/docker/run_pg_tests.sh
index d18d1e4c8e..1fd08cb62b 100755
--- a/docker/run_pg_tests.sh
+++ b/docker/run_pg_tests.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
# This script runs the PostgreSQL tests inside a Docker container. It expects
# the relevant source files to be mounted into /src (done automatically by the
diff --git a/docker/start.py b/docker/start.py
index 0d2c590b88..16d6a8208a 100755
--- a/docker/start.py
+++ b/docker/start.py
@@ -3,6 +3,7 @@
import codecs
import glob
import os
+import platform
import subprocess
import sys
@@ -213,6 +214,13 @@ def main(args, environ):
if "-m" not in args:
args = ["-m", synapse_worker] + args
+ jemallocpath = "/usr/lib/%s-linux-gnu/libjemalloc.so.2" % (platform.machine(),)
+
+ if os.path.isfile(jemallocpath):
+ environ["LD_PRELOAD"] = jemallocpath
+ else:
+ log("Could not find %s, will not use" % (jemallocpath,))
+
# if there are no config files passed to synapse, try adding the default file
if not any(p.startswith("--config-path") or p.startswith("-c") for p in args):
config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data")
@@ -248,9 +256,9 @@ running with 'migrate_config'. See the README for more details.
args = ["python"] + args
if ownership is not None:
args = ["gosu", ownership] + args
- os.execv("/usr/sbin/gosu", args)
+ os.execve("/usr/sbin/gosu", args, environ)
else:
- os.execv("/usr/local/bin/python", args)
+ os.execve("/usr/local/bin/python", args, environ)
if __name__ == "__main__":
diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md
index 71137c6dfc..9dbec68c19 100644
--- a/docs/admin_api/media_admin_api.md
+++ b/docs/admin_api/media_admin_api.md
@@ -1,4 +1,22 @@
-# List all media in a room
+# Contents
+- [Querying media](#querying-media)
+ * [List all media in a room](#list-all-media-in-a-room)
+ * [List all media uploaded by a user](#list-all-media-uploaded-by-a-user)
+- [Quarantine media](#quarantine-media)
+ * [Quarantining media by ID](#quarantining-media-by-id)
+ * [Quarantining media in a room](#quarantining-media-in-a-room)
+ * [Quarantining all media of a user](#quarantining-all-media-of-a-user)
+ * [Protecting media from being quarantined](#protecting-media-from-being-quarantined)
+- [Delete local media](#delete-local-media)
+ * [Delete a specific local media](#delete-a-specific-local-media)
+ * [Delete local media by date or size](#delete-local-media-by-date-or-size)
+- [Purge Remote Media API](#purge-remote-media-api)
+
+# Querying media
+
+These APIs allow extracting media information from the homeserver.
+
+## List all media in a room
This API gets a list of known media in a room.
However, it only shows media from unencrypted events or rooms.
@@ -11,19 +29,25 @@ To use it, you will need to authenticate by providing an `access_token` for a
server admin: see [README.rst](README.rst).
The API returns a JSON body like the following:
-```
+```json
{
- "local": [
- "mxc://localhost/xwvutsrqponmlkjihgfedcba",
- "mxc://localhost/abcdefghijklmnopqrstuvwx"
- ],
- "remote": [
- "mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
- "mxc://matrix.org/abcdefghijklmnopqrstuvwx"
- ]
+ "local": [
+ "mxc://localhost/xwvutsrqponmlkjihgfedcba",
+ "mxc://localhost/abcdefghijklmnopqrstuvwx"
+ ],
+ "remote": [
+ "mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
+ "mxc://matrix.org/abcdefghijklmnopqrstuvwx"
+ ]
}
```
+## List all media uploaded by a user
+
+Listing all media that has been uploaded by a local user can be achieved through
+the use of the [List media of a user](user_admin_api.rst#list-media-of-a-user)
+Admin API.
+
# Quarantine media
Quarantining media means that it is marked as inaccessible by users. It applies
@@ -48,7 +72,7 @@ form of `abcdefg12345...`.
Response:
-```
+```json
{}
```
@@ -68,14 +92,18 @@ Where `room_id` is in the form of `!roomid12345:example.org`.
Response:
-```
+```json
{
- "num_quarantined": 10 # The number of media items successfully quarantined
+ "num_quarantined": 10
}
```
+The following fields are returned in the JSON response body:
+
+* `num_quarantined`: integer - The number of media items successfully quarantined
+
Note that there is a legacy endpoint, `POST
-/_synapse/admin/v1/quarantine_media/<room_id >`, that operates the same.
+/_synapse/admin/v1/quarantine_media/<room_id>`, that operates the same.
However, it is deprecated and may be removed in a future release.
## Quarantining all media of a user
@@ -92,23 +120,52 @@ POST /_synapse/admin/v1/user/<user_id>/media/quarantine
{}
```
-Where `user_id` is in the form of `@bob:example.org`.
+URL Parameters
+
+* `user_id`: string - User ID in the form of `@bob:example.org`
Response:
-```
+```json
{
- "num_quarantined": 10 # The number of media items successfully quarantined
+ "num_quarantined": 10
}
```
+The following fields are returned in the JSON response body:
+
+* `num_quarantined`: integer - The number of media items successfully quarantined
+
+## Protecting media from being quarantined
+
+This API protects a single piece of local media from being quarantined using the
+above APIs. This is useful for sticker packs and other shared media which you do
+not want to get quarantined, especially when
+[quarantining media in a room](#quarantining-media-in-a-room).
+
+Request:
+
+```
+POST /_synapse/admin/v1/media/protect/<media_id>
+
+{}
+```
+
+Where `media_id` is in the form of `abcdefg12345...`.
+
+Response:
+
+```json
+{}
+```
+
# Delete local media
This API deletes the *local* media from the disk of your own server.
This includes any local thumbnails and copies of media downloaded from
remote homeservers.
This API will not affect media that has been uploaded to external
media repositories (e.g https://github.com/turt2live/matrix-media-repo/).
-See also [purge_remote_media.rst](purge_remote_media.rst).
+See also [Purge Remote Media API](#purge-remote-media-api).
## Delete a specific local media
Delete a specific `media_id`.
@@ -129,12 +186,12 @@ URL Parameters
Response:
```json
- {
- "deleted_media": [
- "abcdefghijklmnopqrstuvwx"
- ],
- "total": 1
- }
+{
+ "deleted_media": [
+ "abcdefghijklmnopqrstuvwx"
+ ],
+ "total": 1
+}
```
The following fields are returned in the JSON response body:
@@ -167,16 +224,51 @@ If `false` these files will be deleted. Defaults to `true`.
Response:
```json
- {
- "deleted_media": [
- "abcdefghijklmnopqrstuvwx",
- "abcdefghijklmnopqrstuvwz"
- ],
- "total": 2
- }
+{
+ "deleted_media": [
+ "abcdefghijklmnopqrstuvwx",
+ "abcdefghijklmnopqrstuvwz"
+ ],
+ "total": 2
+}
```
The following fields are returned in the JSON response body:
* `deleted_media`: an array of strings - List of deleted `media_id`
* `total`: integer - Total number of deleted `media_id`
+
+# Purge Remote Media API
+
+The purge remote media API allows server admins to purge old cached remote media.
+
+The API is:
+
+```
+POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms>
+
+{}
+```
+
+URL Parameters
+
+* `unix_timestamp_in_ms`: string representing a positive integer - Unix timestamp in ms.
+All cached media that was last accessed before this timestamp will be removed.
+
+Response:
+
+```json
+{
+ "deleted": 10
+}
+```
+
+The following fields are returned in the JSON response body:
+
+* `deleted`: integer - The number of media items successfully deleted
+
+To use it, you will need to authenticate by providing an `access_token` for a
+server admin: see [README.rst](README.rst).
+
+If the user re-requests purged remote media, synapse will re-request the media
+from the originating server.
diff --git a/docs/admin_api/purge_remote_media.rst b/docs/admin_api/purge_remote_media.rst
deleted file mode 100644
index 00cb6b0589..0000000000
--- a/docs/admin_api/purge_remote_media.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-Purge Remote Media API
-======================
-
-The purge remote media API allows server admins to purge old cached remote
-media.
-
-The API is::
-
- POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms>
-
- {}
-
-\... which will remove all cached media that was last accessed before
-``<unix_timestamp_in_ms>``.
-
-To use it, you will need to authenticate by providing an ``access_token`` for a
-server admin: see `README.rst <README.rst>`_.
-
-If the user re-requests purged remote media, synapse will re-request the media
-from the originating server.
diff --git a/docs/admin_api/purge_room.md b/docs/admin_api/purge_room.md
index ae01a543c6..54fea2db6d 100644
--- a/docs/admin_api/purge_room.md
+++ b/docs/admin_api/purge_room.md
@@ -1,12 +1,13 @@
-Purge room API
-==============
+Deprecated: Purge room API
+==========================
+
+**The old Purge room API is deprecated and will be removed in a future release.
+See the new [Delete Room API](rooms.md#delete-room-api) for more details.**
This API will remove all trace of a room from your database.
All local users must have left the room before it can be removed.
-See also: [Delete Room API](rooms.md#delete-room-api)
-
The API is:
```
diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md
index 004a802e17..bc737b30f5 100644
--- a/docs/admin_api/rooms.md
+++ b/docs/admin_api/rooms.md
@@ -1,3 +1,17 @@
+# Contents
+- [List Room API](#list-room-api)
+ * [Parameters](#parameters)
+ * [Usage](#usage)
+- [Room Details API](#room-details-api)
+- [Room Members API](#room-members-api)
+- [Delete Room API](#delete-room-api)
+ * [Parameters](#parameters-1)
+ * [Response](#response)
+ * [Undoing room shutdowns](#undoing-room-shutdowns)
+- [Make Room Admin API](#make-room-admin-api)
+- [Forward Extremities Admin API](#forward-extremities-admin-api)
+- [Event Context API](#event-context-api)
+
# List Room API
The List Room admin API allows server admins to get a list of rooms on their
@@ -76,7 +90,7 @@ GET /_synapse/admin/v1/rooms
Response:
-```
+```jsonc
{
"rooms": [
{
@@ -128,7 +142,7 @@ GET /_synapse/admin/v1/rooms?search_term=TWIM
Response:
-```
+```json
{
"rooms": [
{
@@ -163,7 +177,7 @@ GET /_synapse/admin/v1/rooms?order_by=size
Response:
-```
+```jsonc
{
"rooms": [
{
@@ -219,14 +233,14 @@ GET /_synapse/admin/v1/rooms?order_by=size&from=100
Response:
-```
+```jsonc
{
"rooms": [
{
"room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
"name": "Music Theory",
"canonical_alias": "#musictheory:matrix.org",
- "joined_members": 127
+ "joined_members": 127,
"joined_local_members": 2,
"version": "1",
"creator": "@foo:matrix.org",
@@ -243,7 +257,7 @@ Response:
"room_id": "!twcBhHVdZlQWuuxBhN:termina.org.uk",
"name": "weechat-matrix",
"canonical_alias": "#weechat-matrix:termina.org.uk",
- "joined_members": 137
+ "joined_members": 137,
"joined_local_members": 20,
"version": "4",
"creator": "@foo:termina.org.uk",
@@ -278,6 +292,7 @@ The following fields are possible in the JSON response body:
* `canonical_alias` - The canonical (main) alias address of the room.
* `joined_members` - How many users are currently in the room.
* `joined_local_members` - How many local users are currently in the room.
+* `joined_local_devices` - How many local devices are currently in the room.
* `version` - The version of the room as a string.
* `creator` - The `user_id` of the room creator.
* `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active.
@@ -300,15 +315,16 @@ GET /_synapse/admin/v1/rooms/<room_id>
Response:
-```
+```json
{
"room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
"name": "Music Theory",
"avatar": "mxc://matrix.org/AQDaVFlbkQoErdOgqWRgiGSV",
"topic": "Theory, Composition, Notation, Analysis",
"canonical_alias": "#musictheory:matrix.org",
- "joined_members": 127
+ "joined_members": 127,
"joined_local_members": 2,
+ "joined_local_devices": 2,
"version": "1",
"creator": "@foo:matrix.org",
"encryption": null,
@@ -342,23 +358,51 @@ GET /_synapse/admin/v1/rooms/<room_id>/members
Response:
-```
+```json
{
"members": [
"@foo:matrix.org",
"@bar:matrix.org",
- "@foobar:matrix.org
- ],
+ "@foobar:matrix.org"
+ ],
"total": 3
}
```
+# Room State API
+
+The Room State admin API allows server admins to get a list of all state events in a room.
+
+The response includes the following fields:
+
+* `state` - The current state of the room at the time of request.
+
+## Usage
+
+A standard request:
+
+```
+GET /_synapse/admin/v1/rooms/<room_id>/state
+
+{}
+```
+
+Response:
+
+```json
+{
+ "state": [
+ {"type": "m.room.create", "state_key": "", "etc": true},
+ {"type": "m.room.power_levels", "state_key": "", "etc": true},
+ {"type": "m.room.name", "state_key": "", "etc": true}
+ ]
+}
+```
+
# Delete Room API
The Delete Room admin API allows server admins to remove rooms from server
and block these rooms.
-It is a combination and improvement of "[Shutdown room](shutdown_room.md)"
-and "[Purge room](purge_room.md)" API.
Shuts down a room. Moves all local users and room aliases automatically to a
new room if `new_room_user_id` is set. Otherwise local users only
@@ -455,3 +499,217 @@ The following fields are returned in the JSON response body:
* `local_aliases` - An array of strings representing the local aliases that were migrated from
the old room to the new.
* `new_room_id` - A string representing the room ID of the new room.
+
+
+## Undoing room shutdowns
+
+*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level,
+the structure can and does change without notice.
+
+First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it
+never happened - work has to be done to move forward instead of resetting the past. In fact, in some cases it might not be possible
+to recover at all:
+
+* If the room was invite-only, your users will need to be re-invited.
+* If the room no longer has any members at all, it'll be impossible to rejoin.
+* The first user to rejoin will have to do so via an alias on a different server.
+
+With all that being said, if you still want to try and recover the room:
+
+1. For safety reasons, shut down Synapse.
+2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';`
+ * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`.
+ * The room ID is the same one supplied to the shutdown room API, not the Content Violation room.
+3. Restart Synapse.
+
+You will have to manually handle, if you so choose, the following:
+
+* Aliases that would have been redirected to the Content Violation room.
+* Users that would have been booted from the room (and will have been force-joined to the Content Violation room).
+* Removal of the Content Violation room if desired.
+
+
+# Make Room Admin API
+
+Grants another user the highest power available to a local user who is in the room.
+If the user is not in the room, and it is not publicly joinable, then invite the user.
+
+By default the server admin (the caller) is granted power, but another user can
+optionally be specified, e.g.:
+
+```
+ POST /_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin
+ {
+ "user_id": "@foo:example.com"
+ }
+```
+
+# Forward Extremities Admin API
+
+Enables querying and deleting forward extremities from rooms. When a lot of forward
+extremities accumulate in a room, performance can become degraded. For details, see
+[#1760](https://github.com/matrix-org/synapse/issues/1760).
+
+## Check for forward extremities
+
+To check the status of forward extremities for a room:
+
+```
+ GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+```
+
+A response as follows will be returned:
+
+```json
+{
+ "count": 1,
+ "results": [
+ {
+ "event_id": "$M5SP266vsnxctfwFgFLNceaCo3ujhRtg_NiiHabcdefgh",
+ "state_group": 439,
+ "depth": 123,
+ "received_ts": 1611263016761
+ }
+ ]
+}
+```
+
+## Deleting forward extremities
+
+**WARNING**: Please ensure you know what you're doing and have read
+the related issue [#1760](https://github.com/matrix-org/synapse/issues/1760).
+Under no situations should this API be executed as an automated maintenance task!
+
+If a room has lots of forward extremities, the extra can be
+deleted as follows:
+
+```
+ DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+```
+
+A response as follows will be returned, indicating the amount of forward extremities
+that were deleted.
+
+```json
+{
+ "deleted": 1
+}
+```
+
+# Event Context API
+
+This API lets a client find the context of an event. This is designed primarily to investigate abuse reports.
+
+```
+GET /_synapse/admin/v1/rooms/<room_id>/context/<event_id>
+```
+
+This API mimmicks [GET /_matrix/client/r0/rooms/{roomId}/context/{eventId}](https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-rooms-roomid-context-eventid). Please refer to the link for all details on parameters and reseponse.
+
+Example response:
+
+```json
+{
+ "end": "t29-57_2_0_2",
+ "events_after": [
+ {
+ "content": {
+ "body": "This is an example text message",
+ "msgtype": "m.text",
+ "format": "org.matrix.custom.html",
+ "formatted_body": "<b>This is an example text message</b>"
+ },
+ "type": "m.room.message",
+ "event_id": "$143273582443PhrSn:example.org",
+ "room_id": "!636q39766251:example.com",
+ "sender": "@example:example.org",
+ "origin_server_ts": 1432735824653,
+ "unsigned": {
+ "age": 1234
+ }
+ }
+ ],
+ "event": {
+ "content": {
+ "body": "filename.jpg",
+ "info": {
+ "h": 398,
+ "w": 394,
+ "mimetype": "image/jpeg",
+ "size": 31037
+ },
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ "msgtype": "m.image"
+ },
+ "type": "m.room.message",
+ "event_id": "$f3h4d129462ha:example.com",
+ "room_id": "!636q39766251:example.com",
+ "sender": "@example:example.org",
+ "origin_server_ts": 1432735824653,
+ "unsigned": {
+ "age": 1234
+ }
+ },
+ "events_before": [
+ {
+ "content": {
+ "body": "something-important.doc",
+ "filename": "something-important.doc",
+ "info": {
+ "mimetype": "application/msword",
+ "size": 46144
+ },
+ "msgtype": "m.file",
+ "url": "mxc://example.org/FHyPlCeYUSFFxlgbQYZmoEoe"
+ },
+ "type": "m.room.message",
+ "event_id": "$143273582443PhrSn:example.org",
+ "room_id": "!636q39766251:example.com",
+ "sender": "@example:example.org",
+ "origin_server_ts": 1432735824653,
+ "unsigned": {
+ "age": 1234
+ }
+ }
+ ],
+ "start": "t27-54_2_0_2",
+ "state": [
+ {
+ "content": {
+ "creator": "@example:example.org",
+ "room_version": "1",
+ "m.federate": true,
+ "predecessor": {
+ "event_id": "$something:example.org",
+ "room_id": "!oldroom:example.org"
+ }
+ },
+ "type": "m.room.create",
+ "event_id": "$143273582443PhrSn:example.org",
+ "room_id": "!636q39766251:example.com",
+ "sender": "@example:example.org",
+ "origin_server_ts": 1432735824653,
+ "unsigned": {
+ "age": 1234
+ },
+ "state_key": ""
+ },
+ {
+ "content": {
+ "membership": "join",
+ "avatar_url": "mxc://example.org/SEsfnsuifSDFSSEF",
+ "displayname": "Alice Margatroid"
+ },
+ "type": "m.room.member",
+ "event_id": "$143273582443PhrSn:example.org",
+ "room_id": "!636q39766251:example.com",
+ "sender": "@example:example.org",
+ "origin_server_ts": 1432735824653,
+ "unsigned": {
+ "age": 1234
+ },
+ "state_key": "@alice:example.org"
+ }
+ ]
+}
+```
diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md
index 9b1cb1c184..856a629487 100644
--- a/docs/admin_api/shutdown_room.md
+++ b/docs/admin_api/shutdown_room.md
@@ -1,4 +1,7 @@
-# Shutdown room API
+# Deprecated: Shutdown room API
+
+**The old Shutdown room API is deprecated and will be removed in a future release.
+See the new [Delete Room API](rooms.md#delete-room-api) for more details.**
Shuts down a room, preventing new joins and moves local users and room aliases automatically
to a new room. The new room will be created with the user specified by the
@@ -10,8 +13,6 @@ disallow any further invites or joins.
The local server will only have the power to move local user and room aliases to
the new room. Users on other servers will be unaffected.
-See also: [Delete Room API](rooms.md#delete-room-api)
-
## API
You will need to authenticate with an access token for an admin user.
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 1473a3d4e3..8d4ec5a6f9 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -29,8 +29,14 @@ It returns a JSON body like the following:
}
],
"avatar_url": "<avatar_url>",
- "admin": false,
- "deactivated": false
+ "admin": 0,
+ "deactivated": 0,
+ "shadow_banned": 0,
+ "password_hash": "$2b$12$p9B4GkqYdRTPGD",
+ "creation_ts": 1560432506,
+ "appservice_id": null,
+ "consent_server_notice_sent": null,
+ "consent_version": null
}
URL parameters:
@@ -93,6 +99,8 @@ Body parameters:
- ``deactivated``, optional. If unspecified, deactivation state will be left
unchanged on existing accounts and set to ``false`` for new accounts.
+ A user cannot be erased by deactivating with this API. For details on deactivating users see
+ `Deactivate Account <#deactivate-account>`_.
If the user already exists then optional parameters default to the current value.
@@ -139,20 +147,20 @@ A JSON body is returned with the following shape:
"users": [
{
"name": "<user_id1>",
- "password_hash": "<password_hash1>",
"is_guest": 0,
"admin": 0,
"user_type": null,
"deactivated": 0,
+ "shadow_banned": 0,
"displayname": "<User One>",
"avatar_url": null
}, {
"name": "<user_id2>",
- "password_hash": "<password_hash2>",
"is_guest": 0,
"admin": 1,
"user_type": null,
"deactivated": 0,
+ "shadow_banned": 0,
"displayname": "<User Two>",
"avatar_url": "<avatar_url>"
}
@@ -245,6 +253,25 @@ server admin: see `README.rst <README.rst>`_.
The erase parameter is optional and defaults to ``false``.
An empty body may be passed for backwards compatibility.
+The following actions are performed when deactivating an user:
+
+- Try to unpind 3PIDs from the identity server
+- Remove all 3PIDs from the homeserver
+- Delete all devices and E2EE keys
+- Delete all access tokens
+- Delete the password hash
+- Removal from all rooms the user is a member of
+- Remove the user from the user directory
+- Reject all pending invites
+- Remove all account validity information related to the user
+
+The following additional actions are performed during deactivation if ``erase``
+is set to ``true``:
+
+- Remove the user's display name
+- Remove the user's avatar URL
+- Mark the user as erased
+
Reset password
==============
@@ -334,6 +361,10 @@ A response body like the following is returned:
"total": 2
}
+The server returns the list of rooms of which the user and the server
+are member. If the user is local, all the rooms of which the user is
+member are returned.
+
**Parameters**
The following parameters should be set in the URL:
@@ -348,11 +379,12 @@ The following fields are returned in the JSON response body:
- ``total`` - Number of rooms.
-List media of an user
-================================
+List media of a user
+====================
Gets a list of all local media that a specific ``user_id`` has created.
-The response is ordered by creation date descending and media ID descending.
-The newest media is on top.
+By default, the response is ordered by descending creation date and ascending media ID.
+The newest media is on top. You can change the order with parameters
+``order_by`` and ``dir``.
The API is::
@@ -409,6 +441,35 @@ The following parameters should be set in the URL:
denoting the offset in the returned results. This should be treated as an opaque value and
not explicitly set to anything other than the return value of ``next_token`` from a previous call.
Defaults to ``0``.
+- ``order_by`` - The method by which to sort the returned list of media.
+ If the ordered field has duplicates, the second order is always by ascending ``media_id``,
+ which guarantees a stable ordering. Valid values are:
+
+ - ``media_id`` - Media are ordered alphabetically by ``media_id``.
+ - ``upload_name`` - Media are ordered alphabetically by name the media was uploaded with.
+ - ``created_ts`` - Media are ordered by when the content was uploaded in ms.
+ Smallest to largest. This is the default.
+ - ``last_access_ts`` - Media are ordered by when the content was last accessed in ms.
+ Smallest to largest.
+ - ``media_length`` - Media are ordered by length of the media in bytes.
+ Smallest to largest.
+ - ``media_type`` - Media are ordered alphabetically by MIME-type.
+ - ``quarantined_by`` - Media are ordered alphabetically by the user ID that
+ initiated the quarantine request for this media.
+ - ``safe_from_quarantine`` - Media are ordered by the status if this media is safe
+ from quarantining.
+
+- ``dir`` - Direction of media order. Either ``f`` for forwards or ``b`` for backwards.
+ Setting this value to ``b`` will reverse the above sort order. Defaults to ``f``.
+
+If neither ``order_by`` nor ``dir`` is set, the default order is newest media on top
+(corresponds to ``order_by`` = ``created_ts`` and ``dir`` = ``b``).
+
+Caution. The database only has indexes on the columns ``media_id``,
+``user_id`` and ``created_ts``. This means that if a different sort order is used
+(``upload_name``, ``last_access_ts``, ``media_length``, ``media_type``,
+``quarantined_by`` or ``safe_from_quarantine``), this can cause a large load on the
+database, especially for large environments.
**Response**
@@ -732,3 +793,33 @@ The following fields are returned in the JSON response body:
- ``total`` - integer - Number of pushers.
See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_
+
+Shadow-banning users
+====================
+
+Shadow-banning is a useful tool for moderating malicious or egregiously abusive users.
+A shadow-banned users receives successful responses to their client-server API requests,
+but the events are not propagated into rooms. This can be an effective tool as it
+(hopefully) takes longer for the user to realise they are being moderated before
+pivoting to another account.
+
+Shadow-banning a user should be used as a tool of last resort and may lead to confusing
+or broken behaviour for the client. A shadow-banned user will not receive any
+notification and it is generally more appropriate to ban or kick abusive users.
+A shadow-banned user will be unable to contact anyone on the server.
+
+The API is::
+
+ POST /_synapse/admin/v1/users/<user_id>/shadow_ban
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+An empty JSON dict is returned.
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
+ be local.
diff --git a/docs/auth_chain_diff.dot b/docs/auth_chain_diff.dot
new file mode 100644
index 0000000000..978d579ada
--- /dev/null
+++ b/docs/auth_chain_diff.dot
@@ -0,0 +1,32 @@
+digraph auth {
+ nodesep=0.5;
+ rankdir="RL";
+
+ C [label="Create (1,1)"];
+
+ BJ [label="Bob's Join (2,1)", color=red];
+ BJ2 [label="Bob's Join (2,2)", color=red];
+ BJ2 -> BJ [color=red, dir=none];
+
+ subgraph cluster_foo {
+ A1 [label="Alice's invite (4,1)", color=blue];
+ A2 [label="Alice's Join (4,2)", color=blue];
+ A3 [label="Alice's Join (4,3)", color=blue];
+ A3 -> A2 -> A1 [color=blue, dir=none];
+ color=none;
+ }
+
+ PL1 [label="Power Level (3,1)", color=darkgreen];
+ PL2 [label="Power Level (3,2)", color=darkgreen];
+ PL2 -> PL1 [color=darkgreen, dir=none];
+
+ {rank = same; C; BJ; PL1; A1;}
+
+ A1 -> C [color=grey];
+ A1 -> BJ [color=grey];
+ PL1 -> C [color=grey];
+ BJ2 -> PL1 [penwidth=2];
+
+ A3 -> PL2 [penwidth=2];
+ A1 -> PL1 -> BJ -> C [penwidth=2];
+}
diff --git a/docs/auth_chain_diff.dot.png b/docs/auth_chain_diff.dot.png
new file mode 100644
index 0000000000..771c07308f
--- /dev/null
+++ b/docs/auth_chain_diff.dot.png
Binary files differdiff --git a/docs/auth_chain_difference_algorithm.md b/docs/auth_chain_difference_algorithm.md
new file mode 100644
index 0000000000..30f72a70da
--- /dev/null
+++ b/docs/auth_chain_difference_algorithm.md
@@ -0,0 +1,108 @@
+# Auth Chain Difference Algorithm
+
+The auth chain difference algorithm is used by V2 state resolution, where a
+naive implementation can be a significant source of CPU and DB usage.
+
+### Definitions
+
+A *state set* is a set of state events; e.g. the input of a state resolution
+algorithm is a collection of state sets.
+
+The *auth chain* of a set of events are all the events' auth events and *their*
+auth events, recursively (i.e. the events reachable by walking the graph induced
+by an event's auth events links).
+
+The *auth chain difference* of a collection of state sets is the union minus the
+intersection of the sets of auth chains corresponding to the state sets, i.e an
+event is in the auth chain difference if it is reachable by walking the auth
+event graph from at least one of the state sets but not from *all* of the state
+sets.
+
+## Breadth First Walk Algorithm
+
+A way of calculating the auth chain difference without calculating the full auth
+chains for each state set is to do a parallel breadth first walk (ordered by
+depth) of each state set's auth chain. By tracking which events are reachable
+from each state set we can finish early if every pending event is reachable from
+every state set.
+
+This can work well for state sets that have a small auth chain difference, but
+can be very inefficient for larger differences. However, this algorithm is still
+used if we don't have a chain cover index for the room (e.g. because we're in
+the process of indexing it).
+
+## Chain Cover Index
+
+Synapse computes auth chain differences by pre-computing a "chain cover" index
+for the auth chain in a room, allowing efficient reachability queries like "is
+event A in the auth chain of event B". This is done by assigning every event a
+*chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links*
+between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A`
+is in the auth chain of `B`) if and only if either:
+
+1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s
+ sequence number; or
+2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that
+ `L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`.
+
+There are actually two potential implementations, one where we store links from
+each chain to every other reachable chain (the transitive closure of the links
+graph), and one where we remove redundant links (the transitive reduction of the
+links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1`
+would not be stored. Synapse uses the former implementations so that it doesn't
+need to recurse to test reachability between chains.
+
+### Example
+
+An example auth graph would look like the following, where chains have been
+formed based on type/state_key and are denoted by colour and are labelled with
+`(chain ID, sequence number)`. Links are denoted by the arrows (links in grey
+are those that would be remove in the second implementation described above).
+
+
+
+Note that we don't include all links between events and their auth events, as
+most of those links would be redundant. For example, all events point to the
+create event, but each chain only needs the one link from it's base to the
+create event.
+
+## Using the Index
+
+This index can be used to calculate the auth chain difference of the state sets
+by looking at the chain ID and sequence numbers reachable from each state set:
+
+1. For every state set lookup the chain ID/sequence numbers of each state event
+2. Use the index to find all chains and the maximum sequence number reachable
+ from each state set.
+3. The auth chain difference is then all events in each chain that have sequence
+ numbers between the maximum sequence number reachable from *any* state set and
+ the minimum reachable by *all* state sets (if any).
+
+Note that steps 2 is effectively calculating the auth chain for each state set
+(in terms of chain IDs and sequence numbers), and step 3 is calculating the
+difference between the union and intersection of the auth chains.
+
+### Worked Example
+
+For example, given the above graph, we can calculate the difference between
+state sets consisting of:
+
+1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and
+2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`.
+
+Using the index we see that the following auth chains are reachable from each
+state set:
+
+1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)`
+2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)`
+
+And so, for each the ranges that are in the auth chain difference:
+1. Chain 1: None, (since everything can reach the create event).
+2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state
+ sets and the maximum reachable is `2` (corresponding to Bob's second join).
+3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power
+ level).
+4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins).
+
+So the final result is: Bob's second join `(2,2)`, the second power level
+`(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`.
diff --git a/docs/code_style.md b/docs/code_style.md
index f6c825d7d4..190f8ab2de 100644
--- a/docs/code_style.md
+++ b/docs/code_style.md
@@ -8,16 +8,16 @@ errors in code.
The necessary tools are detailed below.
+First install them with:
+
+ pip install -e ".[lint,mypy]"
+
- **black**
The Synapse codebase uses [black](https://pypi.org/project/black/)
as an opinionated code formatter, ensuring all comitted code is
properly formatted.
- First install `black` with:
-
- pip install --upgrade black
-
Have `black` auto-format your code (it shouldn't change any
functionality) with:
@@ -28,10 +28,6 @@ The necessary tools are detailed below.
`flake8` is a code checking tool. We require code to pass `flake8`
before being merged into the codebase.
- Install `flake8` with:
-
- pip install --upgrade flake8 flake8-comprehensions
-
Check all application and test code with:
flake8 synapse tests
@@ -41,10 +37,6 @@ The necessary tools are detailed below.
`isort` ensures imports are nicely formatted, and can suggest and
auto-fix issues such as double-importing.
- Install `isort` with:
-
- pip install --upgrade isort
-
Auto-fix imports with:
isort -rc synapse tests
diff --git a/docs/deprecation_policy.md b/docs/deprecation_policy.md
new file mode 100644
index 0000000000..06ea340559
--- /dev/null
+++ b/docs/deprecation_policy.md
@@ -0,0 +1,33 @@
+Deprecation Policy for Platform Dependencies
+============================================
+
+Synapse has a number of platform dependencies, including Python and PostgreSQL.
+This document outlines the policy towards which versions we support, and when we
+drop support for versions in the future.
+
+
+Policy
+------
+
+Synapse follows the upstream support life cycles for Python and PostgreSQL,
+i.e. when a version reaches End of Life Synapse will withdraw support for that
+version in future releases.
+
+Details on the upstream support life cycles for Python and PostgreSQL are
+documented at https://endoflife.date/python and
+https://endoflife.date/postgresql.
+
+
+Context
+-------
+
+It is important for system admins to have a clear understanding of the platform
+requirements of Synapse and its deprecation policies so that they can
+effectively plan upgrading their infrastructure ahead of time. This is
+especially important in contexts where upgrading the infrastructure requires
+auditing and approval from a security team, or where otherwise upgrading is a
+long process.
+
+By following the upstream support life cycles Synapse can ensure that its
+dependencies continue to get security patches, while not requiring system admins
+to constantly update their platform dependencies to the latest versions.
diff --git a/docs/dev/cas.md b/docs/dev/cas.md
index f8d02cc82c..592b2d8d4f 100644
--- a/docs/dev/cas.md
+++ b/docs/dev/cas.md
@@ -31,7 +31,7 @@ easy to run CAS implementation built on top of Django.
You should now have a Django project configured to serve CAS authentication with
a single user created.
-## Configure Synapse (and Riot) to use CAS
+## Configure Synapse (and Element) to use CAS
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
running Django test server:
@@ -51,9 +51,9 @@ and that the CAS server is on port 8000, both on localhost.
## Testing the configuration
-Then in Riot:
+Then in Element:
-1. Visit the login page with a Riot pointing at your homeserver.
+1. Visit the login page with a Element pointing at your homeserver.
2. Click the Single Sign-On button.
3. Login using the credentials created with `createsuperuser`.
4. You should be logged in.
diff --git a/docs/openid.md b/docs/openid.md
index da391f74aa..cfaafc5015 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -42,40 +42,41 @@ as follows:
* For other installation mechanisms, see the documentation provided by the
maintainer.
-To enable the OpenID integration, you should then add an `oidc_config` section
-to your configuration file (or uncomment the `enabled: true` line in the
-existing section). See [sample_config.yaml](./sample_config.yaml) for some
-sample settings, as well as the text below for example configurations for
-specific providers.
+To enable the OpenID integration, you should then add a section to the `oidc_providers`
+setting in your configuration file (or uncomment one of the existing examples).
+See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as
+the text below for example configurations for specific providers.
## Sample configs
Here are a few configs for providers that should work with Synapse.
### Microsoft Azure Active Directory
-Azure AD can act as an OpenID Connect Provider. Register a new application under
+Azure AD can act as an OpenID Connect Provider. Register a new application under
*App registrations* in the Azure AD management console. The RedirectURI for your
-application should point to your matrix server: `[synapse public baseurl]/_synapse/oidc/callback`
+application should point to your matrix server:
+`[synapse public baseurl]/_synapse/client/oidc/callback`
-Go to *Certificates & secrets* and register a new client secret. Make note of your
+Go to *Certificates & secrets* and register a new client secret. Make note of your
Directory (tenant) ID as it will be used in the Azure links.
Edit your Synapse config file and change the `oidc_config` section:
```yaml
-oidc_config:
- enabled: true
- issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
- client_id: "<client id>"
- client_secret: "<client secret>"
- scopes: ["openid", "profile"]
- authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize"
- token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token"
- userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo"
-
- user_mapping_provider:
- config:
- localpart_template: "{{ user.preferred_username.split('@')[0] }}"
- display_name_template: "{{ user.name }}"
+oidc_providers:
+ - idp_id: microsoft
+ idp_name: Microsoft
+ issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
+ client_id: "<client id>"
+ client_secret: "<client secret>"
+ scopes: ["openid", "profile"]
+ authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize"
+ token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token"
+ userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo"
+
+ user_mapping_provider:
+ config:
+ localpart_template: "{{ user.preferred_username.split('@')[0] }}"
+ display_name_template: "{{ user.name }}"
```
### [Dex][dex-idp]
@@ -94,7 +95,7 @@ staticClients:
- id: synapse
secret: secret
redirectURIs:
- - '[synapse public baseurl]/_synapse/oidc/callback'
+ - '[synapse public baseurl]/_synapse/client/oidc/callback'
name: 'Synapse'
```
@@ -103,21 +104,22 @@ Run with `dex serve examples/config-dev.yaml`.
Synapse config:
```yaml
-oidc_config:
- enabled: true
- skip_verification: true # This is needed as Dex is served on an insecure endpoint
- issuer: "http://127.0.0.1:5556/dex"
- client_id: "synapse"
- client_secret: "secret"
- scopes: ["openid", "profile"]
- user_mapping_provider:
- config:
- localpart_template: "{{ user.name }}"
- display_name_template: "{{ user.name|capitalize }}"
+oidc_providers:
+ - idp_id: dex
+ idp_name: "My Dex server"
+ skip_verification: true # This is needed as Dex is served on an insecure endpoint
+ issuer: "http://127.0.0.1:5556/dex"
+ client_id: "synapse"
+ client_secret: "secret"
+ scopes: ["openid", "profile"]
+ user_mapping_provider:
+ config:
+ localpart_template: "{{ user.name }}"
+ display_name_template: "{{ user.name|capitalize }}"
```
### [Keycloak][keycloak-idp]
-[Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat.
+[Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat.
Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm.
@@ -139,7 +141,7 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
| Enabled | `On` |
| Client Protocol | `openid-connect` |
| Access Type | `confidential` |
-| Valid Redirect URIs | `[synapse public baseurl]/_synapse/oidc/callback` |
+| Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` |
5. Click `Save`
6. On the Credentials tab, update the fields:
@@ -152,17 +154,22 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
8. Copy Secret
```yaml
-oidc_config:
- enabled: true
- issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
- client_id: "synapse"
- client_secret: "copy secret generated from above"
- scopes: ["openid", "profile"]
+oidc_providers:
+ - idp_id: keycloak
+ idp_name: "My KeyCloak server"
+ issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
+ client_id: "synapse"
+ client_secret: "copy secret generated from above"
+ scopes: ["openid", "profile"]
+ user_mapping_provider:
+ config:
+ localpart_template: "{{ user.preferred_username }}"
+ display_name_template: "{{ user.name }}"
```
### [Auth0][auth0]
1. Create a regular web application for Synapse
-2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/oidc/callback`
+2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/client/oidc/callback`
3. Add a rule to add the `preferred_username` claim.
<details>
<summary>Code sample</summary>
@@ -187,16 +194,17 @@ oidc_config:
Synapse config:
```yaml
-oidc_config:
- enabled: true
- issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
- client_id: "your-client-id" # TO BE FILLED
- client_secret: "your-client-secret" # TO BE FILLED
- scopes: ["openid", "profile"]
- user_mapping_provider:
- config:
- localpart_template: "{{ user.preferred_username }}"
- display_name_template: "{{ user.name }}"
+oidc_providers:
+ - idp_id: auth0
+ idp_name: Auth0
+ issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ scopes: ["openid", "profile"]
+ user_mapping_provider:
+ config:
+ localpart_template: "{{ user.preferred_username }}"
+ display_name_template: "{{ user.name }}"
```
### GitHub
@@ -210,26 +218,28 @@ login mechanism needs an attribute to uniquely identify users, and that endpoint
does not return a `sub` property, an alternative `subject_claim` has to be set.
1. Create a new OAuth application: https://github.com/settings/applications/new.
-2. Set the callback URL to `[synapse public baseurl]/_synapse/oidc/callback`.
+2. Set the callback URL to `[synapse public baseurl]/_synapse/client/oidc/callback`.
Synapse config:
```yaml
-oidc_config:
- enabled: true
- discover: false
- issuer: "https://github.com/"
- client_id: "your-client-id" # TO BE FILLED
- client_secret: "your-client-secret" # TO BE FILLED
- authorization_endpoint: "https://github.com/login/oauth/authorize"
- token_endpoint: "https://github.com/login/oauth/access_token"
- userinfo_endpoint: "https://api.github.com/user"
- scopes: ["read:user"]
- user_mapping_provider:
- config:
- subject_claim: "id"
- localpart_template: "{{ user.login }}"
- display_name_template: "{{ user.name }}"
+oidc_providers:
+ - idp_id: github
+ idp_name: Github
+ idp_brand: "github" # optional: styling hint for clients
+ discover: false
+ issuer: "https://github.com/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ authorization_endpoint: "https://github.com/login/oauth/authorize"
+ token_endpoint: "https://github.com/login/oauth/access_token"
+ userinfo_endpoint: "https://api.github.com/user"
+ scopes: ["read:user"]
+ user_mapping_provider:
+ config:
+ subject_claim: "id"
+ localpart_template: "{{ user.login }}"
+ display_name_template: "{{ user.name }}"
```
### [Google][google-idp]
@@ -239,60 +249,200 @@ oidc_config:
2. add an "OAuth Client ID" for a Web Application under "Credentials".
3. Copy the Client ID and Client Secret, and add the following to your synapse config:
```yaml
- oidc_config:
- enabled: true
- issuer: "https://accounts.google.com/"
- client_id: "your-client-id" # TO BE FILLED
- client_secret: "your-client-secret" # TO BE FILLED
- scopes: ["openid", "profile"]
- user_mapping_provider:
- config:
- localpart_template: "{{ user.given_name|lower }}"
- display_name_template: "{{ user.name }}"
+ oidc_providers:
+ - idp_id: google
+ idp_name: Google
+ idp_brand: "google" # optional: styling hint for clients
+ issuer: "https://accounts.google.com/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ scopes: ["openid", "profile"]
+ user_mapping_provider:
+ config:
+ localpart_template: "{{ user.given_name|lower }}"
+ display_name_template: "{{ user.name }}"
```
4. Back in the Google console, add this Authorized redirect URI: `[synapse
- public baseurl]/_synapse/oidc/callback`.
+ public baseurl]/_synapse/client/oidc/callback`.
### Twitch
1. Setup a developer account on [Twitch](https://dev.twitch.tv/)
2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/)
-3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/oidc/callback`
+3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/client/oidc/callback`
Synapse config:
```yaml
-oidc_config:
- enabled: true
- issuer: "https://id.twitch.tv/oauth2/"
- client_id: "your-client-id" # TO BE FILLED
- client_secret: "your-client-secret" # TO BE FILLED
- client_auth_method: "client_secret_post"
- user_mapping_provider:
- config:
- localpart_template: "{{ user.preferred_username }}"
- display_name_template: "{{ user.name }}"
+oidc_providers:
+ - idp_id: twitch
+ idp_name: Twitch
+ issuer: "https://id.twitch.tv/oauth2/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ client_auth_method: "client_secret_post"
+ user_mapping_provider:
+ config:
+ localpart_template: "{{ user.preferred_username }}"
+ display_name_template: "{{ user.name }}"
```
### GitLab
1. Create a [new application](https://gitlab.com/profile/applications).
2. Add the `read_user` and `openid` scopes.
-3. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback`
+3. Add this Callback URL: `[synapse public baseurl]/_synapse/client/oidc/callback`
Synapse config:
```yaml
-oidc_config:
- enabled: true
- issuer: "https://gitlab.com/"
- client_id: "your-client-id" # TO BE FILLED
- client_secret: "your-client-secret" # TO BE FILLED
- client_auth_method: "client_secret_post"
- scopes: ["openid", "read_user"]
- user_profile_method: "userinfo_endpoint"
- user_mapping_provider:
- config:
- localpart_template: '{{ user.nickname }}'
- display_name_template: '{{ user.name }}'
+oidc_providers:
+ - idp_id: gitlab
+ idp_name: Gitlab
+ idp_brand: "gitlab" # optional: styling hint for clients
+ issuer: "https://gitlab.com/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ client_auth_method: "client_secret_post"
+ scopes: ["openid", "read_user"]
+ user_profile_method: "userinfo_endpoint"
+ user_mapping_provider:
+ config:
+ localpart_template: '{{ user.nickname }}'
+ display_name_template: '{{ user.name }}'
+```
+
+### Facebook
+
+Like Github, Facebook provide a custom OAuth2 API rather than an OIDC-compliant
+one so requires a little more configuration.
+
+0. You will need a Facebook developer account. You can register for one
+ [here](https://developers.facebook.com/async/registration/).
+1. On the [apps](https://developers.facebook.com/apps/) page of the developer
+ console, "Create App", and choose "Build Connected Experiences".
+2. Once the app is created, add "Facebook Login" and choose "Web". You don't
+ need to go through the whole form here.
+3. In the left-hand menu, open "Products"/"Facebook Login"/"Settings".
+ * Add `[synapse public baseurl]/_synapse/client/oidc/callback` as an OAuth Redirect
+ URL.
+4. In the left-hand menu, open "Settings/Basic". Here you can copy the "App ID"
+ and "App Secret" for use below.
+
+Synapse config:
+
+```yaml
+ - idp_id: facebook
+ idp_name: Facebook
+ idp_brand: "facebook" # optional: styling hint for clients
+ discover: false
+ issuer: "https://facebook.com"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ scopes: ["openid", "email"]
+ authorization_endpoint: https://facebook.com/dialog/oauth
+ token_endpoint: https://graph.facebook.com/v9.0/oauth/access_token
+ user_profile_method: "userinfo_endpoint"
+ userinfo_endpoint: "https://graph.facebook.com/v9.0/me?fields=id,name,email,picture"
+ user_mapping_provider:
+ config:
+ subject_claim: "id"
+ display_name_template: "{{ user.name }}"
+```
+
+Relevant documents:
+ * https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow
+ * Using Facebook's Graph API: https://developers.facebook.com/docs/graph-api/using-graph-api/
+ * Reference to the User endpoint: https://developers.facebook.com/docs/graph-api/reference/user
+
+### Gitea
+
+Gitea is, like Github, not an OpenID provider, but just an OAuth2 provider.
+
+The [`/user` API endpoint](https://try.gitea.io/api/swagger#/user/userGetCurrent)
+can be used to retrieve information on the authenticated user. As the Synapse
+login mechanism needs an attribute to uniquely identify users, and that endpoint
+does not return a `sub` property, an alternative `subject_claim` has to be set.
+
+1. Create a new application.
+2. Add this Callback URL: `[synapse public baseurl]/_synapse/client/oidc/callback`
+
+Synapse config:
+
+```yaml
+oidc_providers:
+ - idp_id: gitea
+ idp_name: Gitea
+ discover: false
+ issuer: "https://your-gitea.com/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ client_auth_method: client_secret_post
+ scopes: [] # Gitea doesn't support Scopes
+ authorization_endpoint: "https://your-gitea.com/login/oauth/authorize"
+ token_endpoint: "https://your-gitea.com/login/oauth/access_token"
+ userinfo_endpoint: "https://your-gitea.com/api/v1/user"
+ user_mapping_provider:
+ config:
+ subject_claim: "id"
+ localpart_template: "{{ user.login }}"
+ display_name_template: "{{ user.full_name }}"
+```
+
+### XWiki
+
+Install [OpenID Connect Provider](https://extensions.xwiki.org/xwiki/bin/view/Extension/OpenID%20Connect/OpenID%20Connect%20Provider/) extension in your [XWiki](https://www.xwiki.org) instance.
+
+Synapse config:
+
+```yaml
+oidc_providers:
+ - idp_id: xwiki
+ idp_name: "XWiki"
+ issuer: "https://myxwikihost/xwiki/oidc/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_auth_method: none
+ scopes: ["openid", "profile"]
+ user_profile_method: "userinfo_endpoint"
+ user_mapping_provider:
+ config:
+ localpart_template: "{{ user.preferred_username }}"
+ display_name_template: "{{ user.name }}"
+```
+
+## Apple
+
+Configuring "Sign in with Apple" (SiWA) requires an Apple Developer account.
+
+You will need to create a new "Services ID" for SiWA, and create and download a
+private key with "SiWA" enabled.
+
+As well as the private key file, you will need:
+ * Client ID: the "identifier" you gave the "Services ID"
+ * Team ID: a 10-character ID associated with your developer account.
+ * Key ID: the 10-character identifier for the key.
+
+https://help.apple.com/developer-account/?lang=en#/dev77c875b7e has more
+documentation on setting up SiWA.
+
+The synapse config will look like this:
+
+```yaml
+ - idp_id: apple
+ idp_name: Apple
+ issuer: "https://appleid.apple.com"
+ client_id: "your-client-id" # Set to the "identifier" for your "ServicesID"
+ client_auth_method: "client_secret_post"
+ client_secret_jwt_key:
+ key_file: "/path/to/AuthKey_KEYIDCODE.p8" # point to your key file
+ jwt_header:
+ alg: ES256
+ kid: "KEYIDCODE" # Set to the 10-char Key ID
+ jwt_payload:
+ iss: TEAMIDCODE # Set to the 10-char Team ID
+ scopes: ["name", "email", "openid"]
+ authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post
+ user_mapping_provider:
+ config:
+ email_template: "{{ user.email }}"
```
diff --git a/docs/postgres.md b/docs/postgres.md
index c30cc1fd8c..680685d04e 100644
--- a/docs/postgres.md
+++ b/docs/postgres.md
@@ -18,7 +18,7 @@ connect to a postgres database.
virtualenv](../INSTALL.md#installing-from-source), you can install
the library with:
- ~/synapse/env/bin/pip install matrix-synapse[postgres]
+ ~/synapse/env/bin/pip install "matrix-synapse[postgres]"
(substituting the path to your virtualenv for `~/synapse/env`, if
you used a different path). You will require the postgres
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index c7020f2df3..cf1b835b9d 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -3,29 +3,30 @@
It is recommended to put a reverse proxy such as
[nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html),
[Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html),
-[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or
-[HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage
+[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy),
+[HAProxy](https://www.haproxy.org/) or
+[relayd](https://man.openbsd.org/relayd.8) in front of Synapse. One advantage
of doing so is that it means that you can expose the default https port
(443) to Matrix clients without needing to run Synapse with root
privileges.
-**NOTE**: Your reverse proxy must not `canonicalise` or `normalise`
-the requested URI in any way (for example, by decoding `%xx` escapes).
-Beware that Apache *will* canonicalise URIs unless you specify
-`nocanon`.
+You should configure your reverse proxy to forward requests to `/_matrix` or
+`/_synapse/client` to Synapse, and have it set the `X-Forwarded-For` and
+`X-Forwarded-Proto` request headers.
-When setting up a reverse proxy, remember that Matrix clients and other
-Matrix servers do not necessarily need to connect to your server via the
-same server name or port. Indeed, clients will use port 443 by default,
-whereas servers default to port 8448. Where these are different, we
-refer to the 'client port' and the 'federation port'. See [the Matrix
+You should remember that Matrix clients and other Matrix servers do not
+necessarily need to connect to your server via the same server name or
+port. Indeed, clients will use port 443 by default, whereas servers default to
+port 8448. Where these are different, we refer to the 'client port' and the
+'federation port'. See [the Matrix
specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names)
for more details of the algorithm used for federation connections, and
[delegate.md](<delegate.md>) for instructions on setting up delegation.
-Endpoints that are part of the standardised Matrix specification are
-located under `/_matrix`, whereas endpoints specific to Synapse are
-located under `/_synapse/client`.
+**NOTE**: Your reverse proxy must not `canonicalise` or `normalise`
+the requested URI in any way (for example, by decoding `%xx` escapes).
+Beware that Apache *will* canonicalise URIs unless you specify
+`nocanon`.
Let's assume that we expect clients to connect to our server at
`https://matrix.example.com`, and other servers to connect at
@@ -40,18 +41,21 @@ the reverse proxy and the homeserver.
```
server {
- listen 443 ssl;
- listen [::]:443 ssl;
+ listen 443 ssl http2;
+ listen [::]:443 ssl http2;
# For the federation port
- listen 8448 ssl default_server;
- listen [::]:8448 ssl default_server;
+ listen 8448 ssl http2 default_server;
+ listen [::]:8448 ssl http2 default_server;
server_name matrix.example.com;
location ~* ^(\/_matrix|\/_synapse\/client) {
proxy_pass http://localhost:8008;
proxy_set_header X-Forwarded-For $remote_addr;
+ proxy_set_header X-Forwarded-Proto $scheme;
+ proxy_set_header Host $host;
+
# Nginx by default only allows file uploads up to 1M in size
# Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
client_max_body_size 50M;
@@ -100,9 +104,11 @@ example.com:8448 {
```
<VirtualHost *:443>
SSLEngine on
- ServerName matrix.example.com;
+ ServerName matrix.example.com
+ RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME}
AllowEncodedSlashes NoDecode
+ ProxyPreserveHost on
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
ProxyPass /_synapse/client http://127.0.0.1:8008/_synapse/client nocanon
@@ -111,8 +117,9 @@ example.com:8448 {
<VirtualHost *:8448>
SSLEngine on
- ServerName example.com;
+ ServerName example.com
+ RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME}
AllowEncodedSlashes NoDecode
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
@@ -129,11 +136,16 @@ example.com:8448 {
</IfModule>
```
+**NOTE 3**: Missing `ProxyPreserveHost on` can lead to a redirect loop.
+
### HAProxy
```
frontend https
bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
+ http-request set-header X-Forwarded-Proto https if { ssl_fc }
+ http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
+ http-request set-header X-Forwarded-For %[src]
# Matrix client traffic
acl matrix-host hdr(host) -i matrix.example.com
@@ -144,12 +156,62 @@ frontend https
frontend matrix-federation
bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
+ http-request set-header X-Forwarded-Proto https if { ssl_fc }
+ http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
+ http-request set-header X-Forwarded-For %[src]
+
default_backend matrix
backend matrix
server matrix 127.0.0.1:8008
```
+### Relayd
+
+```
+table <webserver> { 127.0.0.1 }
+table <matrixserver> { 127.0.0.1 }
+
+http protocol "https" {
+ tls { no tlsv1.0, ciphers "HIGH" }
+ tls keypair "example.com"
+ match header set "X-Forwarded-For" value "$REMOTE_ADDR"
+ match header set "X-Forwarded-Proto" value "https"
+
+ # set CORS header for .well-known/matrix/server, .well-known/matrix/client
+ # httpd does not support setting headers, so do it here
+ match request path "/.well-known/matrix/*" tag "matrix-cors"
+ match response tagged "matrix-cors" header set "Access-Control-Allow-Origin" value "*"
+
+ pass quick path "/_matrix/*" forward to <matrixserver>
+ pass quick path "/_synapse/client/*" forward to <matrixserver>
+
+ # pass on non-matrix traffic to webserver
+ pass forward to <webserver>
+}
+
+relay "https_traffic" {
+ listen on egress port 443 tls
+ protocol "https"
+ forward to <matrixserver> port 8008 check tcp
+ forward to <webserver> port 8080 check tcp
+}
+
+http protocol "matrix" {
+ tls { no tlsv1.0, ciphers "HIGH" }
+ tls keypair "example.com"
+ block
+ pass quick path "/_matrix/*" forward to <matrixserver>
+ pass quick path "/_synapse/client/*" forward to <matrixserver>
+}
+
+relay "matrix_federation" {
+ listen on egress port 8448 tls
+ protocol "matrix"
+ forward to <matrixserver> port 8008 check tcp
+}
+```
+
## Homeserver Configuration
You will also want to set `bind_addresses: ['127.0.0.1']` and
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 95e6a6bd2f..be5e84f0ad 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -67,11 +67,12 @@ pid_file: DATADIR/homeserver.pid
#
#web_client_location: https://riot.example.com/
-# The public-facing base URL that clients use to access this HS
-# (not including _matrix/...). This is the same URL a user would
-# enter into the 'custom HS URL' field on their client. If you
-# use synapse with a reverse proxy, this should be the URL to reach
-# synapse via the proxy.
+# The public-facing base URL that clients use to access this Homeserver (not
+# including _matrix/...). This is the same URL a user might enter into the
+# 'Custom Homeserver URL' field on their client. If you use Synapse with a
+# reverse proxy, this should be the URL to reach Synapse via the proxy.
+# Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
+# 'listeners' below).
#
#public_baseurl: https://example.com/
@@ -88,8 +89,7 @@ pid_file: DATADIR/homeserver.pid
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
# 'false'. Note that profile data is also available via the federation
-# API, so this setting is of limited value if federation is enabled on
-# the server.
+# API, unless allow_profile_lookup_over_federation is set to false.
#
#require_auth_for_profile_requests: true
@@ -100,6 +100,14 @@ pid_file: DATADIR/homeserver.pid
#
#limit_profile_requests_to_users_who_share_rooms: true
+# Uncomment to prevent a user's profile data from being retrieved and
+# displayed in a room until they have joined it. By default, a user's
+# profile data is included in an invite event, regardless of the values
+# of the above two settings, and whether or not the users share a server.
+# Defaults to 'true'.
+#
+#include_profile_data_on_invite: false
+
# If set to 'true', removes the need for authentication to access the server's
# public rooms directory through the client API, meaning that anyone can
# query the room directory. Defaults to 'false'.
@@ -144,6 +152,51 @@ pid_file: DATADIR/homeserver.pid
#
#enable_search: false
+# Prevent outgoing requests from being sent to the following blacklisted IP address
+# CIDR ranges. If this option is not specified then it defaults to private IP
+# address ranges (see the example below).
+#
+# The blacklist applies to the outbound requests for federation, identity servers,
+# push servers, and for checking key validity for third-party invite events.
+#
+# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+# listed here, since they correspond to unroutable addresses.)
+#
+# This option replaces federation_ip_range_blacklist in Synapse v1.25.0.
+#
+#ip_range_blacklist:
+# - '127.0.0.0/8'
+# - '10.0.0.0/8'
+# - '172.16.0.0/12'
+# - '192.168.0.0/16'
+# - '100.64.0.0/10'
+# - '192.0.0.0/24'
+# - '169.254.0.0/16'
+# - '192.88.99.0/24'
+# - '198.18.0.0/15'
+# - '192.0.2.0/24'
+# - '198.51.100.0/24'
+# - '203.0.113.0/24'
+# - '224.0.0.0/4'
+# - '::1/128'
+# - 'fe80::/10'
+# - 'fc00::/7'
+# - '2001:db8::/32'
+# - 'ff00::/8'
+# - 'fec0::/10'
+
+# List of IP address CIDR ranges that should be allowed for federation,
+# identity servers, push servers, and for checking key validity for
+# third-party invite events. This is useful for specifying exceptions to
+# wide-ranging blacklisted target IP ranges - e.g. for communication with
+# a push server only visible in your network.
+#
+# This whitelist overrides ip_range_blacklist and defaults to an empty
+# list.
+#
+#ip_range_whitelist:
+# - '192.168.1.1'
+
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
@@ -710,29 +763,6 @@ acme:
# - nyc.example.com
# - syd.example.com
-# Prevent outgoing requests from being sent to the following blacklisted IP address
-# CIDR ranges. If this option is not specified, or specified with an empty list,
-# no IP range blacklist will be enforced.
-#
-# The blacklist applies to the outbound requests for federation, identity servers,
-# push servers, and for checking key validitity for third-party invite events.
-#
-# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
-# listed here, since they correspond to unroutable addresses.)
-#
-# This option replaces federation_ip_range_blacklist in Synapse v1.24.0.
-#
-ip_range_blacklist:
- - '127.0.0.0/8'
- - '10.0.0.0/8'
- - '172.16.0.0/12'
- - '192.168.0.0/16'
- - '100.64.0.0/10'
- - '169.254.0.0/16'
- - '::1/128'
- - 'fe80::/64'
- - 'fc00::/7'
-
# Report prometheus metrics on the age of PDUs being sent to and received from
# the following domains. This can be used to give an idea of "delay" on inbound
# and outbound federation, though be aware that any delay can be due to problems
@@ -744,6 +774,12 @@ ip_range_blacklist:
# - matrix.org
# - example.com
+# Uncomment to disable profile lookup over federation. By default, the
+# Federation API allows other homeservers to obtain profile data of any user
+# on this homeserver. Defaults to 'true'.
+#
+#allow_profile_lookup_over_federation: false
+
## Caching ##
@@ -871,6 +907,9 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which
# can be more expensive)
+# - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
+# - two for ratelimiting how often invites can be sent in a room or to a
+# specific user.
#
# The defaults are as shown below.
#
@@ -904,11 +943,22 @@ log_config: "CONFDIR/SERVERNAME.log.config"
#rc_joins:
# local:
# per_second: 0.1
-# burst_count: 3
+# burst_count: 10
# remote:
# per_second: 0.01
-# burst_count: 3
-
+# burst_count: 10
+#
+#rc_3pid_validation:
+# per_second: 0.003
+# burst_count: 5
+#
+#rc_invites:
+# per_room:
+# per_second: 0.3
+# burst_count: 10
+# per_user:
+# per_second: 0.003
+# burst_count: 5
# Ratelimiting settings for incoming federation
#
@@ -1053,10 +1103,20 @@ media_store_path: "DATADIR/media_store"
# - '172.16.0.0/12'
# - '192.168.0.0/16'
# - '100.64.0.0/10'
+# - '192.0.0.0/24'
# - '169.254.0.0/16'
+# - '192.88.99.0/24'
+# - '198.18.0.0/15'
+# - '192.0.2.0/24'
+# - '198.51.100.0/24'
+# - '203.0.113.0/24'
+# - '224.0.0.0/4'
# - '::1/128'
-# - 'fe80::/64'
+# - 'fe80::/10'
# - 'fc00::/7'
+# - '2001:db8::/32'
+# - 'ff00::/8'
+# - 'fec0::/10'
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
@@ -1366,6 +1426,8 @@ account_threepid_delegates:
# By default, any room aliases included in this list will be created
# as a publicly joinable room when the first user registers for the
# homeserver. This behaviour can be customised with the settings below.
+# If the room already exists, make certain it is a publicly joinable
+# room. The join rule of the room must be set to 'public'.
#
#auto_join_rooms:
# - "#example:example.com"
@@ -1705,10 +1767,10 @@ trusted_key_servers:
# enable SAML login.
#
# Once SAML support is enabled, a metadata file will be exposed at
-# https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
+# https://<server>:<port>/_synapse/client/saml2/metadata.xml, which you may be able to
# use to configure your SAML IdP with. Alternatively, you can manually configure
# the IdP to use an ACS location of
-# https://<server>:<port>/_matrix/saml2/authn_response.
+# https://<server>:<port>/_synapse/client/saml2/authn_response.
#
saml2_config:
# `sp_config` is the configuration for the pysaml2 Service Provider.
@@ -1865,140 +1927,188 @@ saml2_config:
#idp_entityid: 'https://our_idp/entityid'
-# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
+# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
+# and login.
#
-# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
-# for some example configurations.
+# Options for each entry include:
#
-oidc_config:
- # Uncomment the following to enable authorization against an OpenID Connect
- # server. Defaults to false.
- #
- #enabled: true
-
- # Uncomment the following to disable use of the OIDC discovery mechanism to
- # discover endpoints. Defaults to true.
- #
- #discover: false
-
- # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
- # discover the provider's endpoints.
- #
- # Required if 'enabled' is true.
- #
- #issuer: "https://accounts.example.com/"
-
- # oauth2 client id to use.
- #
- # Required if 'enabled' is true.
- #
- #client_id: "provided-by-your-issuer"
-
- # oauth2 client secret to use.
- #
- # Required if 'enabled' is true.
- #
- #client_secret: "provided-by-your-issuer"
-
- # auth method to use when exchanging the token.
- # Valid values are 'client_secret_basic' (default), 'client_secret_post' and
- # 'none'.
- #
- #client_auth_method: client_secret_post
-
- # list of scopes to request. This should normally include the "openid" scope.
- # Defaults to ["openid"].
- #
- #scopes: ["openid", "profile"]
-
- # the oauth2 authorization endpoint. Required if provider discovery is disabled.
- #
- #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
-
- # the oauth2 token endpoint. Required if provider discovery is disabled.
- #
- #token_endpoint: "https://accounts.example.com/oauth2/token"
-
- # the OIDC userinfo endpoint. Required if discovery is disabled and the
- # "openid" scope is not requested.
- #
- #userinfo_endpoint: "https://accounts.example.com/userinfo"
-
- # URI where to fetch the JWKS. Required if discovery is disabled and the
- # "openid" scope is used.
- #
- #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
-
- # Uncomment to skip metadata verification. Defaults to false.
- #
- # Use this if you are connecting to a provider that is not OpenID Connect
- # compliant.
- # Avoid this in production.
- #
- #skip_verification: true
-
- # Whether to fetch the user profile from the userinfo endpoint. Valid
- # values are: "auto" or "userinfo_endpoint".
- #
- # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
- # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
- #
- #user_profile_method: "userinfo_endpoint"
-
- # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
- # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
- #
- #allow_existing_users: true
-
- # An external module can be provided here as a custom solution to mapping
- # attributes returned from a OIDC provider onto a matrix user.
- #
- user_mapping_provider:
- # The custom module's class. Uncomment to use a custom module.
- # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
- #
- # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
- # for information on implementing a custom mapping provider.
- #
- #module: mapping_provider.OidcMappingProvider
-
- # Custom configuration values for the module. This section will be passed as
- # a Python dictionary to the user mapping provider module's `parse_config`
- # method.
- #
- # The examples below are intended for the default provider: they should be
- # changed if using a custom provider.
- #
- config:
- # name of the claim containing a unique identifier for the user.
- # Defaults to `sub`, which OpenID Connect compliant providers should provide.
- #
- #subject_claim: "sub"
-
- # Jinja2 template for the localpart of the MXID.
- #
- # When rendering, this template is given the following variables:
- # * user: The claims returned by the UserInfo Endpoint and/or in the ID
- # Token
- #
- # This must be configured if using the default mapping provider.
- #
- localpart_template: "{{ user.preferred_username }}"
-
- # Jinja2 template for the display name to set on first login.
- #
- # If unset, no displayname will be set.
- #
- #display_name_template: "{{ user.given_name }} {{ user.last_name }}"
-
- # Jinja2 templates for extra attributes to send back to the client during
- # login.
- #
- # Note that these are non-standard and clients will ignore them without modifications.
- #
- #extra_attributes:
- #birthdate: "{{ user.birthdate }}"
-
+# idp_id: a unique identifier for this identity provider. Used internally
+# by Synapse; should be a single word such as 'github'.
+#
+# Note that, if this is changed, users authenticating via that provider
+# will no longer be recognised as the same user!
+#
+# (Use "oidc" here if you are migrating from an old "oidc_config"
+# configuration.)
+#
+# idp_name: A user-facing name for this identity provider, which is used to
+# offer the user a choice of login mechanisms.
+#
+# idp_icon: An optional icon for this identity provider, which is presented
+# by clients and Synapse's own IdP picker page. If given, must be an
+# MXC URI of the format mxc://<server-name>/<media-id>. (An easy way to
+# obtain such an MXC URI is to upload an image to an (unencrypted) room
+# and then copy the "url" from the source of the event.)
+#
+# idp_brand: An optional brand for this identity provider, allowing clients
+# to style the login flow according to the identity provider in question.
+# See the spec for possible options here.
+#
+# discover: set to 'false' to disable the use of the OIDC discovery mechanism
+# to discover endpoints. Defaults to true.
+#
+# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
+# is enabled) to discover the provider's endpoints.
+#
+# client_id: Required. oauth2 client id to use.
+#
+# client_secret: oauth2 client secret to use. May be omitted if
+# client_secret_jwt_key is given, or if client_auth_method is 'none'.
+#
+# client_secret_jwt_key: Alternative to client_secret: details of a key used
+# to create a JSON Web Token to be used as an OAuth2 client secret. If
+# given, must be a dictionary with the following properties:
+#
+# key: a pem-encoded signing key. Must be a suitable key for the
+# algorithm specified. Required unless 'key_file' is given.
+#
+# key_file: the path to file containing a pem-encoded signing key file.
+# Required unless 'key' is given.
+#
+# jwt_header: a dictionary giving properties to include in the JWT
+# header. Must include the key 'alg', giving the algorithm used to
+# sign the JWT, such as "ES256", using the JWA identifiers in
+# RFC7518.
+#
+# jwt_payload: an optional dictionary giving properties to include in
+# the JWT payload. Normally this should include an 'iss' key.
+#
+# client_auth_method: auth method to use when exchanging the token. Valid
+# values are 'client_secret_basic' (default), 'client_secret_post' and
+# 'none'.
+#
+# scopes: list of scopes to request. This should normally include the "openid"
+# scope. Defaults to ["openid"].
+#
+# authorization_endpoint: the oauth2 authorization endpoint. Required if
+# provider discovery is disabled.
+#
+# token_endpoint: the oauth2 token endpoint. Required if provider discovery is
+# disabled.
+#
+# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
+# disabled and the 'openid' scope is not requested.
+#
+# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
+# the 'openid' scope is used.
+#
+# skip_verification: set to 'true' to skip metadata verification. Use this if
+# you are connecting to a provider that is not OpenID Connect compliant.
+# Defaults to false. Avoid this in production.
+#
+# user_profile_method: Whether to fetch the user profile from the userinfo
+# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
+#
+# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
+# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
+# userinfo endpoint.
+#
+# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
+# match a pre-existing account instead of failing. This could be used if
+# switching from password logins to OIDC. Defaults to false.
+#
+# user_mapping_provider: Configuration for how attributes returned from a OIDC
+# provider are mapped onto a matrix user. This setting has the following
+# sub-properties:
+#
+# module: The class name of a custom mapping module. Default is
+# 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
+# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
+# for information on implementing a custom mapping provider.
+#
+# config: Configuration for the mapping provider module. This section will
+# be passed as a Python dictionary to the user mapping provider
+# module's `parse_config` method.
+#
+# For the default provider, the following settings are available:
+#
+# subject_claim: name of the claim containing a unique identifier
+# for the user. Defaults to 'sub', which OpenID Connect
+# compliant providers should provide.
+#
+# localpart_template: Jinja2 template for the localpart of the MXID.
+# If this is not set, the user will be prompted to choose their
+# own username (see 'sso_auth_account_details.html' in the 'sso'
+# section of this file).
+#
+# display_name_template: Jinja2 template for the display name to set
+# on first login. If unset, no displayname will be set.
+#
+# email_template: Jinja2 template for the email address of the user.
+# If unset, no email address will be added to the account.
+#
+# extra_attributes: a map of Jinja2 templates for extra attributes
+# to send back to the client during login.
+# Note that these are non-standard and clients will ignore them
+# without modifications.
+#
+# When rendering, the Jinja2 templates are given a 'user' variable,
+# which is set to the claims returned by the UserInfo Endpoint and/or
+# in the ID Token.
+#
+# It is possible to configure Synapse to only allow logins if certain attributes
+# match particular values in the OIDC userinfo. The requirements can be listed under
+# `attribute_requirements` as shown below. All of the listed attributes must
+# match for the login to be permitted. Additional attributes can be added to
+# userinfo by expanding the `scopes` section of the OIDC config to retrieve
+# additional information from the OIDC provider.
+#
+# If the OIDC claim is a list, then the attribute must match any value in the list.
+# Otherwise, it must exactly match the value of the claim. Using the example
+# below, the `family_name` claim MUST be "Stephensson", but the `groups`
+# claim MUST contain "admin".
+#
+# attribute_requirements:
+# - attribute: family_name
+# value: "Stephensson"
+# - attribute: groups
+# value: "admin"
+#
+# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
+# for information on how to configure these options.
+#
+# For backwards compatibility, it is also possible to configure a single OIDC
+# provider via an 'oidc_config' setting. This is now deprecated and admins are
+# advised to migrate to the 'oidc_providers' format. (When doing that migration,
+# use 'oidc' for the idp_id to ensure that existing users continue to be
+# recognised.)
+#
+oidc_providers:
+ # Generic example
+ #
+ #- idp_id: my_idp
+ # idp_name: "My OpenID provider"
+ # idp_icon: "mxc://example.com/mediaid"
+ # discover: false
+ # issuer: "https://accounts.example.com/"
+ # client_id: "provided-by-your-issuer"
+ # client_secret: "provided-by-your-issuer"
+ # client_auth_method: client_secret_post
+ # scopes: ["openid", "profile"]
+ # authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+ # token_endpoint: "https://accounts.example.com/oauth2/token"
+ # userinfo_endpoint: "https://accounts.example.com/userinfo"
+ # jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+ # skip_verification: true
+ # user_mapping_provider:
+ # config:
+ # subject_claim: "id"
+ # localpart_template: "{{ user.login }}"
+ # display_name_template: "{{ user.name }}"
+ # email_template: "{{ user.email }}"
+ # attribute_requirements:
+ # - attribute: userGroup
+ # value: "synapseUsers"
# Enable Central Authentication Service (CAS) for registration and login.
@@ -2013,10 +2123,6 @@ cas_config:
#
#server_url: "https://cas-server.com"
- # The public URL of the homeserver.
- #
- #service_url: "https://homeserver.domain.com:8448"
-
# The attribute of the CAS response to use as the display name.
#
# If unset, no displayname will be set.
@@ -2059,41 +2165,135 @@ sso:
# - https://my.custom.client/
# Directory in which Synapse will try to find the template files below.
- # If not set, default templates from within the Synapse package will be used.
- #
- # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
- # If you *do* uncomment it, you will need to make sure that all the templates
- # below are in the directory.
+ # If not set, or the files named below are not found within the template
+ # directory, default templates from within the Synapse package will be used.
#
# Synapse will look for the following templates in this directory:
#
+ # * HTML page to prompt the user to choose an Identity Provider during
+ # login: 'sso_login_idp_picker.html'.
+ #
+ # This is only used if multiple SSO Identity Providers are configured.
+ #
+ # When rendering, this template is given the following variables:
+ # * redirect_url: the URL that the user will be redirected to after
+ # login.
+ #
+ # * server_name: the homeserver's name.
+ #
+ # * providers: a list of available Identity Providers. Each element is
+ # an object with the following attributes:
+ #
+ # * idp_id: unique identifier for the IdP
+ # * idp_name: user-facing name for the IdP
+ # * idp_icon: if specified in the IdP config, an MXC URI for an icon
+ # for the IdP
+ # * idp_brand: if specified in the IdP config, a textual identifier
+ # for the brand of the IdP
+ #
+ # The rendered HTML page should contain a form which submits its results
+ # back as a GET request, with the following query parameters:
+ #
+ # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
+ # to the template)
+ #
+ # * idp: the 'idp_id' of the chosen IDP.
+ #
+ # * HTML page to prompt new users to enter a userid and confirm other
+ # details: 'sso_auth_account_details.html'. This is only shown if the
+ # SSO implementation (with any user_mapping_provider) does not return
+ # a localpart.
+ #
+ # When rendering, this template is given the following variables:
+ #
+ # * server_name: the homeserver's name.
+ #
+ # * idp: details of the SSO Identity Provider that the user logged in
+ # with: an object with the following attributes:
+ #
+ # * idp_id: unique identifier for the IdP
+ # * idp_name: user-facing name for the IdP
+ # * idp_icon: if specified in the IdP config, an MXC URI for an icon
+ # for the IdP
+ # * idp_brand: if specified in the IdP config, a textual identifier
+ # for the brand of the IdP
+ #
+ # * user_attributes: an object containing details about the user that
+ # we received from the IdP. May have the following attributes:
+ #
+ # * display_name: the user's display_name
+ # * emails: a list of email addresses
+ #
+ # The template should render a form which submits the following fields:
+ #
+ # * username: the localpart of the user's chosen user id
+ #
+ # * HTML page allowing the user to consent to the server's terms and
+ # conditions. This is only shown for new users, and only if
+ # `user_consent.require_at_registration` is set.
+ #
+ # When rendering, this template is given the following variables:
+ #
+ # * server_name: the homeserver's name.
+ #
+ # * user_id: the user's matrix proposed ID.
+ #
+ # * user_profile.display_name: the user's proposed display name, if any.
+ #
+ # * consent_version: the version of the terms that the user will be
+ # shown
+ #
+ # * terms_url: a link to the page showing the terms.
+ #
+ # The template should render a form which submits the following fields:
+ #
+ # * accepted_version: the version of the terms accepted by the user
+ # (ie, 'consent_version' from the input variables).
+ #
# * HTML page for a confirmation step before redirecting back to the client
# with the login token: 'sso_redirect_confirm.html'.
#
- # When rendering, this template is given three variables:
- # * redirect_url: the URL the user is about to be redirected to. Needs
- # manual escaping (see
- # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+ # When rendering, this template is given the following variables:
+ #
+ # * redirect_url: the URL the user is about to be redirected to.
#
# * display_url: the same as `redirect_url`, but with the query
# parameters stripped. The intention is to have a
# human-readable URL to show to users, not to use it as
- # the final address to redirect to. Needs manual escaping
- # (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+ # the final address to redirect to.
#
# * server_name: the homeserver's name.
#
+ # * new_user: a boolean indicating whether this is the user's first time
+ # logging in.
+ #
+ # * user_id: the user's matrix ID.
+ #
+ # * user_profile.avatar_url: an MXC URI for the user's avatar, if any.
+ # None if the user has not set an avatar.
+ #
+ # * user_profile.display_name: the user's display name. None if the user
+ # has not set a display name.
+ #
# * HTML page which notifies the user that they are authenticating to confirm
# an operation on their account during the user interactive authentication
# process: 'sso_auth_confirm.html'.
#
# When rendering, this template is given the following variables:
- # * redirect_url: the URL the user is about to be redirected to. Needs
- # manual escaping (see
- # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+ # * redirect_url: the URL the user is about to be redirected to.
#
# * description: the operation which the user is being asked to confirm
#
+ # * idp: details of the Identity Provider that we will use to confirm
+ # the user's identity: an object with the following attributes:
+ #
+ # * idp_id: unique identifier for the IdP
+ # * idp_name: user-facing name for the IdP
+ # * idp_icon: if specified in the IdP config, an MXC URI for an icon
+ # for the IdP
+ # * idp_brand: if specified in the IdP config, a textual identifier
+ # for the brand of the IdP
+ #
# * HTML page shown after a successful user interactive authentication session:
# 'sso_auth_success.html'.
#
@@ -2102,6 +2302,14 @@ sso:
#
# This template has no additional variables.
#
+ # * HTML page shown after a user-interactive authentication session which
+ # does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
+ #
+ # When rendering, this template is given the following variables:
+ # * server_name: the homeserver's name.
+ # * user_id_to_verify: the MXID of the user that we are trying to
+ # validate.
+ #
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
# attempts to login: 'sso_account_deactivated.html'.
#
@@ -2227,6 +2435,21 @@ password_config:
#
#require_uppercase: true
+ui_auth:
+ # The amount of time to allow a user-interactive authentication session
+ # to be active.
+ #
+ # This defaults to 0, meaning the user is queried for their credentials
+ # before every action, but this can be overridden to allow a single
+ # validation to be re-used. This weakens the protections afforded by
+ # the user-interactive authentication process, by allowing for multiple
+ # (and potentially different) operations to use the same validation session.
+ #
+ # Uncomment below to allow for credential validation to last for 15
+ # seconds.
+ #
+ #session_timeout: "15s"
+
# Configuration for sending emails from Synapse.
#
@@ -2292,10 +2515,15 @@ email:
#
#validation_token_lifetime: 15m
- # Directory in which Synapse will try to find the template files below.
- # If not set, default templates from within the Synapse package will be used.
+ # The web client location to direct users to during an invite. This is passed
+ # to the identity server as the org.matrix.web_client_location key. Defaults
+ # to unset, giving no guidance to the identity server.
#
- # Do not uncomment this setting unless you want to customise the templates.
+ #invite_client_location: https://app.element.io
+
+ # Directory in which Synapse will try to find the template files below.
+ # If not set, or the files named below are not found within the template
+ # directory, default templates from within the Synapse package will be used.
#
# Synapse will look for the following templates in this directory:
#
@@ -2504,35 +2732,47 @@ spam_checker:
# If enabled, non server admins can only create groups with local parts
# starting with this prefix
#
-#group_creation_prefix: "unofficial/"
+#group_creation_prefix: "unofficial_"
# User Directory configuration
#
-# 'enabled' defines whether users can search the user directory. If
-# false then empty responses are returned to all queries. Defaults to
-# true.
-#
-# 'search_all_users' defines whether to search all users visible to your HS
-# when searching the user directory, rather than limiting to users visible
-# in public rooms. Defaults to false. If you set it True, you'll have to
-# rebuild the user_directory search indexes, see
-# https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
-#
-# 'prefer_local_users' defines whether to prioritise local users in
-# search query results. If True, local users are more likely to appear above
-# remote users when searching the user directory. Defaults to false.
-#
-#user_directory:
-# enabled: true
-# search_all_users: false
-# prefer_local_users: false
-#
-# # If this is set, user search will be delegated to this ID server instead
-# # of synapse performing the search itself.
-# # This is an experimental API.
-# defer_to_id_server: https://id.example.com
+user_directory:
+ # Defines whether users can search the user directory. If false then
+ # empty responses are returned to all queries. Defaults to true.
+ #
+ # Uncomment to disable the user directory.
+ #
+ #enabled: false
+
+ # Defines whether to search all users visible to your HS when searching
+ # the user directory, rather than limiting to users visible in public
+ # rooms. Defaults to false.
+ #
+ # If you set it true, you'll have to rebuild the user_directory search
+ # indexes, see:
+ # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
+ #
+ # Uncomment to return search results containing all known users, even if that
+ # user does not share a room with the requester.
+ #
+ #search_all_users: true
+
+ # If this is set, user search will be delegated to this ID server instead
+ # of Synapse performing the search itself.
+ # This is an experimental API.
+ #
+ #defer_to_id_server: https://id.example.com
+
+ # Defines whether to prefer local users in search query results.
+ # If True, local users are more likely to appear above remote users
+ # when searching the user directory. Defaults to false.
+ #
+ # Uncomment to prefer local over remote users in user directory search
+ # results.
+ #
+ #prefer_local_users: true
# User Consent configuration
@@ -2587,19 +2827,20 @@ spam_checker:
-# Local statistics collection. Used in populating the room directory.
-#
-# 'bucket_size' controls how large each statistics timeslice is. It can
-# be defined in a human readable short form -- e.g. "1d", "1y".
-#
-# 'retention' controls how long historical statistics will be kept for.
-# It can be defined in a human readable short form -- e.g. "1d", "1y".
-#
+# Settings for local room and user statistics collection. See
+# docs/room_and_user_statistics.md.
#
-#stats:
-# enabled: true
-# bucket_size: 1d
-# retention: 1y
+stats:
+ # Uncomment the following to disable room and user statistics. Note that doing
+ # so may cause certain features (such as the room directory) not to work
+ # correctly.
+ #
+ #enabled: false
+
+ # The size of each timeslice in the room_stats_historical and
+ # user_stats_historical tables, as a time period. Defaults to "1d".
+ #
+ #bucket_size: 1h
# Server Notices room configuration
@@ -2779,6 +3020,13 @@ opentracing:
#
#run_background_tasks_on: worker1
+# A shared secret used by the replication APIs to authenticate HTTP requests
+# from workers.
+#
+# By default this is unused and traffic is not authenticated.
+#
+#worker_replication_secret: ""
+
# Configuration for Redis when using workers. This *must* be enabled when
# using workers (unless using old style direct TCP configuration).
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index 7fc08f1b70..52947f605e 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -14,6 +14,7 @@ The Python class is instantiated with two objects:
* An instance of `synapse.module_api.ModuleApi`.
It then implements methods which return a boolean to alter behavior in Synapse.
+All the methods must be defined.
There's a generic method for checking every event (`check_event_for_spam`), as
well as some specific methods:
@@ -22,38 +23,63 @@ well as some specific methods:
* `user_may_create_room`
* `user_may_create_room_alias`
* `user_may_publish_room`
+* `check_username_for_spam`
+* `check_registration_for_spam`
+* `check_media_file_for_spam`
-The details of the each of these methods (as well as their inputs and outputs)
+The details of each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class.
The `ModuleApi` class provides a way for the custom spam checker class to
call back into the homeserver internals.
+Additionally, a `parse_config` method is mandatory and receives the plugin config
+dictionary. After parsing, It must return an object which will be
+passed to `__init__` later.
+
### Example
```python
+from synapse.spam_checker_api import RegistrationBehaviour
+
class ExampleSpamChecker:
def __init__(self, config, api):
self.config = config
self.api = api
- def check_event_for_spam(self, foo):
+ @staticmethod
+ def parse_config(config):
+ return config
+
+ async def check_event_for_spam(self, foo):
return False # allow all events
- def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
return True # allow all invites
- def user_may_create_room(self, userid):
+ async def user_may_create_room(self, userid):
return True # allow all room creations
- def user_may_create_room_alias(self, userid, room_alias):
+ async def user_may_create_room_alias(self, userid, room_alias):
return True # allow all room aliases
- def user_may_publish_room(self, userid, room_id):
+ async def user_may_publish_room(self, userid, room_id):
return True # allow publishing of all rooms
- def check_username_for_spam(self, user_profile):
+ async def check_username_for_spam(self, user_profile):
return False # allow all usernames
+
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ return RegistrationBehaviour.ALLOW # allow all registrations
+
+ async def check_media_file_for_spam(self, file_wrapper, file_info):
+ return False # allow all media
```
## Configuration
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index ab2a648910..e1d6ede7ba 100644
--- a/docs/sso_mapping_providers.md
+++ b/docs/sso_mapping_providers.md
@@ -15,12 +15,18 @@ where SAML mapping providers come into play.
SSO mapping providers are currently supported for OpenID and SAML SSO
configurations. Please see the details below for how to implement your own.
-It is the responsibility of the mapping provider to normalise the SSO attributes
-and map them to a valid Matrix ID. The
-[specification for Matrix IDs](https://matrix.org/docs/spec/appendices#user-identifiers)
-has some information about what is considered valid. Alternately an easy way to
-ensure it is valid is to use a Synapse utility function:
-`synapse.types.map_username_to_mxid_localpart`.
+It is up to the mapping provider whether the user should be assigned a predefined
+Matrix ID based on the SSO attributes, or if the user should be allowed to
+choose their own username.
+
+In the first case - where users are automatically allocated a Matrix ID - it is
+the responsibility of the mapping provider to normalise the SSO attributes and
+map them to a valid Matrix ID. The [specification for Matrix
+IDs](https://matrix.org/docs/spec/appendices#user-identifiers) has some
+information about what is considered valid.
+
+If the mapping provider does not assign a Matrix ID, then Synapse will
+automatically serve an HTML page allowing the user to pick their own username.
External mapping providers are provided to Synapse in the form of an external
Python module. You can retrieve this module from [PyPI](https://pypi.org) or elsewhere,
@@ -80,8 +86,9 @@ A custom mapping provider must specify the following methods:
with failures=1. The method should then return a different
`localpart` value, such as `john.doe1`.
- Returns a dictionary with two keys:
- - localpart: A required string, used to generate the Matrix ID.
- - displayname: An optional string, the display name for the user.
+ - `localpart`: A string, used to generate the Matrix ID. If this is
+ `None`, the user is prompted to pick their own username.
+ - `displayname`: An optional string, the display name for the user.
* `get_extra_attributes(self, userinfo, token)`
- This method must be async.
- Arguments:
@@ -116,11 +123,13 @@ comment these options out and use those specified by the module instead.
A custom mapping provider must specify the following methods:
-* `__init__(self, parsed_config)`
+* `__init__(self, parsed_config, module_api)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
+ - `module_api` - a `synapse.module_api.ModuleApi` object which provides the
+ stable API available for extension modules.
* `parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
@@ -163,12 +172,13 @@ A custom mapping provider must specify the following methods:
redirected to.
- This method must return a dictionary, which will then be used by Synapse
to build a new user. The following keys are allowed:
- * `mxid_localpart` - Required. The mxid localpart of the new user.
+ * `mxid_localpart` - The mxid localpart of the new user. If this is
+ `None`, the user is prompted to pick their own username.
* `displayname` - The displayname of the new user. If not provided, will default to
the value of `mxid_localpart`.
* `emails` - A list of emails for the new user. If not provided, will
default to an empty list.
-
+
Alternatively it can raise a `synapse.api.errors.RedirectException` to
redirect the user to another page. This is useful to prompt the user for
additional information, e.g. if you want them to provide their own username.
diff --git a/docs/systemd-with-workers/README.md b/docs/systemd-with-workers/README.md
index 8e57d4f62e..cfa36be7b4 100644
--- a/docs/systemd-with-workers/README.md
+++ b/docs/systemd-with-workers/README.md
@@ -31,7 +31,7 @@ There is no need for a separate configuration file for the master process.
1. Adjust synapse configuration files as above.
1. Copy the `*.service` and `*.target` files in [system](system) to
`/etc/systemd/system`.
-1. Run `systemctl deamon-reload` to tell systemd to load the new unit files.
+1. Run `systemctl daemon-reload` to tell systemd to load the new unit files.
1. Run `systemctl enable matrix-synapse.service`. This will configure the
synapse master process to be started as part of the `matrix-synapse.target`
target.
diff --git a/docs/systemd-with-workers/system/matrix-synapse-worker@.service b/docs/systemd-with-workers/system/matrix-synapse-worker@.service
index cb5ac0ac87..d164e8ce1f 100644
--- a/docs/systemd-with-workers/system/matrix-synapse-worker@.service
+++ b/docs/systemd-with-workers/system/matrix-synapse-worker@.service
@@ -4,6 +4,7 @@ AssertPathExists=/etc/matrix-synapse/workers/%i.yaml
# This service should be restarted when the synapse target is restarted.
PartOf=matrix-synapse.target
+ReloadPropagatedFrom=matrix-synapse.target
# if this is started at the same time as the main, let the main process start
# first, to initialise the database schema.
diff --git a/docs/systemd-with-workers/system/matrix-synapse.service b/docs/systemd-with-workers/system/matrix-synapse.service
index c7b5ddfa49..f6b6dfd3ce 100644
--- a/docs/systemd-with-workers/system/matrix-synapse.service
+++ b/docs/systemd-with-workers/system/matrix-synapse.service
@@ -3,6 +3,7 @@ Description=Synapse master
# This service should be restarted when the synapse target is restarted.
PartOf=matrix-synapse.target
+ReloadPropagatedFrom=matrix-synapse.target
[Service]
Type=notify
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index ad145439b4..15df949deb 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -220,10 +220,6 @@ Asks the server for the current position of all streams.
Acknowledge receipt of some federation data
-#### REMOVE_PUSHER (C)
-
- Inform the server a pusher should be removed
-
### REMOTE_SERVER_UP (S, C)
Inform other processes that a remote server may have come back online.
diff --git a/docs/turn-howto.md b/docs/turn-howto.md
index a470c274a5..41738bbe69 100644
--- a/docs/turn-howto.md
+++ b/docs/turn-howto.md
@@ -187,7 +187,7 @@ After updating the homeserver configuration, you must restart synapse:
```
* If you use systemd:
```
- systemctl restart synapse.service
+ systemctl restart matrix-synapse.service
```
... and then reload any clients (or wait an hour for them to refresh their
settings).
@@ -232,6 +232,12 @@ Here are a few things to try:
(Understanding the output is beyond the scope of this document!)
+ * You can test your Matrix homeserver TURN setup with https://test.voip.librepush.net/.
+ Note that this test is not fully reliable yet, so don't be discouraged if
+ the test fails.
+ [Here](https://github.com/matrix-org/voip-tester) is the github repo of the
+ source of the tester, where you can file bug reports.
+
* There is a WebRTC test tool at
https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To
use it, you will need a username/password for your TURN server. You can
diff --git a/docs/workers.md b/docs/workers.md
index c53d1bd2ff..c6282165b0 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -16,6 +16,9 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only
be used for demo purposes and any admin considering workers should already be
running PostgreSQL.
+See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability
+for a higher level overview.
+
## Main process/worker communication
The processes communicate with each other via a Synapse-specific protocol called
@@ -37,6 +40,9 @@ which relays replication commands between processes. This can give a significant
cpu saving on the main process and will be a prerequisite for upcoming
performance improvements.
+If Redis support is enabled Synapse will use it as a shared cache, as well as a
+pub/sub mechanism.
+
See the [Architectural diagram](#architectural-diagram) section at the end for
a visualisation of what this looks like.
@@ -56,7 +62,7 @@ The appropriate dependencies must also be installed for Synapse. If using a
virtualenv, these can be installed with:
```sh
-pip install matrix-synapse[redis]
+pip install "matrix-synapse[redis]"
```
Note that these dependencies are included when synapse is installed with `pip
@@ -89,7 +95,8 @@ shared configuration file.
Normally, only a couple of changes are needed to make an existing configuration
file suitable for use with workers. First, you need to enable an "HTTP replication
listener" for the main process; and secondly, you need to enable redis-based
-replication. For example:
+replication. Optionally, a shared secret can be used to authenticate HTTP
+traffic between workers. For example:
```yaml
@@ -103,6 +110,9 @@ listeners:
resources:
- names: [replication]
+# Add a random shared secret to authenticate traffic.
+worker_replication_secret: ""
+
redis:
enabled: true
```
@@ -210,6 +220,7 @@ expressions:
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
+ ^/_matrix/client/(api/v1|r0|unstable)/devices$
^/_matrix/client/(api/v1|r0|unstable)/keys/query$
^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
^/_matrix/client/versions$
@@ -217,14 +228,13 @@ expressions:
^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
- ^/_synapse/client/password_reset/email/submit_token$
# Registration/login requests
^/_matrix/client/(api/v1|r0|unstable)/login$
^/_matrix/client/(r0|unstable)/register$
- ^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$
# Event sending requests
+ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
@@ -247,25 +257,30 @@ Additionally, the following endpoints should be included if Synapse is configure
to use SSO (you only need to include the ones for whichever SSO provider you're
using):
+ # for all SSO providers
+ ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect
+ ^/_synapse/client/pick_idp$
+ ^/_synapse/client/pick_username
+ ^/_synapse/client/new_user_consent$
+ ^/_synapse/client/sso_register$
+
# OpenID Connect requests.
- ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
- ^/_synapse/oidc/callback$
+ ^/_synapse/client/oidc/callback$
# SAML requests.
- ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
- ^/_matrix/saml2/authn_response$
+ ^/_synapse/client/saml2/authn_response$
# CAS requests.
- ^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$
^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$
+Ensure that all SSO logins go to a single process.
+For multiple workers not handling the SSO endpoints properly, see
+[#7530](https://github.com/matrix-org/synapse/issues/7530) and
+[#9427](https://github.com/matrix-org/synapse/issues/9427).
+
Note that a HTTP listener with `client` and `federation` resources must be
configured in the `worker_listeners` option in the worker config.
-Ensure that all SSO logins go to a single process (usually the main process).
-For multiple workers not handling the SSO endpoints properly, see
-[#7530](https://github.com/matrix-org/synapse/issues/7530).
-
#### Load balancing
It is possible to run multiple instances of this worker app, with incoming requests
@@ -358,7 +373,15 @@ Handles sending push notifications to sygnal and email. Doesn't handle any
REST endpoints itself, but you should set `start_pushers: False` in the
shared configuration file to stop the main synapse sending push notifications.
-Note this worker cannot be load-balanced: only one instance should be active.
+To run multiple instances at once the `pusher_instances` option should list all
+pusher instances by their worker name, e.g.:
+
+```yaml
+pusher_instances:
+ - pusher_worker1
+ - pusher_worker2
+```
+
### `synapse.app.appservice`
diff --git a/mypy.ini b/mypy.ini
index 3c8d303064..3ae5d45787 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -1,50 +1,32 @@
[mypy]
namespace_packages = True
plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
-follow_imports = silent
+follow_imports = normal
check_untyped_defs = True
show_error_codes = True
show_traceback = True
mypy_path = stubs
warn_unreachable = True
+local_partial_types = True
+
+# To find all folders that pass mypy you run:
+#
+# find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null" \; -print
+
files =
scripts-dev/sign_json,
synapse/api,
synapse/appservice,
synapse/config,
+ synapse/crypto,
synapse/event_auth.py,
synapse/events/builder.py,
- synapse/events/validator.py,
synapse/events/spamcheck.py,
+ synapse/events/third_party_rules.py,
+ synapse/events/validator.py,
synapse/federation,
- synapse/handlers/_base.py,
- synapse/handlers/account_data.py,
- synapse/handlers/account_validity.py,
- synapse/handlers/appservice.py,
- synapse/handlers/auth.py,
- synapse/handlers/cas_handler.py,
- synapse/handlers/deactivate_account.py,
- synapse/handlers/device.py,
- synapse/handlers/devicemessage.py,
- synapse/handlers/directory.py,
- synapse/handlers/events.py,
- synapse/handlers/federation.py,
- synapse/handlers/identity.py,
- synapse/handlers/initial_sync.py,
- synapse/handlers/message.py,
- synapse/handlers/oidc_handler.py,
- synapse/handlers/pagination.py,
- synapse/handlers/password_policy.py,
- synapse/handlers/presence.py,
- synapse/handlers/profile.py,
- synapse/handlers/read_marker.py,
- synapse/handlers/register.py,
- synapse/handlers/room.py,
- synapse/handlers/room_member.py,
- synapse/handlers/room_member_worker.py,
- synapse/handlers/saml_handler.py,
- synapse/handlers/sync.py,
- synapse/handlers/ui_auth,
+ synapse/groups,
+ synapse/handlers,
synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/well_known_resolver.py,
@@ -55,32 +37,48 @@ files =
synapse/metrics,
synapse/module_api,
synapse/notifier.py,
- synapse/push/pusherpool.py,
- synapse/push/push_rule_evaluator.py,
+ synapse/push,
synapse/replication,
synapse/rest,
+ synapse/secrets.py,
synapse/server.py,
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
+ synapse/storage/__init__.py,
+ synapse/storage/_base.py,
+ synapse/storage/background_updates.py,
synapse/storage/databases/main/appservice.py,
synapse/storage/databases/main/events.py,
+ synapse/storage/databases/main/keys.py,
+ synapse/storage/databases/main/pusher.py,
synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
synapse/storage/engines,
+ synapse/storage/keys.py,
synapse/storage/persist_events.py,
+ synapse/storage/prepare_database.py,
+ synapse/storage/purge_events.py,
+ synapse/storage/push_rule.py,
+ synapse/storage/relations.py,
+ synapse/storage/roommember.py,
synapse/storage/state.py,
+ synapse/storage/types.py,
synapse/storage/util,
synapse/streams,
synapse/types.py,
synapse/util/async_helpers.py,
synapse/util/caches,
synapse/util/metrics.py,
+ synapse/util/macaroons.py,
+ synapse/util/stringutils.py,
+ synapse/visibility.py,
tests/replication,
tests/test_utils,
tests/handlers/test_password_providers.py,
+ tests/rest/client/v1/test_login.py,
tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_stream_change_cache.py
@@ -108,6 +106,9 @@ ignore_missing_imports = True
[mypy-h11]
ignore_missing_imports = True
+[mypy-msgpack]
+ignore_missing_imports = True
+
[mypy-opentracing]
ignore_missing_imports = True
@@ -120,9 +121,6 @@ ignore_missing_imports = True
[mypy-saml2.*]
ignore_missing_imports = True
-[mypy-unpaddedbase64]
-ignore_missing_imports = True
-
[mypy-canonicaljson]
ignore_missing_imports = True
@@ -167,3 +165,9 @@ ignore_missing_imports = True
[mypy-hiredis]
ignore_missing_imports = True
+
+[mypy-josepy.*]
+ignore_missing_imports = True
+
+[mypy-txacme.*]
+ignore_missing_imports = True
diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment
index d742c522b5..47fc99efcf 100755
--- a/scripts-dev/check-newsfragment
+++ b/scripts-dev/check-newsfragment
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# A script which checks that an appropriate news file has been added on this
# branch.
diff --git a/scripts-dev/config-lint.sh b/scripts-dev/config-lint.sh
index 189ca66535..8c6323e59a 100755
--- a/scripts-dev/config-lint.sh
+++ b/scripts-dev/config-lint.sh
@@ -1,10 +1,15 @@
-#!/bin/bash
+#!/usr/bin/env bash
# Find linting errors in Synapse's default config file.
# Exits with 0 if there are no problems, or another code otherwise.
+# cd to the root of the repository
+cd `dirname $0`/..
+
+# Restore backup of sample config upon script exit
+trap "mv docs/sample_config.yaml.bak docs/sample_config.yaml" EXIT
+
# Fix non-lowercase true/false values
sed -i.bak -E "s/: +True/: true/g; s/: +False/: false/g;" docs/sample_config.yaml
-rm docs/sample_config.yaml.bak
# Check if anything changed
-git diff --exit-code docs/sample_config.yaml
+diff docs/sample_config.yaml docs/sample_config.yaml.bak
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index abcec48c4f..6f76c08fcf 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -22,8 +22,8 @@ import sys
from typing import Any, Optional
from urllib import parse as urlparse
-import nacl.signing
import requests
+import signedjson.key
import signedjson.types
import srvlookup
import yaml
@@ -44,18 +44,6 @@ def encode_base64(input_bytes):
return output_string
-def decode_base64(input_string):
- """Decode a base64 string to bytes inferring padding from the length of the
- string."""
-
- input_bytes = input_string.encode("ascii")
- input_len = len(input_bytes)
- padding = b"=" * (3 - ((input_len + 3) % 4))
- output_len = 3 * ((input_len + 2) // 4) + (input_len + 2) % 4 - 2
- output_bytes = base64.b64decode(input_bytes + padding)
- return output_bytes[:output_len]
-
-
def encode_canonical_json(value):
return json.dumps(
value,
@@ -88,42 +76,6 @@ def sign_json(
return json_object
-NACL_ED25519 = "ed25519"
-
-
-def decode_signing_key_base64(algorithm, version, key_base64):
- """Decode a base64 encoded signing key
- Args:
- algorithm (str): The algorithm the key is for (currently "ed25519").
- version (str): Identifies this key out of the keys for this entity.
- key_base64 (str): Base64 encoded bytes of the key.
- Returns:
- A SigningKey object.
- """
- if algorithm == NACL_ED25519:
- key_bytes = decode_base64(key_base64)
- key = nacl.signing.SigningKey(key_bytes)
- key.version = version
- key.alg = NACL_ED25519
- return key
- else:
- raise ValueError("Unsupported algorithm %s" % (algorithm,))
-
-
-def read_signing_keys(stream):
- """Reads a list of keys from a stream
- Args:
- stream : A stream to iterate for keys.
- Returns:
- list of SigningKey objects.
- """
- keys = []
- for line in stream:
- algorithm, version, key_base64 = line.split()
- keys.append(decode_signing_key_base64(algorithm, version, key_base64))
- return keys
-
-
def request(
method: Optional[str],
origin_name: str,
@@ -223,23 +175,28 @@ def main():
parser.add_argument("--body", help="Data to send as the body of the HTTP request")
parser.add_argument(
- "path", help="request path. We will add '/_matrix/federation/v1/' to this."
+ "path", help="request path, including the '/_matrix/federation/...' prefix."
)
args = parser.parse_args()
- if not args.server_name or not args.signing_key_path:
+ args.signing_key = None
+ if args.signing_key_path:
+ with open(args.signing_key_path) as f:
+ args.signing_key = f.readline()
+
+ if not args.server_name or not args.signing_key:
read_args_from_config(args)
- with open(args.signing_key_path) as f:
- key = read_signing_keys(f)[0]
+ algorithm, version, key_base64 = args.signing_key.split()
+ key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64)
result = request(
args.method,
args.server_name,
key,
args.destination,
- "/_matrix/federation/v1/" + args.path,
+ args.path,
content=args.body,
)
@@ -255,10 +212,16 @@ def main():
def read_args_from_config(args):
with open(args.config, "r") as fh:
config = yaml.safe_load(fh)
+
if not args.server_name:
args.server_name = config["server_name"]
- if not args.signing_key_path:
- args.signing_key_path = config["signing_key_path"]
+
+ if not args.signing_key:
+ if "signing_key" in config:
+ args.signing_key = config["signing_key"]
+ else:
+ with open(config["signing_key_path"]) as f:
+ args.signing_key = f.readline()
class MatrixConnectionAdapter(HTTPAdapter):
diff --git a/scripts-dev/generate_sample_config b/scripts-dev/generate_sample_config
index 9cb4630a5c..02739894b5 100755
--- a/scripts-dev/generate_sample_config
+++ b/scripts-dev/generate_sample_config
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# Update/check the docs/sample_config.yaml
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index f328ab57d5..9761e97594 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# Runs linting scripts over the local Synapse checkout
# isort - sorts import statements
@@ -80,7 +80,8 @@ else
# then lint everything!
if [[ -z ${files+x} ]]; then
# Lint all source code files and directories
- files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark")
+ # Note: this list aims the mirror the one in tox.ini
+ files=("synapse" "docker" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite")
fi
fi
diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh
index 60e8970a35..bc8f978660 100755
--- a/scripts-dev/make_full_schema.sh
+++ b/scripts-dev/make_full_schema.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# This script generates SQL files for creating a brand new Synapse DB with the latest
# schema, on both SQLite3 and Postgres.
@@ -162,12 +162,23 @@ else
fi
# Delete schema_version, applied_schema_deltas and applied_module_schemas tables
+# Also delete any shadow tables from fts4
# This needs to be done after synapse_port_db is run
echo "Dropping unwanted db tables..."
SQL="
DROP TABLE schema_version;
DROP TABLE applied_schema_deltas;
DROP TABLE applied_module_schemas;
+DROP TABLE event_search_content;
+DROP TABLE event_search_segments;
+DROP TABLE event_search_segdir;
+DROP TABLE event_search_docsize;
+DROP TABLE event_search_stat;
+DROP TABLE user_directory_search_content;
+DROP TABLE user_directory_search_segments;
+DROP TABLE user_directory_search_segdir;
+DROP TABLE user_directory_search_docsize;
+DROP TABLE user_directory_search_stat;
"
sqlite3 "$SQLITE_DB" <<< "$SQL"
psql $POSTGRES_DB_NAME -U "$POSTGRES_USERNAME" -w <<< "$SQL"
diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index 5882f3a0b0..18df68305b 100644
--- a/scripts-dev/mypy_synapse_plugin.py
+++ b/scripts-dev/mypy_synapse_plugin.py
@@ -31,6 +31,8 @@ class SynapsePlugin(Plugin):
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__"
+ ) or fullname.startswith(
+ "synapse.util.caches.descriptors._LruCachedFunction.__call__"
):
return cached_function_method_signature
return None
@@ -85,7 +87,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
signature = signature.copy_modified(
- arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
+ arg_types=arg_types,
+ arg_names=arg_names,
+ arg_kinds=arg_kinds,
)
return signature
diff --git a/scripts-dev/next_github_number.sh b/scripts-dev/next_github_number.sh
index 376280025a..00e9b14569 100755
--- a/scripts-dev/next_github_number.sh
+++ b/scripts-dev/next_github_number.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
set -e
@@ -6,4 +6,4 @@ set -e
# next PR number.
CURRENT_NUMBER=`curl -s "https://api.github.com/repos/matrix-org/synapse/issues?state=all&per_page=1" | jq -r ".[0].number"`
CURRENT_NUMBER=$((CURRENT_NUMBER+1))
-echo $CURRENT_NUMBER
\ No newline at end of file
+echo $CURRENT_NUMBER
diff --git a/scripts/generate_log_config b/scripts/generate_log_config
index b6957f48a3..a13a5634a3 100755
--- a/scripts/generate_log_config
+++ b/scripts/generate_log_config
@@ -40,4 +40,6 @@ if __name__ == "__main__":
)
args = parser.parse_args()
- args.output_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file))
+ out = args.output_file
+ out.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file))
+ out.flush()
diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py
index ab2e763386..8477955a90 100755
--- a/scripts/move_remote_media_to_new_store.py
+++ b/scripts/move_remote_media_to_new_store.py
@@ -51,7 +51,7 @@ def main(src_repo, dest_repo):
parts = line.split("|")
if len(parts) != 2:
print("Unable to parse input line %s" % line, file=sys.stderr)
- exit(1)
+ sys.exit(1)
move_media(parts[0], parts[1], src_paths, dest_paths)
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index c982ca9350..bc1a989fc1 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -22,7 +22,7 @@ import logging
import sys
import time
import traceback
-from typing import Dict, Optional, Set
+from typing import Dict, Iterable, Optional, Set
import yaml
@@ -48,6 +48,7 @@ from synapse.storage.databases.main.media_repository import (
MediaRepositoryBackgroundUpdateStore,
)
from synapse.storage.databases.main.profile import ProfileStore
+from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart,
@@ -71,7 +72,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = {
"events": ["processed", "outlier", "contains_url"],
- "rooms": ["is_public"],
+ "rooms": ["is_public", "has_auth_chain_index"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
"presence_stream": ["currently_active"],
@@ -179,6 +180,7 @@ class Store(
UserDirectoryBackgroundUpdateStore,
EndToEndKeyBackgroundStore,
StatsStore,
+ PusherWorkerStore,
):
def execute(self, f, *args, **kwargs):
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
@@ -631,6 +633,13 @@ class Porter(object):
await self._setup_state_group_id_seq()
await self._setup_user_id_seq()
await self._setup_events_stream_seqs()
+ await self._setup_sequence(
+ "device_inbox_sequence", ("device_inbox", "device_federation_outbox")
+ )
+ await self._setup_sequence(
+ "account_data_sequence", ("room_account_data", "room_tags_revisions", "account_data"))
+ await self._setup_sequence("receipts_sequence", ("receipts_linearized", ))
+ await self._setup_auth_chain_sequence()
# Step 3. Get tables.
self.progress.set_state("Fetching tables")
@@ -855,7 +864,7 @@ class Porter(object):
return done, remaining + done
- async def _setup_state_group_id_seq(self):
+ async def _setup_state_group_id_seq(self) -> None:
curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
)
@@ -869,7 +878,7 @@ class Porter(object):
await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
- async def _setup_user_id_seq(self):
+ async def _setup_user_id_seq(self) -> None:
curr_id = await self.sqlite_store.db_pool.runInteraction(
"setup_user_id_seq", find_max_generated_user_id_localpart
)
@@ -878,9 +887,9 @@ class Porter(object):
next_id = curr_id + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
- return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
+ await self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
- async def _setup_events_stream_seqs(self):
+ async def _setup_events_stream_seqs(self) -> None:
"""Set the event stream sequences to the correct values.
"""
@@ -909,10 +918,47 @@ class Porter(object):
(curr_backward_id + 1,),
)
- return await self.postgres_store.db_pool.runInteraction(
+ await self.postgres_store.db_pool.runInteraction(
"_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
)
+ async def _setup_sequence(self, sequence_name: str, stream_id_tables: Iterable[str]) -> None:
+ """Set a sequence to the correct value.
+ """
+ current_stream_ids = []
+ for stream_id_table in stream_id_tables:
+ max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+ table=stream_id_table,
+ keyvalues={},
+ retcol="COALESCE(MAX(stream_id), 1)",
+ allow_none=True,
+ )
+ current_stream_ids.append(max_stream_id)
+
+ next_id = max(current_stream_ids) + 1
+
+ def r(txn):
+ sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name, )
+ txn.execute(sql + " %s", (next_id, ))
+
+ await self.postgres_store.db_pool.runInteraction("_setup_%s" % (sequence_name,), r)
+
+ async def _setup_auth_chain_sequence(self) -> None:
+ curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+ table="event_auth_chains", keyvalues={}, retcol="MAX(chain_id)", allow_none=True
+ )
+
+ def r(txn):
+ txn.execute(
+ "ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
+ (curr_chain_id,),
+ )
+
+ await self.postgres_store.db_pool.runInteraction(
+ "_setup_event_auth_chain_id", r,
+ )
+
+
##############################################
# The following is simply UI stuff
diff --git a/setup.cfg b/setup.cfg
index f46e43fad0..7329eed213 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -3,6 +3,7 @@ test_suite = tests
[check-manifest]
ignore =
+ .git-blame-ignore-revs
contrib
contrib/*
docs/*
@@ -17,7 +18,8 @@ ignore =
# E203: whitespace before ':' (which is contrary to pep8?)
# E731: do not assign a lambda expression, use a def
# E501: Line too long (black enforces this for us)
-ignore=W503,W504,E203,E731,E501
+# B00*: Subsection of the bugbear suite (TODO: add in remaining fixes)
+ignore=W503,W504,E203,E731,E501,B006,B007,B008
[isort]
line_length = 88
diff --git a/setup.py b/setup.py
index 9730afb41b..29e9971dc1 100755
--- a/setup.py
+++ b/setup.py
@@ -96,13 +96,14 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
#
# We pin black so that our tests don't start failing on new releases.
CONDITIONAL_REQUIREMENTS["lint"] = [
- "isort==5.0.3",
- "black==19.10b0",
+ "isort==5.7.0",
+ "black==20.8b1",
"flake8-comprehensions",
+ "flake8-bugbear==21.3.2",
"flake8",
]
-CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
+CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.13"]
# Dependencies which are exclusively required by unit test code. This is
# NOT a list of all modules that are necessary to run the unit tests.
@@ -121,6 +122,7 @@ setup(
include_package_data=True,
zip_safe=False,
long_description=long_description,
+ long_description_content_type="text/x-rst",
python_requires="~=3.5",
classifiers=[
"Development Status :: 5 - Production/Stable",
diff --git a/stubs/frozendict.pyi b/stubs/frozendict.pyi
index 3f3af59f26..0368ba4703 100644
--- a/stubs/frozendict.pyi
+++ b/stubs/frozendict.pyi
@@ -15,16 +15,7 @@
# Stub for frozendict.
-from typing import (
- Any,
- Hashable,
- Iterable,
- Iterator,
- Mapping,
- overload,
- Tuple,
- TypeVar,
-)
+from typing import Any, Hashable, Iterable, Iterator, Mapping, Tuple, TypeVar, overload
_KT = TypeVar("_KT", bound=Hashable) # Key type.
_VT = TypeVar("_VT") # Value type.
diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi
index 68779f968e..0eaef00498 100644
--- a/stubs/sortedcontainers/sorteddict.pyi
+++ b/stubs/sortedcontainers/sorteddict.pyi
@@ -7,17 +7,17 @@ from typing import (
Callable,
Dict,
Hashable,
- Iterator,
- Iterable,
ItemsView,
+ Iterable,
+ Iterator,
KeysView,
List,
Mapping,
Optional,
Sequence,
+ Tuple,
Type,
TypeVar,
- Tuple,
Union,
ValuesView,
overload,
@@ -89,12 +89,16 @@ class SortedDict(Dict[_KT, _VT]):
def __reduce__(
self,
) -> Tuple[
- Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]],
+ Type[SortedDict[_KT, _VT]],
+ Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]],
]: ...
def __repr__(self) -> str: ...
def _check(self) -> None: ...
def islice(
- self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
+ self,
+ start: Optional[int] = ...,
+ stop: Optional[int] = ...,
+ reverse=bool,
) -> Iterator[_KT]: ...
def bisect_left(self, value: _KT) -> int: ...
def bisect_right(self, value: _KT) -> int: ...
diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi
index 8f6086b3ff..f80a3a72ce 100644
--- a/stubs/sortedcontainers/sortedlist.pyi
+++ b/stubs/sortedcontainers/sortedlist.pyi
@@ -31,7 +31,9 @@ class SortedList(MutableSequence[_T]):
DEFAULT_LOAD_FACTOR: int = ...
def __init__(
- self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ...,
+ self,
+ iterable: Optional[Iterable[_T]] = ...,
+ key: Optional[_Key[_T]] = ...,
): ...
# NB: currently mypy does not honour return type, see mypy #3307
@overload
@@ -76,10 +78,18 @@ class SortedList(MutableSequence[_T]):
def __len__(self) -> int: ...
def reverse(self) -> None: ...
def islice(
- self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
+ self,
+ start: Optional[int] = ...,
+ stop: Optional[int] = ...,
+ reverse=bool,
) -> Iterator[_T]: ...
def _islice(
- self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool,
+ self,
+ min_pos: int,
+ min_idx: int,
+ max_pos: int,
+ max_idx: int,
+ reverse: bool,
) -> Iterator[_T]: ...
def irange(
self,
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 522244bb57..080ca40287 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -15,13 +15,25 @@
"""Contains *incomplete* type hints for txredisapi.
"""
+from typing import Any, List, Optional, Type, Union
-from typing import List, Optional, Union, Type
+from twisted.internet import protocol
-class RedisProtocol:
+class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ...
+ async def ping(self) -> None: ...
+ async def set(
+ self,
+ key: str,
+ value: Any,
+ expire: Optional[int] = None,
+ pexpire: Optional[int] = None,
+ only_if_not_exists: bool = False,
+ only_if_exists: bool = False,
+ ) -> None: ...
+ async def get(self, key: str) -> Any: ...
-class SubscriberProtocol:
+class SubscriberProtocol(RedisProtocol):
def __init__(self, *args, **kwargs): ...
password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ...
@@ -40,14 +52,13 @@ def lazyConnection(
convertNumbers: bool = ...,
) -> RedisProtocol: ...
-class SubscriberFactory:
- def buildProtocol(self, addr): ...
-
class ConnectionHandler: ...
-class RedisFactory:
+class RedisFactory(protocol.ReconnectingClientFactory):
continueTrying: bool
handler: RedisProtocol
+ pool: List[RedisProtocol]
+ replyTimeout: Optional[int]
def __init__(
self,
uuid: str,
@@ -60,3 +71,7 @@ class RedisFactory:
replyTimeout: Optional[int] = None,
convertNumbers: Optional[int] = True,
): ...
+ def buildProtocol(self, addr) -> RedisProtocol: ...
+
+class SubscriberFactory(RedisFactory):
+ def __init__(self): ...
diff --git a/synapse/__init__.py b/synapse/__init__.py
index f2d3ac68eb..1d2883acf6 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.24.0"
+__version__ = "1.31.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index bc73982dc2..26cb1bc657 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,7 +23,7 @@ from twisted.web.server import Request
import synapse.types
from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -31,11 +31,15 @@ from synapse.api.errors import (
MissingClientTokenError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.appservice import ApplicationService
from synapse.events import EventBase
+from synapse.http import get_request_user_agent
+from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -160,12 +164,12 @@ class Auth:
async def get_user_by_req(
self,
- request: Request,
+ request: SynapseRequest,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
) -> synapse.types.Requester:
- """ Get a registered user's ID.
+ """Get a registered user's ID.
Args:
request: An HTTP request with an access_token query parameter.
@@ -184,8 +188,8 @@ class Auth:
AuthError if access is denied for the user in the access token
"""
try:
- ip_addr = self.hs.get_ip_from_request(request)
- user_agent = request.get_user_agent("")
+ ip_addr = request.getClientIP()
+ user_agent = get_request_user_agent(request)
access_token = self.get_access_token_from_request(request)
@@ -274,7 +278,7 @@ class Auth:
return None, None
if app_service.ip_range_whitelist:
- ip_address = IPAddress(self.hs.get_ip_from_request(request))
+ ip_address = IPAddress(request.getClientIP())
if ip_address not in app_service.ip_range_whitelist:
return None, None
@@ -296,9 +300,12 @@ class Auth:
return user_id, app_service
async def get_user_by_access_token(
- self, token: str, rights: str = "access", allow_expired: bool = False,
+ self,
+ token: str,
+ rights: str = "access",
+ allow_expired: bool = False,
) -> TokenLookupResult:
- """ Validate access token and get user_id from it
+ """Validate access token and get user_id from it
Args:
token: The access token to get the user by
@@ -407,7 +414,7 @@ class Auth:
raise _InvalidMacaroonException()
try:
- user_id = self.get_user_id_from_macaroon(macaroon)
+ user_id = get_value_from_macaroon(macaroon, "user_id")
guest = False
for caveat in macaroon.caveats:
@@ -415,7 +422,12 @@ class Auth:
guest = True
self.validate_macaroon(macaroon, rights, user_id=user_id)
- except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+ except (
+ pymacaroons.exceptions.MacaroonException,
+ KeyError,
+ TypeError,
+ ValueError,
+ ):
raise InvalidClientTokenError("Invalid macaroon passed.")
if rights == "access":
@@ -423,27 +435,6 @@ class Auth:
return user_id, guest
- def get_user_id_from_macaroon(self, macaroon):
- """Retrieve the user_id given by the caveats on the macaroon.
-
- Does *not* validate the macaroon.
-
- Args:
- macaroon (pymacaroons.Macaroon): The macaroon to validate
-
- Returns:
- (str) user id
-
- Raises:
- InvalidClientCredentialsError if there is no user_id caveat in the
- macaroon
- """
- user_prefix = "user_id = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(user_prefix):
- return caveat.caveat_id[len(user_prefix) :]
- raise InvalidClientTokenError("No user caveat in macaroon")
-
def validate_macaroon(self, macaroon, type_string, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@@ -464,22 +455,14 @@ class Auth:
v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true")
- v.satisfy_general(self._verify_expiry)
+ satisfy_expiry(v, self.clock.time_msec)
# access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.verify(macaroon, self._macaroon_secret_key)
- def _verify_expiry(self, caveat):
- prefix = "time < "
- if not caveat.startswith(prefix):
- return False
- expiry = int(caveat[len(prefix) :])
- now = self.hs.get_clock().time_msec()
- return now < expiry
-
- def get_appservice_by_req(self, request):
+ def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
if not service:
@@ -491,7 +474,7 @@ class Auth:
return service
async def is_server_admin(self, user: UserID) -> bool:
- """ Check if the given user is a local server admin.
+ """Check if the given user is a local server admin.
Args:
user: user to check
@@ -502,7 +485,10 @@ class Auth:
return await self.store.is_server_admin(user)
def compute_auth_events(
- self, event, current_state_ids: StateMap[str], for_verification: bool = False,
+ self,
+ event,
+ current_state_ids: StateMap[str],
+ for_verification: bool = False,
) -> List[str]:
"""Given an event and current state return the list of event IDs used
to auth an event.
@@ -577,6 +563,9 @@ class Auth:
Returns:
bool: False if no access_token was given, True otherwise.
"""
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
query_params = request.args.get(b"access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
@@ -593,6 +582,8 @@ class Auth:
MissingClientTokenError: If there isn't a single access_token in the
request
"""
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
query_params = request.args.get(b"access_token")
@@ -651,7 +642,8 @@ class Auth:
)
if (
visibility
- and visibility.content["history_visibility"] == "world_readable"
+ and visibility.content.get("history_visibility")
+ == HistoryVisibility.WORLD_READABLE
):
return Membership.JOIN, None
raise AuthError(
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index 9c227218e0..d8088f524a 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -36,6 +36,7 @@ class AuthBlocking:
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
self._server_name = hs.hostname
+ self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
async def check_auth_blocking(
self,
@@ -76,6 +77,12 @@ class AuthBlocking:
# We never block the server from doing actions on behalf of
# users.
return
+ elif requester.app_service and not self._track_appservice_user_ips:
+ # If we're authenticated as an appservice then we only block
+ # auth if `track_appservice_user_ips` is set, as that option
+ # implicitly means that application services are part of MAU
+ # limits.
+ return
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 592abd844b..8f37d2cf3b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -27,6 +27,11 @@ MAX_ALIAS_LENGTH = 255
# the maximum length for a user id is 255 characters
MAX_USERID_LENGTH = 255
+# The maximum length for a group id is 255 characters
+MAX_GROUPID_LENGTH = 255
+MAX_GROUP_CATEGORYID_LENGTH = 255
+MAX_GROUP_ROLEID_LENGTH = 255
+
class Membership:
@@ -46,6 +51,7 @@ class PresenceState:
OFFLINE = "offline"
UNAVAILABLE = "unavailable"
ONLINE = "online"
+ BUSY = "org.matrix.msc3026.busy"
class JoinRules:
@@ -93,7 +99,15 @@ class EventTypes:
Retention = "m.room.retention"
+ Dummy = "org.matrix.dummy_event"
+
+ MSC1772_SPACE_CHILD = "org.matrix.msc1772.space.child"
+ MSC1772_SPACE_PARENT = "org.matrix.msc1772.space.parent"
+
+
+class EduTypes:
Presence = "m.presence"
+ RoomKeyRequest = "m.room_key_request"
class RejectedReason:
@@ -126,8 +140,7 @@ class UserTypes:
class RelationTypes:
- """The types of relations known to this server.
- """
+ """The types of relations known to this server."""
ANNOTATION = "m.annotation"
REPLACE = "m.replace"
@@ -151,6 +164,9 @@ class EventContentFields:
# cf https://github.com/matrix-org/matrix-doc/pull/2228
SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
+ # cf https://github.com/matrix-org/matrix-doc/pull/1772
+ MSC1772_ROOM_TYPE = "org.matrix.msc1772.type"
+
class RoomEncryptionAlgorithms:
MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
@@ -160,3 +176,10 @@ class RoomEncryptionAlgorithms:
class AccountDataTypes:
DIRECT = "m.direct"
IGNORED_USER_LIST = "m.ignored_user_list"
+
+
+class HistoryVisibility:
+ INVITED = "invited"
+ JOINED = "joined"
+ SHARED = "shared"
+ WORLD_READABLE = "world_readable"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 90bb01f414..a71e518f90 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -391,8 +391,7 @@ class InvalidCaptchaError(SynapseError):
class LimitExceededError(SynapseError):
- """A client has sent too many requests and is being throttled.
- """
+ """A client has sent too many requests and is being throttled."""
def __init__(
self,
@@ -409,8 +408,7 @@ class LimitExceededError(SynapseError):
class RoomKeysVersionError(SynapseError):
- """A client has tried to upload to a non-current version of the room_keys store
- """
+ """A client has tried to upload to a non-current version of the room_keys store"""
def __init__(self, current_version: str):
"""
@@ -427,7 +425,9 @@ class UnsupportedRoomVersionError(SynapseError):
def __init__(self, msg: str = "Homeserver does not support this room version"):
super().__init__(
- code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION,
+ code=400,
+ msg=msg,
+ errcode=Codes.UNSUPPORTED_ROOM_VERSION,
)
@@ -462,8 +462,7 @@ class IncompatibleRoomVersionError(SynapseError):
class PasswordRefusedError(SynapseError):
- """A password has been refused, either during password reset/change or registration.
- """
+ """A password has been refused, either during password reset/change or registration."""
def __init__(
self,
@@ -471,7 +470,9 @@ class PasswordRefusedError(SynapseError):
errcode: str = Codes.WEAK_PASSWORD,
):
super().__init__(
- code=400, msg=msg, errcode=errcode,
+ code=400,
+ msg=msg,
+ errcode=errcode,
)
@@ -494,7 +495,7 @@ class RequestSendFailed(RuntimeError):
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
- """ Utility method for constructing an error response for client-server
+ """Utility method for constructing an error response for client-server
interactions.
Args:
@@ -511,7 +512,7 @@ def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
class FederationError(RuntimeError):
- """ This class is used to inform remote homeservers about erroneous
+ """This class is used to inform remote homeservers about erroneous
PDUs they sent us.
FATAL: The remote server could not interpret the source event.
diff --git a/synapse/api/presence.py b/synapse/api/presence.py
index 18a462f0ee..b9a8e29460 100644
--- a/synapse/api/presence.py
+++ b/synapse/api/presence.py
@@ -56,8 +56,7 @@ class UserPresenceState(
@classmethod
def default(cls, user_id):
- """Returns a default presence state.
- """
+ """Returns a default presence state."""
return cls(
user_id=user_id,
state=PresenceState.OFFLINE,
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 5d9d5a228f..c3f07bc1a3 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -14,7 +14,7 @@
# limitations under the License.
from collections import OrderedDict
-from typing import Any, Optional, Tuple
+from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError
from synapse.types import Requester
@@ -42,7 +42,9 @@ class Ratelimiter:
# * How many times an action has occurred since a point in time
# * The point in time
# * The rate_hz of this particular entry. This can vary per request
- self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]]
+ self.actions = (
+ OrderedDict()
+ ) # type: OrderedDict[Hashable, Tuple[float, int, float]]
def can_requester_do_action(
self,
@@ -82,7 +84,7 @@ class Ratelimiter:
def can_do_action(
self,
- key: Any,
+ key: Hashable,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
@@ -175,7 +177,7 @@ class Ratelimiter:
def ratelimit(
self,
- key: Any,
+ key: Hashable,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 1cdb153f63..139fbf5524 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -51,11 +51,11 @@ class RoomDisposition:
class RoomVersion:
"""An object which describes the unique attributes of a room version."""
- identifier = attr.ib() # str; the identifier for this version
- disposition = attr.ib() # str; one of the RoomDispositions
- event_format = attr.ib() # int; one of the EventFormatVersions
- state_res = attr.ib() # int; one of the StateResolutionVersions
- enforce_key_validity = attr.ib() # bool
+ identifier = attr.ib(type=str) # the identifier for this version
+ disposition = attr.ib(type=str) # one of the RoomDispositions
+ event_format = attr.ib(type=int) # one of the EventFormatVersions
+ state_res = attr.ib(type=int) # one of the StateResolutionVersions
+ enforce_key_validity = attr.ib(type=bool)
# Before MSC2432, m.room.aliases had special auth rules and redaction rules
special_case_aliases_auth = attr.ib(type=bool)
@@ -64,11 +64,13 @@ class RoomVersion:
# * Floats
# * NaN, Infinity, -Infinity
strict_canonicaljson = attr.ib(type=bool)
- # bool: MSC2209: Check 'notifications' key while verifying
+ # MSC2209: Check 'notifications' key while verifying
# m.room.power_levels auth rules.
limit_notifications_power_levels = attr.ib(type=bool)
# MSC2174/MSC2176: Apply updated redaction rules algorithm.
msc2176_redaction_rules = attr.ib(type=bool)
+ # MSC2174/MSC2176: Apply updated redaction rules algorithm.
+ msc2176_redaction_rules = attr.ib(type=bool)
# MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending
# m.room.membership event with membership 'knock'.
allow_knocking = attr.ib(type=bool)
@@ -183,5 +185,6 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V5,
RoomVersions.V6,
RoomVersions.V7,
+ RoomVersions.MSC2176,
)
} # type: Dict[str, RoomVersion]
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index a01bac2997..d1a2cd5e19 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -17,14 +17,14 @@ import sys
from synapse import python_dependencies # noqa: E402
-sys.dont_write_bytecode = True
-
logger = logging.getLogger(__name__)
try:
python_dependencies.check_requirements()
except python_dependencies.DependencyException as e:
- sys.stderr.writelines(e.message)
+ sys.stderr.writelines(
+ e.message # noqa: B306, DependencyException.message is a property
+ )
sys.exit(1)
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 895b38ae76..3912c8994c 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,12 +16,15 @@
import gc
import logging
import os
+import platform
import signal
import socket
import sys
import traceback
-from typing import Iterable
+import warnings
+from typing import Awaitable, Callable, Iterable
+from cryptography.utils import CryptographyDeprecationWarning
from typing_extensions import NoReturn
from twisted.internet import defer, error, reactor
@@ -56,7 +60,7 @@ def register_sighup(func, *args, **kwargs):
def start_worker_reactor(appname, config, run_command=reactor.run):
- """ Run the reactor in the main process
+ """Run the reactor in the main process
Daemonizes if necessary, and then configures some resources, before starting
the reactor. Pulls configuration from the 'worker' settings in 'config'.
@@ -91,7 +95,7 @@ def start_reactor(
logger,
run_command=reactor.run,
):
- """ Run the reactor in the main process
+ """Run the reactor in the main process
Daemonizes if necessary, and then configures some resources, before starting
the reactor
@@ -143,6 +147,45 @@ def quit_with_error(error_string: str) -> NoReturn:
sys.exit(1)
+def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
+ """Register a callback with the reactor, to be called once it is running
+
+ This can be used to initialise parts of the system which require an asynchronous
+ setup.
+
+ Any exception raised by the callback will be printed and logged, and the process
+ will exit.
+ """
+
+ async def wrapper():
+ try:
+ await cb(*args, **kwargs)
+ except Exception:
+ # previously, we used Failure().printTraceback() here, in the hope that
+ # would give better tracebacks than traceback.print_exc(). However, that
+ # doesn't handle chained exceptions (with a __cause__ or __context__) well,
+ # and I *think* the need for Failure() is reduced now that we mostly use
+ # async/await.
+
+ # Write the exception to both the logs *and* the unredirected stderr,
+ # because people tend to get confused if it only goes to one or the other.
+ #
+ # One problem with this is that if people are using a logging config that
+ # logs to the console (as is common eg under docker), they will get two
+ # copies of the exception. We could maybe try to detect that, but it's
+ # probably a cost we can bear.
+ logger.fatal("Error during startup", exc_info=True)
+ print("Error during startup:", file=sys.__stderr__)
+ traceback.print_exc(file=sys.__stderr__)
+
+ # it's no use calling sys.exit here, since that just raises a SystemExit
+ # exception which is then caught by the reactor, and everything carries
+ # on as normal.
+ os._exit(1)
+
+ reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
+
+
def listen_metrics(bind_addresses, port):
"""
Start Prometheus metrics server.
@@ -154,6 +197,25 @@ def listen_metrics(bind_addresses, port):
start_http_server(port, addr=host, registry=RegistryProxy)
+def listen_manhole(bind_addresses: Iterable[str], port: int, manhole_globals: dict):
+ # twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing
+ # warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so
+ # suppress the warning for now.
+ warnings.filterwarnings(
+ action="ignore",
+ category=CryptographyDeprecationWarning,
+ message="int_from_bytes is deprecated",
+ )
+
+ from synapse.util.manhole import manhole
+
+ listen_tcp(
+ bind_addresses,
+ port,
+ manhole(username="matrix", password="rabbithole", globals=manhole_globals),
+ )
+
+
def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
"""
Create a TCP socket for a port and several addresses
@@ -227,7 +289,7 @@ def refresh_certificate(hs):
logger.info("Context factories updated.")
-def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
+async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
"""
Start a Synapse server or worker.
@@ -241,71 +303,65 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
hs: homeserver instance
listeners: Listener configuration ('listeners' in homeserver.yaml)
"""
- try:
- # Set up the SIGHUP machinery.
- if hasattr(signal, "SIGHUP"):
+ # Set up the SIGHUP machinery.
+ if hasattr(signal, "SIGHUP"):
+ reactor = hs.get_reactor()
- @wrap_as_background_process("sighup")
- def handle_sighup(*args, **kwargs):
- # Tell systemd our state, if we're using it. This will silently fail if
- # we're not using systemd.
- sdnotify(b"RELOADING=1")
+ @wrap_as_background_process("sighup")
+ def handle_sighup(*args, **kwargs):
+ # Tell systemd our state, if we're using it. This will silently fail if
+ # we're not using systemd.
+ sdnotify(b"RELOADING=1")
- for i, args, kwargs in _sighup_callbacks:
- i(*args, **kwargs)
+ for i, args, kwargs in _sighup_callbacks:
+ i(*args, **kwargs)
- sdnotify(b"READY=1")
+ sdnotify(b"READY=1")
- # We defer running the sighup handlers until next reactor tick. This
- # is so that we're in a sane state, e.g. flushing the logs may fail
- # if the sighup happens in the middle of writing a log entry.
- def run_sighup(*args, **kwargs):
- hs.get_clock().call_later(0, handle_sighup, *args, **kwargs)
+ # We defer running the sighup handlers until next reactor tick. This
+ # is so that we're in a sane state, e.g. flushing the logs may fail
+ # if the sighup happens in the middle of writing a log entry.
+ def run_sighup(*args, **kwargs):
+ # `callFromThread` should be "signal safe" as well as thread
+ # safe.
+ reactor.callFromThread(handle_sighup, *args, **kwargs)
- signal.signal(signal.SIGHUP, run_sighup)
+ signal.signal(signal.SIGHUP, run_sighup)
- register_sighup(refresh_certificate, hs)
+ register_sighup(refresh_certificate, hs)
- # Load the certificate from disk.
- refresh_certificate(hs)
+ # Load the certificate from disk.
+ refresh_certificate(hs)
- # Start the tracer
- synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
- hs
- )
+ # Start the tracer
+ synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
- # It is now safe to start your Synapse.
- hs.start_listening(listeners)
- hs.get_datastore().db_pool.start_profiling()
- hs.get_pusherpool().start()
+ # It is now safe to start your Synapse.
+ hs.start_listening(listeners)
+ hs.get_datastore().db_pool.start_profiling()
+ hs.get_pusherpool().start()
- # Log when we start the shut down process.
- hs.get_reactor().addSystemEventTrigger(
- "before", "shutdown", logger.info, "Shutting down..."
- )
+ # Log when we start the shut down process.
+ hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", logger.info, "Shutting down..."
+ )
- setup_sentry(hs)
- setup_sdnotify(hs)
-
- # If background tasks are running on the main process, start collecting the
- # phone home stats.
- if hs.config.run_background_tasks:
- start_phone_stats_home(hs)
-
- # We now freeze all allocated objects in the hopes that (almost)
- # everything currently allocated are things that will be used for the
- # rest of time. Doing so means less work each GC (hopefully).
- #
- # This only works on Python 3.7
- if sys.version_info >= (3, 7):
- gc.collect()
- gc.freeze()
- except Exception:
- traceback.print_exc(file=sys.stderr)
- reactor = hs.get_reactor()
- if reactor.running:
- reactor.stop()
- sys.exit(1)
+ setup_sentry(hs)
+ setup_sdnotify(hs)
+
+ # If background tasks are running on the main process, start collecting the
+ # phone home stats.
+ if hs.config.run_background_tasks:
+ start_phone_stats_home(hs)
+
+ # We now freeze all allocated objects in the hopes that (almost)
+ # everything currently allocated are things that will be used for the
+ # rest of time. Doing so means less work each GC (hopefully).
+ #
+ # This only works on Python 3.7
+ if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7):
+ gc.collect()
+ gc.freeze()
def setup_sentry(hs):
@@ -333,8 +389,7 @@ def setup_sentry(hs):
def setup_sdnotify(hs):
- """Adds process state hooks to tell systemd what we are up to.
- """
+ """Adds process state hooks to tell systemd what we are up to."""
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
@@ -368,8 +423,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
class _LimitedHostnameResolver:
- """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.
- """
+ """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups."""
def __init__(self, resolver, max_dns_requests_in_flight):
self._resolver = resolver
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index b4bd4d8e7a..9f99651aa2 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -210,7 +210,9 @@ def start(config_options):
config.update_user_directory = False
config.run_background_tasks = False
config.start_pushers = False
+ config.pusher_shard_config.instances = []
config.send_federation = False
+ config.federation_shard_config.instances = []
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 82ec325a13..b2d21acefd 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -21,7 +21,9 @@ from typing import Dict, Iterable, Optional, Set
from typing_extensions import ContextManager
-from twisted.internet import address, reactor
+from twisted.internet import address
+from twisted.web.resource import IResource
+from twisted.web.server import Request
import synapse
import synapse.events
@@ -34,6 +36,7 @@ from synapse.api.urls import (
SERVER_KEY_V2_PREFIX,
)
from synapse.app import _base
+from synapse.app._base import register_start
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
@@ -89,45 +92,47 @@ from synapse.replication.tcp.streams import (
ToDeviceStream,
)
from synapse.rest.admin import register_servlets_for_media_repo
-from synapse.rest.client.v1 import events
+from synapse.rest.client.v1 import events, login, room
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
-from synapse.rest.client.v1.login import LoginRestServlet
from synapse.rest.client.v1.profile import (
ProfileAvatarURLRestServlet,
ProfileDisplaynameRestServlet,
ProfileRestServlet,
)
from synapse.rest.client.v1.push_rule import PushRuleRestServlet
-from synapse.rest.client.v1.room import (
- JoinedRoomMemberListRestServlet,
- JoinRoomAliasServlet,
- PublicRoomListRestServlet,
- RoomEventContextServlet,
- RoomInitialSyncRestServlet,
- RoomMemberListRestServlet,
- RoomMembershipRestServlet,
- RoomMessageListRestServlet,
- RoomSendEventRestServlet,
- RoomStateEventRestServlet,
- RoomStateRestServlet,
- RoomTypingRestServlet,
-)
from synapse.rest.client.v1.voip import VoipRestServlet
-from synapse.rest.client.v2_alpha import groups, sync, user_directory
+from synapse.rest.client.v2_alpha import (
+ account_data,
+ groups,
+ read_marker,
+ receipts,
+ room_keys,
+ sync,
+ tags,
+ user_directory,
+)
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
from synapse.rest.client.v2_alpha.account_data import (
AccountDataServlet,
RoomAccountDataServlet,
)
-from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
+from synapse.rest.client.v2_alpha.devices import DevicesRestServlet
+from synapse.rest.client.v2_alpha.keys import (
+ KeyChangesServlet,
+ KeyQueryServlet,
+ OneTimeKeyServlet,
+)
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer, cache_in_self
from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
+from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
from synapse.storage.databases.main.media_repository import MediaRepositoryStore
from synapse.storage.databases.main.metrics import ServerMetricsStore
from synapse.storage.databases.main.monthly_active_users import (
@@ -142,7 +147,6 @@ from synapse.storage.databases.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.generic_worker")
@@ -186,7 +190,7 @@ class KeyUploadServlet(RestServlet):
self.http_client = hs.get_simple_http_client()
self.main_uri = hs.config.worker_main_http_uri
- async def on_POST(self, request, device_id):
+ async def on_POST(self, request: Request, device_id: Optional[str]):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -219,10 +223,12 @@ class KeyUploadServlet(RestServlet):
header: request.requestHeaders.getRawHeaders(header, [])
for header in (b"Authorization", b"User-Agent")
}
- # Add the previous hop the the X-Forwarded-For header.
+ # Add the previous hop to the X-Forwarded-For header.
x_forwarded_for = request.requestHeaders.getRawHeaders(
b"X-Forwarded-For", []
)
+ # we use request.client here, since we want the previous hop, not the
+ # original client (as returned by request.getClientAddress()).
if isinstance(request.client, (address.IPv4Address, address.IPv6Address)):
previous_host = request.client.host.encode("ascii")
# If the header exists, add to the comma-separated list of the first
@@ -235,6 +241,14 @@ class KeyUploadServlet(RestServlet):
x_forwarded_for = [previous_host]
headers[b"X-Forwarded-For"] = x_forwarded_for
+ # Replicate the original X-Forwarded-Proto header. Note that
+ # XForwardedForRequest overrides isSecure() to give us the original protocol
+ # used by the client, as opposed to the protocol used by our upstream proxy
+ # - which is what we want here.
+ headers[b"X-Forwarded-Proto"] = [
+ b"https" if request.isSecure() else b"http"
+ ]
+
try:
result = await self.http_client.post_json_get_json(
self.main_uri + request.uri.decode("ascii"), body, headers=headers
@@ -287,6 +301,8 @@ class GenericWorkerPresence(BasePresenceHandler):
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
)
+ self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
+
hs.get_reactor().addSystemEventTrigger(
"before",
"shutdown",
@@ -417,16 +433,19 @@ class GenericWorkerPresence(BasePresenceHandler):
]
async def set_state(self, target_user, state, ignore_status_msg=False):
- """Set the presence state of the user.
- """
+ """Set the presence state of the user."""
presence = state["presence"]
valid_presence = (
PresenceState.ONLINE,
PresenceState.UNAVAILABLE,
PresenceState.OFFLINE,
+ PresenceState.BUSY,
)
- if presence not in valid_presence:
+
+ if presence not in valid_presence or (
+ presence == PresenceState.BUSY and not self._busy_presence_enabled
+ ):
raise SynapseError(400, "Invalid presence state")
user_id = target_user.to_string()
@@ -459,6 +478,7 @@ class GenericWorkerSlavedStore(
UserDirectoryStore,
StatsStore,
UIAuthWorkerStore,
+ EndToEndRoomKeyStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedReceiptsStore,
@@ -503,7 +523,7 @@ class GenericWorkerServer(HomeServer):
site_tag = port
# We always include a health resource.
- resources = {"/health": HealthResource()}
+ resources = {"/health": HealthResource()} # type: Dict[str, IResource]
for res in listener_config.http_options.resources:
for name in res.names:
@@ -512,36 +532,36 @@ class GenericWorkerServer(HomeServer):
elif name == "client":
resource = JsonResource(self, canonical_json=False)
- PublicRoomListRestServlet(self).register(resource)
- RoomMemberListRestServlet(self).register(resource)
- JoinedRoomMemberListRestServlet(self).register(resource)
- RoomStateRestServlet(self).register(resource)
- RoomEventContextServlet(self).register(resource)
- RoomMessageListRestServlet(self).register(resource)
RegisterRestServlet(self).register(resource)
- LoginRestServlet(self).register(resource)
+ login.register_servlets(self, resource)
ThreepidRestServlet(self).register(resource)
+ DevicesRestServlet(self).register(resource)
KeyQueryServlet(self).register(resource)
+ OneTimeKeyServlet(self).register(resource)
KeyChangesServlet(self).register(resource)
VoipRestServlet(self).register(resource)
PushRuleRestServlet(self).register(resource)
VersionsRestServlet(self).register(resource)
- RoomSendEventRestServlet(self).register(resource)
- RoomMembershipRestServlet(self).register(resource)
- RoomStateEventRestServlet(self).register(resource)
- JoinRoomAliasServlet(self).register(resource)
+
ProfileAvatarURLRestServlet(self).register(resource)
ProfileDisplaynameRestServlet(self).register(resource)
ProfileRestServlet(self).register(resource)
KeyUploadServlet(self).register(resource)
AccountDataServlet(self).register(resource)
RoomAccountDataServlet(self).register(resource)
- RoomTypingRestServlet(self).register(resource)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
+ room.register_servlets(self, resource, True)
+ room.register_deprecated_servlets(self, resource)
InitialSyncRestServlet(self).register(resource)
- RoomInitialSyncRestServlet(self).register(resource)
+ room_keys.register_servlets(self, resource)
+ tags.register_servlets(self, resource)
+ account_data.register_servlets(self, resource)
+ receipts.register_servlets(self, resource)
+ read_marker.register_servlets(self, resource)
+
+ SendToDeviceRestServlet(self).register(resource)
user_directory.register_servlets(self, resource)
@@ -553,6 +573,8 @@ class GenericWorkerServer(HomeServer):
groups.register_servlets(self, resource)
resources.update({CLIENT_API_PREFIX: resource})
+
+ resources.update(build_synapse_client_resource_tree(self))
elif name == "federation":
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
elif name == "media":
@@ -617,12 +639,8 @@ class GenericWorkerServer(HomeServer):
if listener.type == "http":
self._listen_http(listener)
elif listener.type == "manhole":
- _base.listen_tcp(
- listener.bind_addresses,
- listener.port,
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
+ _base.listen_manhole(
+ listener.bind_addresses, listener.port, manhole_globals={"hs": self}
)
elif listener.type == "metrics":
if not self.get_config().enable_metrics:
@@ -639,9 +657,6 @@ class GenericWorkerServer(HomeServer):
self.get_tcp_replication().start_replication(self)
- async def remove_pusher(self, app_id, push_key, user_id):
- self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
-
@cache_in_self
def get_replication_data_handler(self):
return GenericWorkerReplicationHandler(self)
@@ -772,13 +787,6 @@ class FederationSenderHandler:
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
- def on_start(self):
- # There may be some events that are persisted but haven't been sent,
- # so send them now.
- self.federation_sender.notify_new_events(
- self.store.get_room_max_stream_ordering()
- )
-
def wake_destination(self, server: str):
self.federation_sender.wake_destination(server)
@@ -916,22 +924,6 @@ def start(config_options):
# For other worker types we force this to off.
config.appservice.notify_appservices = False
- if config.worker_app == "synapse.app.pusher":
- if config.server.start_pushers:
- sys.stderr.write(
- "\nThe pushers must be disabled in the main synapse process"
- "\nbefore they can be run in a separate worker."
- "\nPlease add ``start_pushers: false`` to the main config"
- "\n"
- )
- sys.exit(1)
-
- # Force the pushers to start since they will be disabled in the main config
- config.server.start_pushers = True
- else:
- # For other worker types we force this to off.
- config.server.start_pushers = False
-
if config.worker_app == "synapse.app.user_dir":
if config.server.update_user_directory:
sys.stderr.write(
@@ -948,22 +940,6 @@ def start(config_options):
# For other worker types we force this to off.
config.server.update_user_directory = False
- if config.worker_app == "synapse.app.federation_sender":
- if config.worker.send_federation:
- sys.stderr.write(
- "\nThe send_federation must be disabled in the main synapse process"
- "\nbefore they can be run in a separate worker."
- "\nPlease add ``send_federation: false`` to the main config"
- "\n"
- )
- sys.exit(1)
-
- # Force the pushers to start since they will be disabled in the main config
- config.worker.send_federation = True
- else:
- # For other worker types we force this to off.
- config.worker.send_federation = False
-
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
hs = GenericWorkerServer(
@@ -980,9 +956,7 @@ def start(config_options):
# streams. Will no-op if no streams can be written to by this worker.
hs.get_replication_streamer()
- reactor.addSystemEventTrigger(
- "before", "startup", _base.start, hs, config.worker_listeners
- )
+ register_start(_base.start, hs, config.worker_listeners)
_base.start_worker_reactor("synapse-generic-worker", config)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 2b5465417f..3bfe9d507f 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -15,15 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
import logging
import os
import sys
-from typing import Iterable
+from typing import Iterable, Iterator
-from twisted.application import service
-from twisted.internet import defer, reactor
-from twisted.python.failure import Failure
+from twisted.internet import reactor
from twisted.web.resource import EncodingResourceWrapper, IResource
from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File
@@ -40,7 +37,7 @@ from synapse.api.urls import (
WEB_CLIENT_PREFIX,
)
from synapse.app import _base
-from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
+from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
from synapse.config._base import ConfigError
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
@@ -63,15 +60,14 @@ from synapse.rest import ClientRestResource
from synapse.rest.admin import AdminRestResource
from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
from synapse.util.module_loader import load_module
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.homeserver")
@@ -90,7 +86,7 @@ class SynapseHomeServer(HomeServer):
tls = listener_config.tls
site_tag = listener_config.http_options.tag
if site_tag is None:
- site_tag = port
+ site_tag = str(port)
# We always include a health resource.
resources = {"/health": HealthResource()}
@@ -107,7 +103,10 @@ class SynapseHomeServer(HomeServer):
logger.debug("Configuring additional resources: %r", additional_resources)
module_api = self.get_module_api()
for path, resmodule in additional_resources.items():
- handler_cls, config = load_module(resmodule)
+ handler_cls, config = load_module(
+ resmodule,
+ ("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
+ )
handler = handler_cls(config, module_api)
if IResource.providedBy(handler):
resource = handler
@@ -189,19 +188,10 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/versions": client_resource,
"/.well-known/matrix/client": WellKnownResource(self),
"/_synapse/admin": AdminRestResource(self),
+ **build_synapse_client_resource_tree(self),
}
)
- if self.get_config().oidc_enabled:
- from synapse.rest.oidc import OIDCResource
-
- resources["/_synapse/oidc"] = OIDCResource(self)
-
- if self.get_config().saml2_enabled:
- from synapse.rest.saml2 import SAML2Resource
-
- resources["/_matrix/saml2"] = SAML2Resource(self)
-
if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL:
from synapse.rest.synapse.client.password_reset import (
PasswordResetSubmitTokenResource,
@@ -297,12 +287,8 @@ class SynapseHomeServer(HomeServer):
if listener.type == "http":
self._listening_services.extend(self._listener_http(config, listener))
elif listener.type == "manhole":
- listen_tcp(
- listener.bind_addresses,
- listener.port,
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
+ _base.listen_manhole(
+ listener.bind_addresses, listener.port, manhole_globals={"hs": self}
)
elif listener.type == "replication":
services = listen_tcp(
@@ -342,7 +328,10 @@ def setup(config_options):
"Synapse Homeserver", config_options
)
except ConfigError as e:
- sys.stderr.write("\nERROR: %s\n" % (e,))
+ sys.stderr.write("\n")
+ for f in format_config_error(e):
+ sys.stderr.write(f)
+ sys.stderr.write("\n")
sys.exit(1)
if not config:
@@ -407,61 +396,62 @@ def setup(config_options):
_base.refresh_certificate(hs)
async def start():
- try:
- # Run the ACME provisioning code, if it's enabled.
- if hs.config.acme_enabled:
- acme = hs.get_acme_handler()
- # Start up the webservices which we will respond to ACME
- # challenges with, and then provision.
- await acme.start_listening()
- await do_acme()
+ # Run the ACME provisioning code, if it's enabled.
+ if hs.config.acme_enabled:
+ acme = hs.get_acme_handler()
+ # Start up the webservices which we will respond to ACME
+ # challenges with, and then provision.
+ await acme.start_listening()
+ await do_acme()
- # Check if it needs to be reprovisioned every day.
- hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
+ # Check if it needs to be reprovisioned every day.
+ hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
- # Load the OIDC provider metadatas, if OIDC is enabled.
- if hs.config.oidc_enabled:
- oidc = hs.get_oidc_handler()
- # Loading the provider metadata also ensures the provider config is valid.
- await oidc.load_metadata()
- await oidc.load_jwks()
+ # Load the OIDC provider metadatas, if OIDC is enabled.
+ if hs.config.oidc_enabled:
+ oidc = hs.get_oidc_handler()
+ # Loading the provider metadata also ensures the provider config is valid.
+ await oidc.load_metadata()
- _base.start(hs, config.listeners)
+ await _base.start(hs, config.listeners)
- hs.get_datastore().db_pool.updates.start_doing_background_updates()
- except Exception:
- # Print the exception and bail out.
- print("Error during startup:", file=sys.stderr)
+ hs.get_datastore().db_pool.updates.start_doing_background_updates()
- # this gives better tracebacks than traceback.print_exc()
- Failure().printTraceback(file=sys.stderr)
-
- if reactor.running:
- reactor.stop()
- sys.exit(1)
-
- reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
+ register_start(start)
return hs
-class SynapseService(service.Service):
+def format_config_error(e: ConfigError) -> Iterator[str]:
"""
- A twisted Service class that will start synapse. Used to run synapse
- via twistd and a .tac.
+ Formats a config error neatly
+
+ The idea is to format the immediate error, plus the "causes" of those errors,
+ hopefully in a way that makes sense to the user. For example:
+
+ Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
+ Failed to parse config for module 'JinjaOidcMappingProvider':
+ invalid jinja template:
+ unexpected end of template, expected 'end of print statement'.
+
+ Args:
+ e: the error to be formatted
+
+ Returns: An iterator which yields string fragments to be formatted
"""
+ yield "Error in configuration"
- def __init__(self, config):
- self.config = config
+ if e.path:
+ yield " at '%s'" % (".".join(e.path),)
- def startService(self):
- hs = setup(self.config)
- change_resource_limit(hs.config.soft_file_limit)
- if hs.config.gc_thresholds:
- gc.set_threshold(*hs.config.gc_thresholds)
+ yield ":\n %s" % (e.msg,)
- def stopService(self):
- return self._port.stopListening()
+ e = e.__cause__
+ indent = 1
+ while e:
+ indent += 1
+ yield ":\n%s%s" % (" " * indent, str(e))
+ e = e.__cause__
def run(hs):
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index c38cf8231f..8f86cecb76 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -93,15 +93,20 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+ daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms()
+ stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
+ stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages()
+ daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages()
+ stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
+ daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
+ stats["daily_sent_messages"] = daily_sent_messages
r30_results = await hs.get_datastore().count_r30_users()
for name, count in r30_results.items():
stats["r30_users_" + name] = count
- daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
- stats["daily_sent_messages"] = daily_sent_messages
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 3944780a42..0bfc5e445f 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -166,7 +166,10 @@ class ApplicationService:
@cached(num_args=1, cache_context=True)
async def matches_user_in_member_list(
- self, room_id: str, store: "DataStore", cache_context: _CacheContext,
+ self,
+ room_id: str,
+ store: "DataStore",
+ cache_context: _CacheContext,
) -> bool:
"""Check if this service is interested a room based upon it's membership
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 5dcd9ea290..8c5d178e05 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -76,9 +76,6 @@ def _is_valid_3pe_result(r, field):
fields = r["fields"]
if not isinstance(fields, dict):
return False
- for k in fields.keys():
- if not isinstance(fields[k], str):
- return False
return True
@@ -93,7 +90,7 @@ class ApplicationServiceApi(SimpleHttpClient):
self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(
- hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
+ hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
) # type: ResponseCache[Tuple[str, str]]
async def query_user(self, service, user_id):
@@ -230,7 +227,9 @@ class ApplicationServiceApi(SimpleHttpClient):
try:
await self.put_json(
- uri=uri, json_body=body, args={"access_token": service.hs_token},
+ uri=uri,
+ json_body=body,
+ args={"access_token": service.hs_token},
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 58291afc22..366c476f80 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -68,7 +68,7 @@ MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
class ApplicationServiceScheduler:
- """ Public facing API for this module. Does the required DI to tie the
+ """Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
case is a simple array.
"""
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 33a917a035..ba9cd63cf2 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,22 +18,31 @@
import argparse
import errno
import os
-import time
-import urllib.parse
from collections import OrderedDict
from hashlib import sha256
-from io import open as io_open
from textwrap import dedent
-from typing import Any, Callable, List, MutableMapping, Optional
+from typing import Any, Iterable, List, MutableMapping, Optional, Union
import attr
import jinja2
import pkg_resources
import yaml
+from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter
+
class ConfigError(Exception):
- pass
+ """Represents a problem parsing the configuration
+
+ Args:
+ msg: A textual description of the error.
+ path: Where appropriate, an indication of where in the configuration
+ the problem lies.
+ """
+
+ def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
+ self.msg = msg
+ self.path = path
# We split these messages out to allow packages to override with package
@@ -138,7 +147,20 @@ class Config:
return int(value) * size
@staticmethod
- def parse_duration(value):
+ def parse_duration(value: Union[str, int]) -> int:
+ """Convert a duration as a string or integer to a number of milliseconds.
+
+ If an integer is provided it is treated as milliseconds and is unchanged.
+
+ String durations can have a suffix of 's', 'm', 'h', 'd', 'w', or 'y'.
+ No suffix is treated as milliseconds.
+
+ Args:
+ value: The duration to parse.
+
+ Returns:
+ The number of milliseconds in the duration.
+ """
if isinstance(value, int):
return value
second = 1000
@@ -190,15 +212,33 @@ class Config:
@classmethod
def read_file(cls, file_path, config_name):
- cls.check_file(file_path, config_name)
- with io_open(file_path, encoding="utf-8") as file_stream:
- return file_stream.read()
+ """Deprecated: call read_file directly"""
+ return read_file(file_path, (config_name,))
+
+ def read_template(self, filename: str) -> jinja2.Template:
+ """Load a template file from disk.
+
+ This function will attempt to load the given template from the default Synapse
+ template directory.
+
+ Files read are treated as Jinja templates. The templates is not rendered yet
+ and has autoescape enabled.
+
+ Args:
+ filename: A template filename to read.
+
+ Raises:
+ ConfigError: if the file's path is incorrect or otherwise cannot be read.
+
+ Returns:
+ A jinja2 template.
+ """
+ return self.read_templates([filename])[0]
def read_templates(
self,
filenames: List[str],
custom_template_directory: Optional[str] = None,
- autoescape: bool = False,
) -> List[jinja2.Template]:
"""Load a list of template files from disk using the given variables.
@@ -206,7 +246,8 @@ class Config:
template directory. If `custom_template_directory` is supplied, that directory
is tried first.
- Files read are treated as Jinja templates. These templates are not rendered yet.
+ Files read are treated as Jinja templates. The templates are not rendered yet
+ and have autoescape enabled.
Args:
filenames: A list of template filenames to read.
@@ -214,16 +255,12 @@ class Config:
custom_template_directory: A directory to try to look for the templates
before using the default Synapse template directory instead.
- autoescape: Whether to autoescape variables before inserting them into the
- template.
-
Raises:
ConfigError: if the file's path is incorrect or otherwise cannot be read.
Returns:
A list of jinja2 templates.
"""
- templates = []
search_directories = [self.default_template_dir]
# The loader will first look in the custom template directory (if specified) for the
@@ -239,54 +276,23 @@ class Config:
# Search the custom template directory as well
search_directories.insert(0, custom_template_directory)
+ # TODO: switch to synapse.util.templates.build_jinja_env
loader = jinja2.FileSystemLoader(search_directories)
- env = jinja2.Environment(loader=loader, autoescape=autoescape)
+ env = jinja2.Environment(
+ loader=loader,
+ autoescape=jinja2.select_autoescape(),
+ )
# Update the environment with our custom filters
- env.filters.update({"format_ts": _format_ts_filter})
- if self.public_baseurl:
- env.filters.update(
- {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)}
- )
-
- for filename in filenames:
- # Load the template
- template = env.get_template(filename)
- templates.append(template)
-
- return templates
-
-
-def _format_ts_filter(value: int, format: str):
- return time.strftime(format, time.localtime(value / 1000))
-
-
-def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
- """Create and return a jinja2 filter that converts MXC urls to HTTP
-
- Args:
- public_baseurl: The public, accessible base URL of the homeserver
- """
-
- def mxc_to_http_filter(value, width, height, resize_method="crop"):
- if value[0:6] != "mxc://":
- return ""
-
- server_and_media_id = value[6:]
- fragment = None
- if "#" in server_and_media_id:
- server_and_media_id, fragment = server_and_media_id.split("#", 1)
- fragment = "#" + fragment
-
- params = {"width": width, "height": height, "method": resize_method}
- return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
- public_baseurl,
- server_and_media_id,
- urllib.parse.urlencode(params),
- fragment or "",
+ env.filters.update(
+ {
+ "format_ts": _format_ts_filter,
+ "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
+ }
)
- return mxc_to_http_filter
+ # Load the templates
+ return [env.get_template(filename) for filename in filenames]
class RootConfig:
@@ -836,24 +842,24 @@ class ShardedWorkerHandlingConfig:
instances = attr.ib(type=List[str])
def should_handle(self, instance_name: str, key: str) -> bool:
- """Whether this instance is responsible for handling the given key.
- """
- # If multiple instances are not defined we always return true
- if not self.instances or len(self.instances) == 1:
- return True
+ """Whether this instance is responsible for handling the given key."""
+ # If no instances are defined we assume some other worker is handling
+ # this.
+ if not self.instances:
+ return False
- return self.get_instance(key) == instance_name
+ return self._get_instance(key) == instance_name
- def get_instance(self, key: str) -> str:
+ def _get_instance(self, key: str) -> str:
"""Get the instance responsible for handling the given key.
- Note: For things like federation sending the config for which instance
- is sending is known only to the sender instance if there is only one.
- Therefore `should_handle` should be used where possible.
+ Note: For federation sending and pushers the config for which instance
+ is sending is known only to the sender instance, so we don't expose this
+ method by default.
"""
if not self.instances:
- return "master"
+ raise Exception("Unknown worker")
if len(self.instances) == 1:
return self.instances[0]
@@ -870,4 +876,52 @@ class ShardedWorkerHandlingConfig:
return self.instances[remainder]
-__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
+@attr.s
+class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
+ """A version of `ShardedWorkerHandlingConfig` that is used for config
+ options where all instances know which instances are responsible for the
+ sharded work.
+ """
+
+ def __attrs_post_init__(self):
+ # We require that `self.instances` is non-empty.
+ if not self.instances:
+ raise Exception("Got empty list of instances for shard config")
+
+ def get_instance(self, key: str) -> str:
+ """Get the instance responsible for handling the given key."""
+ return self._get_instance(key)
+
+
+def read_file(file_path: Any, config_path: Iterable[str]) -> str:
+ """Check the given file exists, and read it into a string
+
+ If it does not, emit an error indicating the problem
+
+ Args:
+ file_path: the file to be read
+ config_path: where in the configuration file_path came from, so that a useful
+ error can be emitted if it does not exist.
+ Returns:
+ content of the file.
+ Raises:
+ ConfigError if there is a problem reading the file.
+ """
+ if not isinstance(file_path, str):
+ raise ConfigError("%r is not a string", config_path)
+
+ try:
+ os.stat(file_path)
+ with open(file_path) as file_stream:
+ return file_stream.read()
+ except OSError as e:
+ raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
+
+
+__all__ = [
+ "Config",
+ "RootConfig",
+ "ShardedWorkerHandlingConfig",
+ "RoutableShardedWorkerHandlingConfig",
+ "read_file",
+]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 3124135c13..ddec356a07 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,9 +1,10 @@
-from typing import Any, List, Optional
+from typing import Any, Iterable, List, Optional
from synapse.config import (
account_validity,
api,
appservice,
+ auth,
captcha,
cas,
consent_config,
@@ -16,10 +17,10 @@ from synapse.config import (
logger,
metrics,
oidc_config,
- password,
password_auth_providers,
push,
ratelimiting,
+ redis,
registration,
repository,
room_directory,
@@ -37,7 +38,10 @@ from synapse.config import (
workers,
)
-class ConfigError(Exception): ...
+class ConfigError(Exception):
+ def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
+ self.msg = msg
+ self.path = path
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
MISSING_REPORT_STATS_SPIEL: str
@@ -51,7 +55,7 @@ class RootConfig:
tls: tls.TlsConfig
database: database.DatabaseConfig
logging: logger.LoggingConfig
- ratelimit: ratelimiting.RatelimitConfig
+ ratelimiting: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig
captcha: captcha.CaptchaConfig
voip: voip.VoipConfig
@@ -66,7 +70,7 @@ class RootConfig:
sso: sso.SSOConfig
oidc: oidc_config.OIDCConfig
jwt: jwt_config.JWTConfig
- password: password.PasswordConfig
+ auth: auth.AuthConfig
email: emailconfig.EmailConfig
worker: workers.WorkerConfig
authproviders: password_auth_providers.PasswordAuthProviderConfig
@@ -80,6 +84,7 @@ class RootConfig:
roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig
+ redis: redis.RedisConfig
config_classes: List = ...
def __init__(self) -> None: ...
@@ -146,4 +151,8 @@ class ShardedWorkerHandlingConfig:
instances: List[str]
def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ...
+
+class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
def get_instance(self, key: str) -> str: ...
+
+def read_file(file_path: Any, config_path: Iterable[str]) -> str: ...
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index c74969a977..8fce7f6bb1 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -38,14 +38,27 @@ def validate_config(
try:
jsonschema.validate(config, json_schema)
except jsonschema.ValidationError as e:
- # copy `config_path` before modifying it.
- path = list(config_path)
- for p in list(e.path):
- if isinstance(p, int):
- path.append("<item %i>" % p)
- else:
- path.append(str(p))
-
- raise ConfigError(
- "Unable to parse configuration: %s at %s" % (e.message, ".".join(path))
- )
+ raise json_error_to_config_error(e, config_path)
+
+
+def json_error_to_config_error(
+ e: jsonschema.ValidationError, config_path: Iterable[str]
+) -> ConfigError:
+ """Converts a json validation error to a user-readable ConfigError
+
+ Args:
+ e: the exception to be converted
+ config_path: the path within the config file. This will be used as a basis
+ for the error message.
+
+ Returns:
+ a ConfigError
+ """
+ # copy `config_path` before modifying it.
+ path = list(config_path)
+ for p in list(e.absolute_path):
+ if isinstance(p, int):
+ path.append("<item %i>" % p)
+ else:
+ path.append(str(p))
+ return ConfigError(e.message, path)
diff --git a/synapse/config/password.py b/synapse/config/auth.py
index 6b2dae78b0..9aabaadf9e 100644
--- a/synapse/config/password.py
+++ b/synapse/config/auth.py
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2015-2016 OpenMarket Ltd
-# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,11 +17,10 @@
from ._base import Config
-class PasswordConfig(Config):
- """Password login configuration
- """
+class AuthConfig(Config):
+ """Password and login configuration"""
- section = "password"
+ section = "auth"
def read_config(self, config, **kwargs):
password_config = config.get("password_config", {})
@@ -37,6 +35,12 @@ class PasswordConfig(Config):
self.password_policy = password_config.get("policy") or {}
self.password_policy_enabled = self.password_policy.get("enabled", False)
+ # User-interactive authentication
+ ui_auth = config.get("ui_auth") or {}
+ self.ui_auth_session_timeout = self.parse_duration(
+ ui_auth.get("session_timeout", 0)
+ )
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
password_config:
@@ -89,4 +93,19 @@ class PasswordConfig(Config):
# Defaults to 'false'.
#
#require_uppercase: true
+
+ ui_auth:
+ # The amount of time to allow a user-interactive authentication session
+ # to be active.
+ #
+ # This defaults to 0, meaning the user is queried for their credentials
+ # before every action, but this can be overridden to allow a single
+ # validation to be re-used. This weakens the protections afforded by
+ # the user-interactive authentication process, by allowing for multiple
+ # (and potentially different) operations to use the same validation session.
+ #
+ # Uncomment below to allow for credential validation to last for 15
+ # seconds.
+ #
+ #session_timeout: "15s"
"""
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 8e03f14005..4e8abbf88a 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -24,7 +24,7 @@ from ._base import Config, ConfigError
_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
# Map from canonicalised cache name to cache.
-_CACHES = {}
+_CACHES = {} # type: Dict[str, Callable[[float], None]]
# a lock on the contents of _CACHES
_CACHES_LOCK = threading.Lock()
@@ -59,7 +59,9 @@ def _canonicalise_cache_name(cache_name: str) -> str:
return cache_name.lower()
-def add_resizable_cache(cache_name: str, cache_resize_callback: Callable):
+def add_resizable_cache(
+ cache_name: str, cache_resize_callback: Callable[[float], None]
+):
"""Register a cache that's size can dynamically change
Args:
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index cb00958165..9e48f865cc 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -28,9 +28,7 @@ class CaptchaConfig(Config):
"recaptcha_siteverify_api",
"https://www.recaptcha.net/recaptcha/api/siteverify",
)
- self.recaptcha_template = self.read_templates(
- ["recaptcha.html"], autoescape=True
- )[0]
+ self.recaptcha_template = self.read_template("recaptcha.html")
def generate_config_section(self, **kwargs):
return """\
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 2f97e6d258..dbf5085965 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -13,7 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config
+from typing import Any, List
+
+from synapse.config.sso import SsoAttributeRequirement
+
+from ._base import Config, ConfigError
+from ._util import validate_config
class CasConfig(Config):
@@ -30,17 +35,29 @@ class CasConfig(Config):
if self.cas_enabled:
self.cas_server_url = cas_config["server_url"]
- self.cas_service_url = cas_config["service_url"]
+
+ # The public baseurl is required because it is used by the redirect
+ # template.
+ public_baseurl = self.public_baseurl
+ if not public_baseurl:
+ raise ConfigError("cas_config requires a public_baseurl to be set")
+
+ # TODO Update this to a _synapse URL.
+ self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket"
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
- self.cas_required_attributes = cas_config.get("required_attributes") or {}
+ required_attributes = cas_config.get("required_attributes") or {}
+ self.cas_required_attributes = _parsed_required_attributes_def(
+ required_attributes
+ )
+
else:
self.cas_server_url = None
self.cas_service_url = None
self.cas_displayname_attribute = None
- self.cas_required_attributes = {}
+ self.cas_required_attributes = []
def generate_config_section(self, config_dir_path, server_name, **kwargs):
- return """
+ return """\
# Enable Central Authentication Service (CAS) for registration and login.
#
cas_config:
@@ -53,10 +70,6 @@ class CasConfig(Config):
#
#server_url: "https://cas-server.com"
- # The public URL of the homeserver.
- #
- #service_url: "https://homeserver.domain.com:8448"
-
# The attribute of the CAS response to use as the display name.
#
# If unset, no displayname will be set.
@@ -73,3 +86,22 @@ class CasConfig(Config):
# userGroup: "staff"
# department: None
"""
+
+
+# CAS uses a legacy required attributes mapping, not the one provided by
+# SsoAttributeRequirement.
+REQUIRED_ATTRIBUTES_SCHEMA = {
+ "type": "object",
+ "additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]},
+}
+
+
+def _parsed_required_attributes_def(
+ required_attributes: Any,
+) -> List[SsoAttributeRequirement]:
+ validate_config(
+ REQUIRED_ATTRIBUTES_SCHEMA,
+ required_attributes,
+ config_path=("cas_config", "required_attributes"),
+ )
+ return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()]
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index 6efa59b110..c47f364b14 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -89,7 +89,7 @@ class ConsentConfig(Config):
def read_config(self, config, **kwargs):
consent_config = config.get("user_consent")
- self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0]
+ self.terms_template = self.read_template("terms.html")
if consent_config is None:
return
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 8a18a9ca2a..e7889b9c20 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -207,8 +207,7 @@ class DatabaseConfig(Config):
)
def get_single_database(self) -> DatabaseConnectionConfig:
- """Returns the database if there is only one, useful for e.g. tests
- """
+ """Returns the database if there is only one, useful for e.g. tests"""
if not self.databases:
raise Exception("More than one database exists")
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index ad96c55798..5431691831 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -289,7 +289,8 @@ class EmailConfig(Config):
self.email_notif_template_html,
self.email_notif_template_text,
) = self.read_templates(
- [notif_template_html, notif_template_text], template_dir,
+ [notif_template_html, notif_template_text],
+ template_dir,
)
self.email_notif_for_new_users = email_config.get(
@@ -311,7 +312,8 @@ class EmailConfig(Config):
self.account_validity_template_html,
self.account_validity_template_text,
) = self.read_templates(
- [expiry_template_html, expiry_template_text], template_dir,
+ [expiry_template_html, expiry_template_text],
+ template_dir,
)
subjects_config = email_config.get("subjects", {})
@@ -322,6 +324,22 @@ class EmailConfig(Config):
self.email_subjects = EmailSubjectConfig(**subjects)
+ # The invite client location should be a HTTP(S) URL or None.
+ self.invite_client_location = email_config.get("invite_client_location") or None
+ if self.invite_client_location:
+ if not isinstance(self.invite_client_location, str):
+ raise ConfigError(
+ "Config option email.invite_client_location must be type str"
+ )
+ if not (
+ self.invite_client_location.startswith("http://")
+ or self.invite_client_location.startswith("https://")
+ ):
+ raise ConfigError(
+ "Config option email.invite_client_location must be a http or https URL",
+ path=("email", "invite_client_location"),
+ )
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return (
"""\
@@ -389,10 +407,15 @@ class EmailConfig(Config):
#
#validation_token_lifetime: 15m
- # Directory in which Synapse will try to find the template files below.
- # If not set, default templates from within the Synapse package will be used.
+ # The web client location to direct users to during an invite. This is passed
+ # to the identity server as the org.matrix.web_client_location key. Defaults
+ # to unset, giving no guidance to the identity server.
#
- # Do not uncomment this setting unless you want to customise the templates.
+ #invite_client_location: https://app.element.io
+
+ # Directory in which Synapse will try to find the template files below.
+ # If not set, or the files named below are not found within the template
+ # directory, default templates from within the Synapse package will be used.
#
# Synapse will look for the following templates in this directory:
#
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 1ea11422af..001bddc6f6 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -31,3 +31,10 @@ class ExperimentalConfig(Config):
if self.msc2403_enabled:
# Enable the MSC2403 unstable room version
KNOWN_ROOM_VERSIONS.update({RoomVersions.V7.identifier: RoomVersions.V7})
+
+ # MSC2858 (multiple SSO identity providers)
+ self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
+ # Spaces (MSC1772, MSC2946, etc)
+ self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool
+ # MSC3026 (busy presence state)
+ self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index 27ccf61c3c..55e4db5442 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -12,12 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Optional
-from netaddr import IPSet
-
-from synapse.config._base import Config, ConfigError
+from synapse.config._base import Config
from synapse.config._util import validate_config
@@ -36,31 +33,6 @@ class FederationConfig(Config):
for domain in federation_domain_whitelist:
self.federation_domain_whitelist[domain] = True
- ip_range_blacklist = config.get("ip_range_blacklist", [])
-
- # Attempt to create an IPSet from the given ranges
- try:
- self.ip_range_blacklist = IPSet(ip_range_blacklist)
- except Exception as e:
- raise ConfigError("Invalid range(s) provided in ip_range_blacklist: %s" % e)
- # Always blacklist 0.0.0.0, ::
- self.ip_range_blacklist.update(["0.0.0.0", "::"])
-
- # The federation_ip_range_blacklist is used for backwards-compatibility
- # and only applies to federation and identity servers. If it is not given,
- # default to ip_range_blacklist.
- federation_ip_range_blacklist = config.get(
- "federation_ip_range_blacklist", ip_range_blacklist
- )
- try:
- self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist)
- except Exception as e:
- raise ConfigError(
- "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
- )
- # Always blacklist 0.0.0.0, ::
- self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
-
federation_metrics_domains = config.get("federation_metrics_domains") or []
validate_config(
_METRICS_FOR_DOMAINS_SCHEMA,
@@ -69,6 +41,10 @@ class FederationConfig(Config):
)
self.federation_metrics_domains = set(federation_metrics_domains)
+ self.allow_profile_lookup_over_federation = config.get(
+ "allow_profile_lookup_over_federation", True
+ )
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
## Federation ##
@@ -84,29 +60,6 @@ class FederationConfig(Config):
# - nyc.example.com
# - syd.example.com
- # Prevent outgoing requests from being sent to the following blacklisted IP address
- # CIDR ranges. If this option is not specified, or specified with an empty list,
- # no IP range blacklist will be enforced.
- #
- # The blacklist applies to the outbound requests for federation, identity servers,
- # push servers, and for checking key validitity for third-party invite events.
- #
- # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
- # listed here, since they correspond to unroutable addresses.)
- #
- # This option replaces federation_ip_range_blacklist in Synapse v1.24.0.
- #
- ip_range_blacklist:
- - '127.0.0.0/8'
- - '10.0.0.0/8'
- - '172.16.0.0/12'
- - '192.168.0.0/16'
- - '100.64.0.0/10'
- - '169.254.0.0/16'
- - '::1/128'
- - 'fe80::/64'
- - 'fc00::/7'
-
# Report prometheus metrics on the age of PDUs being sent to and received from
# the following domains. This can be used to give an idea of "delay" on inbound
# and outbound federation, though be aware that any delay can be due to problems
@@ -117,6 +70,12 @@ class FederationConfig(Config):
#federation_metrics_domains:
# - matrix.org
# - example.com
+
+ # Uncomment to disable profile lookup over federation. By default, the
+ # Federation API allows other homeservers to obtain profile data of any user
+ # on this homeserver. Defaults to 'true'.
+ #
+ #allow_profile_lookup_over_federation: false
"""
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
index d6862d9a64..7b7860ea71 100644
--- a/synapse/config/groups.py
+++ b/synapse/config/groups.py
@@ -32,5 +32,5 @@ class GroupsConfig(Config):
# If enabled, non server admins can only create groups with local parts
# starting with this prefix
#
- #group_creation_prefix: "unofficial/"
+ #group_creation_prefix: "unofficial_"
"""
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 09def49197..58961679ff 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -17,6 +17,7 @@ from ._base import RootConfig
from .account_validity import AccountValidityConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
+from .auth import AuthConfig
from .cache import CacheConfig
from .captcha import CaptchaConfig
from .cas import CasConfig
@@ -31,7 +32,6 @@ from .key import KeyConfig
from .logger import LoggingConfig
from .metrics import MetricsConfig
from .oidc_config import OIDCConfig
-from .password import PasswordConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .push import PushConfig
from .ratelimiting import RatelimitConfig
@@ -79,7 +79,7 @@ class HomeServerConfig(RootConfig):
CasConfig,
SSOConfig,
JWTConfig,
- PasswordConfig,
+ AuthConfig,
EmailConfig,
PasswordAuthProviderConfig,
PushConfig,
diff --git a/synapse/config/key.py b/synapse/config/key.py
index de964dff13..350ff1d665 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -404,7 +404,11 @@ def _parse_key_servers(key_servers, federation_verify_certificates):
try:
jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA)
except jsonschema.ValidationError as e:
- raise ConfigError("Unable to parse 'trusted_key_servers': " + e.message)
+ raise ConfigError(
+ "Unable to parse 'trusted_key_servers': {}".format(
+ e.message # noqa: B306, jsonschema.ValidationError.message is a valid attribute
+ )
+ )
for server in key_servers:
server_name = server["server_name"]
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index d4e887a3e0..999aecce5c 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -21,8 +21,10 @@ import threading
from string import Template
import yaml
+from zope.interface import implementer
from twisted.logger import (
+ ILogObserver,
LogBeginner,
STDLibLogObserver,
eventAsText,
@@ -162,7 +164,10 @@ class LoggingConfig(Config):
)
logging_group.add_argument(
- "-f", "--log-file", dest="log_file", help=argparse.SUPPRESS,
+ "-f",
+ "--log-file",
+ dest="log_file",
+ help=argparse.SUPPRESS,
)
def generate_files(self, config, config_dir_path):
@@ -206,7 +211,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
# filter options, but care must when using e.g. MemoryHandler to buffer
# writes.
- log_context_filter = LoggingContextFilter(request="")
+ log_context_filter = LoggingContextFilter()
log_metadata_filter = MetadataFilter({"server_name": config.server_name})
old_factory = logging.getLogRecordFactory()
@@ -224,7 +229,8 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
threadlocal = threading.local()
- def _log(event):
+ @implementer(ILogObserver)
+ def _log(event: dict) -> None:
if "log_text" in event:
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
return
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index dfd27e1523..2b289f4208 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -56,7 +56,9 @@ class MetricsConfig(Config):
try:
check_requirements("sentry")
except DependencyException as e:
- raise ConfigError(e.message)
+ raise ConfigError(
+ e.message # noqa: B306, DependencyException.message is a property
+ )
self.sentry_dsn = config["sentry"].get("dsn")
if not self.sentry_dsn:
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 69d188341c..05733ec41d 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,10 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from collections import Counter
+from typing import Iterable, List, Mapping, Optional, Tuple, Type
+
+import attr
+
+from synapse.config._util import validate_config
+from synapse.config.sso import SsoAttributeRequirement
from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.types import Collection, JsonDict
from synapse.util.module_loader import load_module
+from synapse.util.stringutils import parse_and_validate_mxc_uri
-from ._base import Config, ConfigError
+from ._base import Config, ConfigError, read_file
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
@@ -25,201 +35,557 @@ class OIDCConfig(Config):
section = "oidc"
def read_config(self, config, **kwargs):
- self.oidc_enabled = False
-
- oidc_config = config.get("oidc_config")
-
- if not oidc_config or not oidc_config.get("enabled", False):
+ self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
+ if not self.oidc_providers:
return
try:
check_requirements("oidc")
except DependencyException as e:
- raise ConfigError(e.message)
+ raise ConfigError(
+ e.message # noqa: B306, DependencyException.message is a property
+ ) from e
+
+ # check we don't have any duplicate idp_ids now. (The SSO handler will also
+ # check for duplicates when the REST listeners get registered, but that happens
+ # after synapse has forked so doesn't give nice errors.)
+ c = Counter([i.idp_id for i in self.oidc_providers])
+ for idp_id, count in c.items():
+ if count > 1:
+ raise ConfigError(
+ "Multiple OIDC providers have the idp_id %r." % idp_id
+ )
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError("oidc_config requires a public_baseurl to be set")
- self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
-
- self.oidc_enabled = True
- self.oidc_discover = oidc_config.get("discover", True)
- self.oidc_issuer = oidc_config["issuer"]
- self.oidc_client_id = oidc_config["client_id"]
- self.oidc_client_secret = oidc_config["client_secret"]
- self.oidc_client_auth_method = oidc_config.get(
- "client_auth_method", "client_secret_basic"
- )
- self.oidc_scopes = oidc_config.get("scopes", ["openid"])
- self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
- self.oidc_token_endpoint = oidc_config.get("token_endpoint")
- self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
- self.oidc_jwks_uri = oidc_config.get("jwks_uri")
- self.oidc_skip_verification = oidc_config.get("skip_verification", False)
- self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
- self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
-
- ump_config = oidc_config.get("user_mapping_provider", {})
- ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
- ump_config.setdefault("config", {})
-
- (
- self.oidc_user_mapping_provider_class,
- self.oidc_user_mapping_provider_config,
- ) = load_module(ump_config)
-
- # Ensure loaded user mapping module has defined all necessary methods
- required_methods = [
- "get_remote_user_id",
- "map_user_attributes",
- ]
- missing_methods = [
- method
- for method in required_methods
- if not hasattr(self.oidc_user_mapping_provider_class, method)
- ]
- if missing_methods:
- raise ConfigError(
- "Class specified by oidc_config."
- "user_mapping_provider.module is missing required "
- "methods: %s" % (", ".join(missing_methods),)
- )
+ self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback"
+
+ @property
+ def oidc_enabled(self) -> bool:
+ # OIDC is enabled if we have a provider
+ return bool(self.oidc_providers)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
- # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
+ # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
+ # and login.
+ #
+ # Options for each entry include:
+ #
+ # idp_id: a unique identifier for this identity provider. Used internally
+ # by Synapse; should be a single word such as 'github'.
+ #
+ # Note that, if this is changed, users authenticating via that provider
+ # will no longer be recognised as the same user!
+ #
+ # (Use "oidc" here if you are migrating from an old "oidc_config"
+ # configuration.)
+ #
+ # idp_name: A user-facing name for this identity provider, which is used to
+ # offer the user a choice of login mechanisms.
+ #
+ # idp_icon: An optional icon for this identity provider, which is presented
+ # by clients and Synapse's own IdP picker page. If given, must be an
+ # MXC URI of the format mxc://<server-name>/<media-id>. (An easy way to
+ # obtain such an MXC URI is to upload an image to an (unencrypted) room
+ # and then copy the "url" from the source of the event.)
+ #
+ # idp_brand: An optional brand for this identity provider, allowing clients
+ # to style the login flow according to the identity provider in question.
+ # See the spec for possible options here.
+ #
+ # discover: set to 'false' to disable the use of the OIDC discovery mechanism
+ # to discover endpoints. Defaults to true.
+ #
+ # issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
+ # is enabled) to discover the provider's endpoints.
+ #
+ # client_id: Required. oauth2 client id to use.
+ #
+ # client_secret: oauth2 client secret to use. May be omitted if
+ # client_secret_jwt_key is given, or if client_auth_method is 'none'.
+ #
+ # client_secret_jwt_key: Alternative to client_secret: details of a key used
+ # to create a JSON Web Token to be used as an OAuth2 client secret. If
+ # given, must be a dictionary with the following properties:
+ #
+ # key: a pem-encoded signing key. Must be a suitable key for the
+ # algorithm specified. Required unless 'key_file' is given.
+ #
+ # key_file: the path to file containing a pem-encoded signing key file.
+ # Required unless 'key' is given.
+ #
+ # jwt_header: a dictionary giving properties to include in the JWT
+ # header. Must include the key 'alg', giving the algorithm used to
+ # sign the JWT, such as "ES256", using the JWA identifiers in
+ # RFC7518.
+ #
+ # jwt_payload: an optional dictionary giving properties to include in
+ # the JWT payload. Normally this should include an 'iss' key.
+ #
+ # client_auth_method: auth method to use when exchanging the token. Valid
+ # values are 'client_secret_basic' (default), 'client_secret_post' and
+ # 'none'.
+ #
+ # scopes: list of scopes to request. This should normally include the "openid"
+ # scope. Defaults to ["openid"].
+ #
+ # authorization_endpoint: the oauth2 authorization endpoint. Required if
+ # provider discovery is disabled.
+ #
+ # token_endpoint: the oauth2 token endpoint. Required if provider discovery is
+ # disabled.
+ #
+ # userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
+ # disabled and the 'openid' scope is not requested.
+ #
+ # jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
+ # the 'openid' scope is used.
+ #
+ # skip_verification: set to 'true' to skip metadata verification. Use this if
+ # you are connecting to a provider that is not OpenID Connect compliant.
+ # Defaults to false. Avoid this in production.
+ #
+ # user_profile_method: Whether to fetch the user profile from the userinfo
+ # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
+ #
+ # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
+ # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
+ # userinfo endpoint.
+ #
+ # allow_existing_users: set to 'true' to allow a user logging in via OIDC to
+ # match a pre-existing account instead of failing. This could be used if
+ # switching from password logins to OIDC. Defaults to false.
+ #
+ # user_mapping_provider: Configuration for how attributes returned from a OIDC
+ # provider are mapped onto a matrix user. This setting has the following
+ # sub-properties:
+ #
+ # module: The class name of a custom mapping module. Default is
+ # {mapping_provider!r}.
+ # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
+ # for information on implementing a custom mapping provider.
+ #
+ # config: Configuration for the mapping provider module. This section will
+ # be passed as a Python dictionary to the user mapping provider
+ # module's `parse_config` method.
+ #
+ # For the default provider, the following settings are available:
+ #
+ # subject_claim: name of the claim containing a unique identifier
+ # for the user. Defaults to 'sub', which OpenID Connect
+ # compliant providers should provide.
+ #
+ # localpart_template: Jinja2 template for the localpart of the MXID.
+ # If this is not set, the user will be prompted to choose their
+ # own username (see 'sso_auth_account_details.html' in the 'sso'
+ # section of this file).
+ #
+ # display_name_template: Jinja2 template for the display name to set
+ # on first login. If unset, no displayname will be set.
+ #
+ # email_template: Jinja2 template for the email address of the user.
+ # If unset, no email address will be added to the account.
+ #
+ # extra_attributes: a map of Jinja2 templates for extra attributes
+ # to send back to the client during login.
+ # Note that these are non-standard and clients will ignore them
+ # without modifications.
+ #
+ # When rendering, the Jinja2 templates are given a 'user' variable,
+ # which is set to the claims returned by the UserInfo Endpoint and/or
+ # in the ID Token.
+ #
+ # It is possible to configure Synapse to only allow logins if certain attributes
+ # match particular values in the OIDC userinfo. The requirements can be listed under
+ # `attribute_requirements` as shown below. All of the listed attributes must
+ # match for the login to be permitted. Additional attributes can be added to
+ # userinfo by expanding the `scopes` section of the OIDC config to retrieve
+ # additional information from the OIDC provider.
+ #
+ # If the OIDC claim is a list, then the attribute must match any value in the list.
+ # Otherwise, it must exactly match the value of the claim. Using the example
+ # below, the `family_name` claim MUST be "Stephensson", but the `groups`
+ # claim MUST contain "admin".
+ #
+ # attribute_requirements:
+ # - attribute: family_name
+ # value: "Stephensson"
+ # - attribute: groups
+ # value: "admin"
#
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
- # for some example configurations.
+ # for information on how to configure these options.
#
- oidc_config:
- # Uncomment the following to enable authorization against an OpenID Connect
- # server. Defaults to false.
- #
- #enabled: true
-
- # Uncomment the following to disable use of the OIDC discovery mechanism to
- # discover endpoints. Defaults to true.
- #
- #discover: false
-
- # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
- # discover the provider's endpoints.
- #
- # Required if 'enabled' is true.
- #
- #issuer: "https://accounts.example.com/"
-
- # oauth2 client id to use.
- #
- # Required if 'enabled' is true.
- #
- #client_id: "provided-by-your-issuer"
-
- # oauth2 client secret to use.
- #
- # Required if 'enabled' is true.
- #
- #client_secret: "provided-by-your-issuer"
-
- # auth method to use when exchanging the token.
- # Valid values are 'client_secret_basic' (default), 'client_secret_post' and
- # 'none'.
+ # For backwards compatibility, it is also possible to configure a single OIDC
+ # provider via an 'oidc_config' setting. This is now deprecated and admins are
+ # advised to migrate to the 'oidc_providers' format. (When doing that migration,
+ # use 'oidc' for the idp_id to ensure that existing users continue to be
+ # recognised.)
+ #
+ oidc_providers:
+ # Generic example
#
- #client_auth_method: client_secret_post
+ #- idp_id: my_idp
+ # idp_name: "My OpenID provider"
+ # idp_icon: "mxc://example.com/mediaid"
+ # discover: false
+ # issuer: "https://accounts.example.com/"
+ # client_id: "provided-by-your-issuer"
+ # client_secret: "provided-by-your-issuer"
+ # client_auth_method: client_secret_post
+ # scopes: ["openid", "profile"]
+ # authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+ # token_endpoint: "https://accounts.example.com/oauth2/token"
+ # userinfo_endpoint: "https://accounts.example.com/userinfo"
+ # jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+ # skip_verification: true
+ # user_mapping_provider:
+ # config:
+ # subject_claim: "id"
+ # localpart_template: "{{{{ user.login }}}}"
+ # display_name_template: "{{{{ user.name }}}}"
+ # email_template: "{{{{ user.email }}}}"
+ # attribute_requirements:
+ # - attribute: userGroup
+ # value: "synapseUsers"
+ """.format(
+ mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
+ )
- # list of scopes to request. This should normally include the "openid" scope.
- # Defaults to ["openid"].
- #
- #scopes: ["openid", "profile"]
- # the oauth2 authorization endpoint. Required if provider discovery is disabled.
- #
- #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+# jsonschema definition of the configuration settings for an oidc identity provider
+OIDC_PROVIDER_CONFIG_SCHEMA = {
+ "type": "object",
+ "required": ["issuer", "client_id"],
+ "properties": {
+ "idp_id": {
+ "type": "string",
+ "minLength": 1,
+ # MSC2858 allows a maxlen of 255, but we prefix with "oidc-"
+ "maxLength": 250,
+ "pattern": "^[A-Za-z0-9._~-]+$",
+ },
+ "idp_name": {"type": "string"},
+ "idp_icon": {"type": "string"},
+ "idp_brand": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 255,
+ "pattern": "^[a-z][a-z0-9_.-]*$",
+ },
+ "idp_unstable_brand": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 255,
+ "pattern": "^[a-z][a-z0-9_.-]*$",
+ },
+ "discover": {"type": "boolean"},
+ "issuer": {"type": "string"},
+ "client_id": {"type": "string"},
+ "client_secret": {"type": "string"},
+ "client_secret_jwt_key": {
+ "type": "object",
+ "required": ["jwt_header"],
+ "oneOf": [
+ {"required": ["key"]},
+ {"required": ["key_file"]},
+ ],
+ "properties": {
+ "key": {"type": "string"},
+ "key_file": {"type": "string"},
+ "jwt_header": {
+ "type": "object",
+ "required": ["alg"],
+ "properties": {
+ "alg": {"type": "string"},
+ },
+ "additionalProperties": {"type": "string"},
+ },
+ "jwt_payload": {
+ "type": "object",
+ "additionalProperties": {"type": "string"},
+ },
+ },
+ },
+ "client_auth_method": {
+ "type": "string",
+ # the following list is the same as the keys of
+ # authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it
+ # to avoid importing authlib here.
+ "enum": ["client_secret_basic", "client_secret_post", "none"],
+ },
+ "scopes": {"type": "array", "items": {"type": "string"}},
+ "authorization_endpoint": {"type": "string"},
+ "token_endpoint": {"type": "string"},
+ "userinfo_endpoint": {"type": "string"},
+ "jwks_uri": {"type": "string"},
+ "skip_verification": {"type": "boolean"},
+ "user_profile_method": {
+ "type": "string",
+ "enum": ["auto", "userinfo_endpoint"],
+ },
+ "allow_existing_users": {"type": "boolean"},
+ "user_mapping_provider": {"type": ["object", "null"]},
+ "attribute_requirements": {
+ "type": "array",
+ "items": SsoAttributeRequirement.JSON_SCHEMA,
+ },
+ },
+}
+
+# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name
+OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
+ "allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
+}
+
+
+# the `oidc_providers` list can either be None (as it is in the default config), or
+# a list of provider configs, each of which requires an explicit ID and name.
+OIDC_PROVIDER_LIST_SCHEMA = {
+ "oneOf": [
+ {"type": "null"},
+ {"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA},
+ ]
+}
+
+# the `oidc_config` setting can either be None (which it used to be in the default
+# config), or an object. If an object, it is ignored unless it has an "enabled: True"
+# property.
+#
+# It's *possible* to represent this with jsonschema, but the resultant errors aren't
+# particularly clear, so we just check for either an object or a null here, and do
+# additional checks in the code.
+OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}
+
+# the top-level schema can contain an "oidc_config" and/or an "oidc_providers".
+MAIN_CONFIG_SCHEMA = {
+ "type": "object",
+ "properties": {
+ "oidc_config": OIDC_CONFIG_SCHEMA,
+ "oidc_providers": OIDC_PROVIDER_LIST_SCHEMA,
+ },
+}
+
+
+def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]:
+ """extract and parse the OIDC provider configs from the config dict
+
+ The configuration may contain either a single `oidc_config` object with an
+ `enabled: True` property, or a list of provider configurations under
+ `oidc_providers`, *or both*.
+
+ Returns a generator which yields the OidcProviderConfig objects
+ """
+ validate_config(MAIN_CONFIG_SCHEMA, config, ())
+
+ for i, p in enumerate(config.get("oidc_providers") or []):
+ yield _parse_oidc_config_dict(p, ("oidc_providers", "<item %i>" % (i,)))
+
+ # for backwards-compatibility, it is also possible to provide a single "oidc_config"
+ # object with an "enabled: True" property.
+ oidc_config = config.get("oidc_config")
+ if oidc_config and oidc_config.get("enabled", False):
+ # MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that
+ # it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA
+ # above), so now we need to validate it.
+ validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
+ yield _parse_oidc_config_dict(oidc_config, ("oidc_config",))
+
+
+def _parse_oidc_config_dict(
+ oidc_config: JsonDict, config_path: Tuple[str, ...]
+) -> "OidcProviderConfig":
+ """Take the configuration dict and parse it into an OidcProviderConfig
+
+ Raises:
+ ConfigError if the configuration is malformed.
+ """
+ ump_config = oidc_config.get("user_mapping_provider", {})
+ ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+ ump_config.setdefault("config", {})
+
+ (
+ user_mapping_provider_class,
+ user_mapping_provider_config,
+ ) = load_module(ump_config, config_path + ("user_mapping_provider",))
+
+ # Ensure loaded user mapping module has defined all necessary methods
+ required_methods = [
+ "get_remote_user_id",
+ "map_user_attributes",
+ ]
+ missing_methods = [
+ method
+ for method in required_methods
+ if not hasattr(user_mapping_provider_class, method)
+ ]
+ if missing_methods:
+ raise ConfigError(
+ "Class %s is missing required "
+ "methods: %s"
+ % (
+ user_mapping_provider_class,
+ ", ".join(missing_methods),
+ ),
+ config_path + ("user_mapping_provider", "module"),
+ )
- # the oauth2 token endpoint. Required if provider discovery is disabled.
- #
- #token_endpoint: "https://accounts.example.com/oauth2/token"
+ idp_id = oidc_config.get("idp_id", "oidc")
- # the OIDC userinfo endpoint. Required if discovery is disabled and the
- # "openid" scope is not requested.
- #
- #userinfo_endpoint: "https://accounts.example.com/userinfo"
+ # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid
+ # clashes with other mechs (such as SAML, CAS).
+ #
+ # We allow "oidc" as an exception so that people migrating from old-style
+ # "oidc_config" format (which has long used "oidc" as its idp_id) can migrate to
+ # a new-style "oidc_providers" entry without changing the idp_id for their provider
+ # (and thereby invalidating their user_external_ids data).
- # URI where to fetch the JWKS. Required if discovery is disabled and the
- # "openid" scope is used.
- #
- #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+ if idp_id != "oidc":
+ idp_id = "oidc-" + idp_id
- # Uncomment to skip metadata verification. Defaults to false.
- #
- # Use this if you are connecting to a provider that is not OpenID Connect
- # compliant.
- # Avoid this in production.
- #
- #skip_verification: true
+ # MSC2858 also specifies that the idp_icon must be a valid MXC uri
+ idp_icon = oidc_config.get("idp_icon")
+ if idp_icon is not None:
+ try:
+ parse_and_validate_mxc_uri(idp_icon)
+ except ValueError as e:
+ raise ConfigError(
+ "idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
+ ) from e
+
+ client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
+ client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey]
+ if client_secret_jwt_key_config is not None:
+ keyfile = client_secret_jwt_key_config.get("key_file")
+ if keyfile:
+ key = read_file(keyfile, config_path + ("client_secret_jwt_key",))
+ else:
+ key = client_secret_jwt_key_config["key"]
+ client_secret_jwt_key = OidcProviderClientSecretJwtKey(
+ key=key,
+ jwt_header=client_secret_jwt_key_config["jwt_header"],
+ jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
+ )
+ # parse attribute_requirements from config (list of dicts) into a list of SsoAttributeRequirement
+ attribute_requirements = [
+ SsoAttributeRequirement(**x)
+ for x in oidc_config.get("attribute_requirements", [])
+ ]
+
+ return OidcProviderConfig(
+ idp_id=idp_id,
+ idp_name=oidc_config.get("idp_name", "OIDC"),
+ idp_icon=idp_icon,
+ idp_brand=oidc_config.get("idp_brand"),
+ unstable_idp_brand=oidc_config.get("unstable_idp_brand"),
+ discover=oidc_config.get("discover", True),
+ issuer=oidc_config["issuer"],
+ client_id=oidc_config["client_id"],
+ client_secret=oidc_config.get("client_secret"),
+ client_secret_jwt_key=client_secret_jwt_key,
+ client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
+ scopes=oidc_config.get("scopes", ["openid"]),
+ authorization_endpoint=oidc_config.get("authorization_endpoint"),
+ token_endpoint=oidc_config.get("token_endpoint"),
+ userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
+ jwks_uri=oidc_config.get("jwks_uri"),
+ skip_verification=oidc_config.get("skip_verification", False),
+ user_profile_method=oidc_config.get("user_profile_method", "auto"),
+ allow_existing_users=oidc_config.get("allow_existing_users", False),
+ user_mapping_provider_class=user_mapping_provider_class,
+ user_mapping_provider_config=user_mapping_provider_config,
+ attribute_requirements=attribute_requirements,
+ )
+
+
+@attr.s(slots=True, frozen=True)
+class OidcProviderClientSecretJwtKey:
+ # a pem-encoded signing key
+ key = attr.ib(type=str)
+
+ # properties to include in the JWT header
+ jwt_header = attr.ib(type=Mapping[str, str])
+
+ # properties to include in the JWT payload.
+ jwt_payload = attr.ib(type=Mapping[str, str])
+
+
+@attr.s(slots=True, frozen=True)
+class OidcProviderConfig:
+ # a unique identifier for this identity provider. Used in the 'user_external_ids'
+ # table, as well as the query/path parameter used in the login protocol.
+ idp_id = attr.ib(type=str)
+
+ # user-facing name for this identity provider.
+ idp_name = attr.ib(type=str)
+
+ # Optional MXC URI for icon for this IdP.
+ idp_icon = attr.ib(type=Optional[str])
+
+ # Optional brand identifier for this IdP.
+ idp_brand = attr.ib(type=Optional[str])
+
+ # Optional brand identifier for the unstable API (see MSC2858).
+ unstable_idp_brand = attr.ib(type=Optional[str])
+
+ # whether the OIDC discovery mechanism is used to discover endpoints
+ discover = attr.ib(type=bool)
+
+ # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
+ # discover the provider's endpoints.
+ issuer = attr.ib(type=str)
+
+ # oauth2 client id to use
+ client_id = attr.ib(type=str)
+
+ # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
+ # a secret.
+ client_secret = attr.ib(type=Optional[str])
+
+ # key to use to construct a JWT to use as a client secret. May be `None` if
+ # `client_secret` is set.
+ client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
+
+ # auth method to use when exchanging the token.
+ # Valid values are 'client_secret_basic', 'client_secret_post' and
+ # 'none'.
+ client_auth_method = attr.ib(type=str)
+
+ # list of scopes to request
+ scopes = attr.ib(type=Collection[str])
+
+ # the oauth2 authorization endpoint. Required if discovery is disabled.
+ authorization_endpoint = attr.ib(type=Optional[str])
+
+ # the oauth2 token endpoint. Required if discovery is disabled.
+ token_endpoint = attr.ib(type=Optional[str])
+
+ # the OIDC userinfo endpoint. Required if discovery is disabled and the
+ # "openid" scope is not requested.
+ userinfo_endpoint = attr.ib(type=Optional[str])
+
+ # URI where to fetch the JWKS. Required if discovery is disabled and the
+ # "openid" scope is used.
+ jwks_uri = attr.ib(type=Optional[str])
+
+ # Whether to skip metadata verification
+ skip_verification = attr.ib(type=bool)
+
+ # Whether to fetch the user profile from the userinfo endpoint. Valid
+ # values are: "auto" or "userinfo_endpoint".
+ user_profile_method = attr.ib(type=str)
+
+ # whether to allow a user logging in via OIDC to match a pre-existing account
+ # instead of failing
+ allow_existing_users = attr.ib(type=bool)
- # Whether to fetch the user profile from the userinfo endpoint. Valid
- # values are: "auto" or "userinfo_endpoint".
- #
- # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
- # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
- #
- #user_profile_method: "userinfo_endpoint"
+ # the class of the user mapping provider
+ user_mapping_provider_class = attr.ib(type=Type)
- # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
- # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
- #
- #allow_existing_users: true
+ # the config of the user mapping provider
+ user_mapping_provider_config = attr.ib()
- # An external module can be provided here as a custom solution to mapping
- # attributes returned from a OIDC provider onto a matrix user.
- #
- user_mapping_provider:
- # The custom module's class. Uncomment to use a custom module.
- # Default is {mapping_provider!r}.
- #
- # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
- # for information on implementing a custom mapping provider.
- #
- #module: mapping_provider.OidcMappingProvider
-
- # Custom configuration values for the module. This section will be passed as
- # a Python dictionary to the user mapping provider module's `parse_config`
- # method.
- #
- # The examples below are intended for the default provider: they should be
- # changed if using a custom provider.
- #
- config:
- # name of the claim containing a unique identifier for the user.
- # Defaults to `sub`, which OpenID Connect compliant providers should provide.
- #
- #subject_claim: "sub"
-
- # Jinja2 template for the localpart of the MXID.
- #
- # When rendering, this template is given the following variables:
- # * user: The claims returned by the UserInfo Endpoint and/or in the ID
- # Token
- #
- # This must be configured if using the default mapping provider.
- #
- localpart_template: "{{{{ user.preferred_username }}}}"
-
- # Jinja2 template for the display name to set on first login.
- #
- # If unset, no displayname will be set.
- #
- #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
-
- # Jinja2 templates for extra attributes to send back to the client during
- # login.
- #
- # Note that these are non-standard and clients will ignore them without modifications.
- #
- #extra_attributes:
- #birthdate: "{{{{ user.birthdate }}}}"
- """.format(
- mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
- )
+ # required attributes to require in userinfo to allow login/registration
+ attribute_requirements = attr.ib(type=List[SsoAttributeRequirement])
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 4fda8ae987..85d07c4f8f 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -36,7 +36,7 @@ class PasswordAuthProviderConfig(Config):
providers.append({"module": LDAP_PROVIDER, "config": ldap_config})
providers.extend(config.get("password_providers") or [])
- for provider in providers:
+ for i, provider in enumerate(providers):
mod_name = provider["module"]
# This is for backwards compat when the ldap auth provider resided
@@ -45,7 +45,8 @@ class PasswordAuthProviderConfig(Config):
mod_name = LDAP_PROVIDER
(provider_class, provider_config) = load_module(
- {"module": mod_name, "config": provider["config"]}
+ {"module": mod_name, "config": provider["config"]},
+ ("password_providers", "<item %i>" % i),
)
self.password_providers.append((provider_class, provider_config))
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 3adbfb73e6..7831a2ef79 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config, ShardedWorkerHandlingConfig
+from ._base import Config
class PushConfig(Config):
@@ -27,9 +27,6 @@ class PushConfig(Config):
"group_unread_count_by_room", True
)
- pusher_instances = config.get("pusher_instances") or []
- self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
-
# There was a a 'redact_content' setting but mistakenly read from the
# 'email'section'. Check for the flag in the 'push' section, and log,
# but do not honour it to avoid nasty surprises when people upgrade.
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 4fca5b6d96..19322372a9 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -24,7 +24,7 @@ class RateLimitConfig:
defaults={"per_second": 0.17, "burst_count": 3.0},
):
self.per_second = config.get("per_second", defaults["per_second"])
- self.burst_count = config.get("burst_count", defaults["burst_count"])
+ self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
class FederationRateLimitConfig:
@@ -98,11 +98,35 @@ class RatelimitConfig(Config):
self.rc_joins_local = RateLimitConfig(
config.get("rc_joins", {}).get("local", {}),
- defaults={"per_second": 0.1, "burst_count": 3},
+ defaults={"per_second": 0.1, "burst_count": 10},
)
self.rc_joins_remote = RateLimitConfig(
config.get("rc_joins", {}).get("remote", {}),
- defaults={"per_second": 0.01, "burst_count": 3},
+ defaults={"per_second": 0.01, "burst_count": 10},
+ )
+
+ # Ratelimit cross-user key requests:
+ # * For local requests this is keyed by the sending device.
+ # * For requests received over federation this is keyed by the origin.
+ #
+ # Note that this isn't exposed in the configuration as it is obscure.
+ self.rc_key_requests = RateLimitConfig(
+ config.get("rc_key_requests", {}),
+ defaults={"per_second": 20, "burst_count": 100},
+ )
+
+ self.rc_3pid_validation = RateLimitConfig(
+ config.get("rc_3pid_validation") or {},
+ defaults={"per_second": 0.003, "burst_count": 5},
+ )
+
+ self.rc_invites_per_room = RateLimitConfig(
+ config.get("rc_invites", {}).get("per_room", {}),
+ defaults={"per_second": 0.3, "burst_count": 10},
+ )
+ self.rc_invites_per_user = RateLimitConfig(
+ config.get("rc_invites", {}).get("per_user", {}),
+ defaults={"per_second": 0.003, "burst_count": 5},
)
def generate_config_section(self, **kwargs):
@@ -136,6 +160,9 @@ class RatelimitConfig(Config):
# users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which
# can be more expensive)
+ # - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
+ # - two for ratelimiting how often invites can be sent in a room or to a
+ # specific user.
#
# The defaults are as shown below.
#
@@ -169,11 +196,22 @@ class RatelimitConfig(Config):
#rc_joins:
# local:
# per_second: 0.1
- # burst_count: 3
+ # burst_count: 10
# remote:
# per_second: 0.01
- # burst_count: 3
-
+ # burst_count: 10
+ #
+ #rc_3pid_validation:
+ # per_second: 0.003
+ # burst_count: 5
+ #
+ #rc_invites:
+ # per_room:
+ # per_second: 0.3
+ # burst_count: 10
+ # per_user:
+ # per_second: 0.003
+ # burst_count: 5
# Ratelimiting settings for incoming federation
#
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 770b921aff..b49e6609ce 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -13,24 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from distutils.util import strtobool
-
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias, UserID
-from synapse.util.stringutils import random_string_with_symbols
+from synapse.util.stringutils import random_string_with_symbols, strtobool
class RegistrationConfig(Config):
section = "registration"
def read_config(self, config, **kwargs):
- self.enable_registration = bool(
- strtobool(str(config.get("enable_registration", False)))
+ self.enable_registration = strtobool(
+ str(config.get("enable_registration", False))
)
if "disable_registration" in config:
- self.enable_registration = not bool(
- strtobool(str(config["disable_registration"]))
+ self.enable_registration = not strtobool(
+ str(config["disable_registration"])
)
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
@@ -157,9 +155,7 @@ class RegistrationConfig(Config):
self.session_lifetime = session_lifetime
# The success template used during fallback auth.
- self.fallback_success_template = self.read_templates(
- ["auth_success.html"], autoescape=True
- )[0]
+ self.fallback_success_template = self.read_template("auth_success.html")
self.bind_new_user_emails_to_sydent = config.get(
"bind_new_user_emails_to_sydent"
@@ -174,8 +170,8 @@ class RegistrationConfig(Config):
)
# Remove trailing slashes
- self.bind_new_user_emails_to_sydent = self.bind_new_user_emails_to_sydent.strip(
- "/"
+ self.bind_new_user_emails_to_sydent = (
+ self.bind_new_user_emails_to_sydent.strip("/")
)
def generate_config_section(self, generate_secrets=False, **kwargs):
@@ -369,6 +365,8 @@ class RegistrationConfig(Config):
# By default, any room aliases included in this list will be created
# as a publicly joinable room when the first user registers for the
# homeserver. This behaviour can be customised with the settings below.
+ # If the room already exists, make certain it is a publicly joinable
+ # room. The join rule of the room must be set to 'public'.
#
#auto_join_rooms:
# - "#example:example.com"
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index bc02754a33..6cade50acc 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -17,6 +17,7 @@ import os
from collections import namedtuple
from typing import Dict, List
+from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module
@@ -51,7 +52,7 @@ MediaStorageProviderConfig = namedtuple(
def parse_thumbnail_requirements(thumbnail_sizes):
- """ Takes a list of dictionaries with "width", "height", and "method" keys
+ """Takes a list of dictionaries with "width", "height", and "method" keys
and creates a map from image media types to the thumbnail size, thumbnailing
method, and thumbnail media type to precalculate
@@ -148,7 +149,7 @@ class ContentRepositoryConfig(Config):
# them to be started.
self.media_storage_providers = [] # type: List[tuple]
- for provider_config in storage_providers:
+ for i, provider_config in enumerate(storage_providers):
# We special case the module "file_system" so as not to need to
# expose FileStorageProviderBackend
if provider_config["module"] == "file_system":
@@ -157,7 +158,9 @@ class ContentRepositoryConfig(Config):
".FileStorageProviderBackend"
)
- provider_class, parsed_config = load_module(provider_config)
+ provider_class, parsed_config = load_module(
+ provider_config, ("media_storage_providers", "<item %i>" % i)
+ )
wrapper_config = MediaStorageProviderConfig(
provider_config.get("store_local", False),
@@ -179,7 +182,9 @@ class ContentRepositoryConfig(Config):
check_requirements("url_preview")
except DependencyException as e:
- raise ConfigError(e.message)
+ raise ConfigError(
+ e.message # noqa: B306, DependencyException.message is a property
+ )
if "url_preview_ip_range_blacklist" not in config:
raise ConfigError(
@@ -188,19 +193,17 @@ class ContentRepositoryConfig(Config):
"to work"
)
- # netaddr is a dependency for url_preview
- from netaddr import IPSet
-
- self.url_preview_ip_range_blacklist = IPSet(
- config["url_preview_ip_range_blacklist"]
- )
-
# we always blacklist '0.0.0.0' and '::', which are supposed to be
# unroutable addresses.
- self.url_preview_ip_range_blacklist.update(["0.0.0.0", "::"])
+ self.url_preview_ip_range_blacklist = generate_ip_set(
+ config["url_preview_ip_range_blacklist"],
+ ["0.0.0.0", "::"],
+ config_path=("url_preview_ip_range_blacklist",),
+ )
- self.url_preview_ip_range_whitelist = IPSet(
- config.get("url_preview_ip_range_whitelist", ())
+ self.url_preview_ip_range_whitelist = generate_ip_set(
+ config.get("url_preview_ip_range_whitelist", ()),
+ config_path=("url_preview_ip_range_whitelist",),
)
self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ())
@@ -211,7 +214,6 @@ class ContentRepositoryConfig(Config):
def generate_config_section(self, data_dir_path, **kwargs):
media_store = os.path.join(data_dir_path, "media_store")
- uploads_path = os.path.join(data_dir_path, "uploads")
formatted_thumbnail_sizes = "".join(
THUMBNAIL_SIZE_YAML % s for s in DEFAULT_THUMBNAIL_SIZES
@@ -219,6 +221,10 @@ class ContentRepositoryConfig(Config):
# strip final NL
formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1]
+ ip_range_blacklist = "\n".join(
+ " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
+ )
+
return (
r"""
## Media Store ##
@@ -313,15 +319,7 @@ class ContentRepositoryConfig(Config):
# you uncomment the following list as a starting point.
#
#url_preview_ip_range_blacklist:
- # - '127.0.0.0/8'
- # - '10.0.0.0/8'
- # - '172.16.0.0/12'
- # - '192.168.0.0/16'
- # - '100.64.0.0/10'
- # - '169.254.0.0/16'
- # - '::1/128'
- # - 'fe80::/64'
- # - 'fc00::/7'
+%(ip_range_blacklist)s
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 92e1b67528..2dd719c388 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -123,7 +123,7 @@ class RoomDirectoryConfig(Config):
alias (str)
Returns:
- boolean: True if user is allowed to crate the alias
+ boolean: True if user is allowed to create the alias
"""
for rule in self._alias_creation_rules:
if rule.matches(user_id, room_id, [alias]):
@@ -180,7 +180,7 @@ class _RoomDirectoryRule:
self._alias_regex = glob_to_regex(alias)
self._room_id_regex = glob_to_regex(room_id)
except Exception as e:
- raise ConfigError("Failed to parse glob into regex: %s", e)
+ raise ConfigError("Failed to parse glob into regex") from e
def matches(self, user_id, room_id, aliases):
"""Tests if this rule matches the given user_id, room_id and aliases.
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c1b8e98ae0..6db9cb5ced 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -17,8 +17,7 @@
import logging
from typing import Any, List
-import attr
-
+from synapse.config.sso import SsoAttributeRequirement
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module
@@ -77,7 +76,9 @@ class SAML2Config(Config):
try:
check_requirements("saml2")
except DependencyException as e:
- raise ConfigError(e.message)
+ raise ConfigError(
+ e.message # noqa: B306, DependencyException.message is a property
+ )
self.saml2_enabled = True
@@ -125,7 +126,7 @@ class SAML2Config(Config):
(
self.saml2_user_mapping_provider_class,
self.saml2_user_mapping_provider_config,
- ) = load_module(ump_dict)
+ ) = load_module(ump_dict, ("saml2_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods
# Note parse_config() is already checked during the call to load_module
@@ -196,8 +197,8 @@ class SAML2Config(Config):
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes
- metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
- response_url = public_baseurl + "_matrix/saml2/authn_response"
+ metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml"
+ response_url = public_baseurl + "_synapse/client/saml2/authn_response"
return {
"entityid": metadata_url,
"service": {
@@ -235,10 +236,10 @@ class SAML2Config(Config):
# enable SAML login.
#
# Once SAML support is enabled, a metadata file will be exposed at
- # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
+ # https://<server>:<port>/_synapse/client/saml2/metadata.xml, which you may be able to
# use to configure your SAML IdP with. Alternatively, you can manually configure
# the IdP to use an ACS location of
- # https://<server>:<port>/_matrix/saml2/authn_response.
+ # https://<server>:<port>/_synapse/client/saml2/authn_response.
#
saml2_config:
# `sp_config` is the configuration for the pysaml2 Service Provider.
@@ -398,32 +399,18 @@ class SAML2Config(Config):
}
-@attr.s(frozen=True)
-class SamlAttributeRequirement:
- """Object describing a single requirement for SAML attributes."""
-
- attribute = attr.ib(type=str)
- value = attr.ib(type=str)
-
- JSON_SCHEMA = {
- "type": "object",
- "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
- "required": ["attribute", "value"],
- }
-
-
ATTRIBUTE_REQUIREMENTS_SCHEMA = {
"type": "array",
- "items": SamlAttributeRequirement.JSON_SCHEMA,
+ "items": SsoAttributeRequirement.JSON_SCHEMA,
}
def _parse_attribute_requirements_def(
attribute_requirements: Any,
-) -> List[SamlAttributeRequirement]:
+) -> List[SsoAttributeRequirement]:
validate_config(
ATTRIBUTE_REQUIREMENTS_SCHEMA,
attribute_requirements,
- config_path=["saml2_config", "attribute_requirements"],
+ config_path=("saml2_config", "attribute_requirements"),
)
- return [SamlAttributeRequirement(**x) for x in attribute_requirements]
+ return [SsoAttributeRequirement(**x) for x in attribute_requirements]
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 61335595ff..c8b1a25004 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import logging
import os.path
import re
@@ -23,9 +24,10 @@ from typing import Any, Dict, Iterable, List, Optional, Set
import attr
import yaml
+from netaddr import AddrFormatError, IPNetwork, IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.http.endpoint import parse_and_validate_server_name
+from synapse.util.stringutils import parse_and_validate_server_name
from ._base import Config, ConfigError
@@ -39,6 +41,107 @@ logger = logging.Logger(__name__)
# in the list.
DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
+
+def _6to4(network: IPNetwork) -> IPNetwork:
+ """Convert an IPv4 network into a 6to4 IPv6 network per RFC 3056."""
+
+ # 6to4 networks consist of:
+ # * 2002 as the first 16 bits
+ # * The first IPv4 address in the network hex-encoded as the next 32 bits
+ # * The new prefix length needs to include the bits from the 2002 prefix.
+ hex_network = hex(network.first)[2:]
+ hex_network = ("0" * (8 - len(hex_network))) + hex_network
+ return IPNetwork(
+ "2002:%s:%s::/%d"
+ % (
+ hex_network[:4],
+ hex_network[4:],
+ 16 + network.prefixlen,
+ )
+ )
+
+
+def generate_ip_set(
+ ip_addresses: Optional[Iterable[str]],
+ extra_addresses: Optional[Iterable[str]] = None,
+ config_path: Optional[Iterable[str]] = None,
+) -> IPSet:
+ """
+ Generate an IPSet from a list of IP addresses or CIDRs.
+
+ Additionally, for each IPv4 network in the list of IP addresses, also
+ includes the corresponding IPv6 networks.
+
+ This includes:
+
+ * IPv4-Compatible IPv6 Address (see RFC 4291, section 2.5.5.1)
+ * IPv4-Mapped IPv6 Address (see RFC 4291, section 2.5.5.2)
+ * 6to4 Address (see RFC 3056, section 2)
+
+ Args:
+ ip_addresses: An iterable of IP addresses or CIDRs.
+ extra_addresses: An iterable of IP addresses or CIDRs.
+ config_path: The path in the configuration for error messages.
+
+ Returns:
+ A new IP set.
+ """
+ result = IPSet()
+ for ip in itertools.chain(ip_addresses or (), extra_addresses or ()):
+ try:
+ network = IPNetwork(ip)
+ except AddrFormatError as e:
+ raise ConfigError(
+ "Invalid IP range provided: %s." % (ip,), config_path
+ ) from e
+ result.add(network)
+
+ # It is possible that these already exist in the set, but that's OK.
+ if ":" not in str(network):
+ result.add(IPNetwork(network).ipv6(ipv4_compatible=True))
+ result.add(IPNetwork(network).ipv6(ipv4_compatible=False))
+ result.add(_6to4(network))
+
+ return result
+
+
+# IP ranges that are considered private / unroutable / don't make sense.
+DEFAULT_IP_RANGE_BLACKLIST = [
+ # Localhost
+ "127.0.0.0/8",
+ # Private networks.
+ "10.0.0.0/8",
+ "172.16.0.0/12",
+ "192.168.0.0/16",
+ # Carrier grade NAT.
+ "100.64.0.0/10",
+ # Address registry.
+ "192.0.0.0/24",
+ # Link-local networks.
+ "169.254.0.0/16",
+ # Formerly used for 6to4 relay.
+ "192.88.99.0/24",
+ # Testing networks.
+ "198.18.0.0/15",
+ "192.0.2.0/24",
+ "198.51.100.0/24",
+ "203.0.113.0/24",
+ # Multicast.
+ "224.0.0.0/4",
+ # Localhost
+ "::1/128",
+ # Link-local addresses.
+ "fe80::/10",
+ # Unique local addresses.
+ "fc00::/7",
+ # Testing networks.
+ "2001:db8::/32",
+ # Multicast.
+ "ff00::/8",
+ # Site-local addresses
+ "fec0::/10",
+]
+
DEFAULT_ROOM_VERSION = "6"
ROOM_COMPLEXITY_TOO_GREAT = (
@@ -156,7 +259,14 @@ class ServerConfig(Config):
# Whether to require sharing a room with a user to retrieve their
# profile data
self.limit_profile_requests_to_users_who_share_rooms = config.get(
- "limit_profile_requests_to_users_who_share_rooms", False,
+ "limit_profile_requests_to_users_who_share_rooms",
+ False,
+ )
+
+ # Whether to retrieve and display profile data for a user when they
+ # are invited to a room
+ self.include_profile_data_on_invite = config.get(
+ "include_profile_data_on_invite", True
)
if "restrict_public_rooms_to_local_users" in config and (
@@ -256,10 +366,37 @@ class ServerConfig(Config):
# due to resource constraints
self.admin_contact = config.get("admin_contact", None)
+ ip_range_blacklist = config.get(
+ "ip_range_blacklist", DEFAULT_IP_RANGE_BLACKLIST
+ )
+
+ # Attempt to create an IPSet from the given ranges
+
+ # Always blacklist 0.0.0.0, ::
+ self.ip_range_blacklist = generate_ip_set(
+ ip_range_blacklist, ["0.0.0.0", "::"], config_path=("ip_range_blacklist",)
+ )
+
+ self.ip_range_whitelist = generate_ip_set(
+ config.get("ip_range_whitelist", ()), config_path=("ip_range_whitelist",)
+ )
+
+ # The federation_ip_range_blacklist is used for backwards-compatibility
+ # and only applies to federation and identity servers. If it is not given,
+ # default to ip_range_blacklist.
+ federation_ip_range_blacklist = config.get(
+ "federation_ip_range_blacklist", ip_range_blacklist
+ )
+ # Always blacklist 0.0.0.0, ::
+ self.federation_ip_range_blacklist = generate_ip_set(
+ federation_ip_range_blacklist,
+ ["0.0.0.0", "::"],
+ config_path=("federation_ip_range_blacklist",),
+ )
+
if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
- self.start_pushers = config.get("start_pushers", True)
# (undocumented) option for torturing the worker-mode replication a bit,
# for testing. The value defines the number of milliseconds to pause before
@@ -494,7 +631,9 @@ class ServerConfig(Config):
if manhole:
self.listeners.append(
ListenerConfig(
- port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
+ port=manhole,
+ bind_addresses=["127.0.0.1"],
+ type="manhole",
)
)
@@ -530,7 +669,8 @@ class ServerConfig(Config):
# and letting the client know which email address is bound to an account and
# which one isn't.
self.request_token_inhibit_3pid_errors = config.get(
- "request_token_inhibit_3pid_errors", False,
+ "request_token_inhibit_3pid_errors",
+ False,
)
# List of users trialing the new experimental default push rules. This setting is
@@ -567,6 +707,10 @@ class ServerConfig(Config):
def generate_config_section(
self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
):
+ ip_range_blacklist = "\n".join(
+ " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
+ )
+
_, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None:
unsecure_port = bind_port - 400
@@ -681,11 +825,12 @@ class ServerConfig(Config):
#
#web_client_location: https://riot.example.com/
- # The public-facing base URL that clients use to access this HS
- # (not including _matrix/...). This is the same URL a user would
- # enter into the 'custom HS URL' field on their client. If you
- # use synapse with a reverse proxy, this should be the URL to reach
- # synapse via the proxy.
+ # The public-facing base URL that clients use to access this Homeserver (not
+ # including _matrix/...). This is the same URL a user might enter into the
+ # 'Custom Homeserver URL' field on their client. If you use Synapse with a
+ # reverse proxy, this should be the URL to reach Synapse via the proxy.
+ # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
+ # 'listeners' below).
#
#public_baseurl: https://example.com/
@@ -702,8 +847,7 @@ class ServerConfig(Config):
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
# 'false'. Note that profile data is also available via the federation
- # API, so this setting is of limited value if federation is enabled on
- # the server.
+ # API, unless allow_profile_lookup_over_federation is set to false.
#
#require_auth_for_profile_requests: true
@@ -714,6 +858,14 @@ class ServerConfig(Config):
#
#limit_profile_requests_to_users_who_share_rooms: true
+ # Uncomment to prevent a user's profile data from being retrieved and
+ # displayed in a room until they have joined it. By default, a user's
+ # profile data is included in an invite event, regardless of the values
+ # of the above two settings, and whether or not the users share a server.
+ # Defaults to 'true'.
+ #
+ #include_profile_data_on_invite: false
+
# If set to 'true', removes the need for authentication to access the server's
# public rooms directory through the client API, meaning that anyone can
# query the room directory. Defaults to 'false'.
@@ -758,6 +910,33 @@ class ServerConfig(Config):
#
#enable_search: false
+ # Prevent outgoing requests from being sent to the following blacklisted IP address
+ # CIDR ranges. If this option is not specified then it defaults to private IP
+ # address ranges (see the example below).
+ #
+ # The blacklist applies to the outbound requests for federation, identity servers,
+ # push servers, and for checking key validity for third-party invite events.
+ #
+ # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+ # listed here, since they correspond to unroutable addresses.)
+ #
+ # This option replaces federation_ip_range_blacklist in Synapse v1.25.0.
+ #
+ #ip_range_blacklist:
+%(ip_range_blacklist)s
+
+ # List of IP address CIDR ranges that should be allowed for federation,
+ # identity servers, push servers, and for checking key validity for
+ # third-party invite events. This is useful for specifying exceptions to
+ # wide-ranging blacklisted target IP ranges - e.g. for communication with
+ # a push server only visible in your network.
+ #
+ # This whitelist overrides ip_range_blacklist and defaults to an empty
+ # list.
+ #
+ #ip_range_whitelist:
+ # - '192.168.1.1'
+
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index 3d067d29db..3d05abc158 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -33,13 +33,14 @@ class SpamCheckerConfig(Config):
# spam checker, and thus was simply a dictionary with module
# and config keys. Support this old behaviour by checking
# to see if the option resolves to a dictionary
- self.spam_checkers.append(load_module(spam_checkers))
+ self.spam_checkers.append(load_module(spam_checkers, ("spam_checker",)))
elif isinstance(spam_checkers, list):
- for spam_checker in spam_checkers:
+ for i, spam_checker in enumerate(spam_checkers):
+ config_path = ("spam_checker", "<item %i>" % i)
if not isinstance(spam_checker, dict):
- raise ConfigError("spam_checker syntax is incorrect")
+ raise ConfigError("expected a mapping", config_path)
- self.spam_checkers.append(load_module(spam_checker))
+ self.spam_checkers.append(load_module(spam_checker, config_path))
else:
raise ConfigError("spam_checker syntax is incorrect")
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 4427676167..243cc681e8 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -12,14 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import Any, Dict, Optional
+
+import attr
from ._base import Config
+@attr.s(frozen=True)
+class SsoAttributeRequirement:
+ """Object describing a single requirement for SSO attributes."""
+
+ attribute = attr.ib(type=str)
+ # If a value is not given, than the attribute must simply exist.
+ value = attr.ib(type=Optional[str])
+
+ JSON_SCHEMA = {
+ "type": "object",
+ "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
+ "required": ["attribute", "value"],
+ }
+
+
class SSOConfig(Config):
- """SSO Configuration
- """
+ """SSO Configuration"""
section = "sso"
@@ -27,24 +43,28 @@ class SSOConfig(Config):
sso_config = config.get("sso") or {} # type: Dict[str, Any]
# The sso-specific template_dir
- template_dir = sso_config.get("template_dir")
+ self.sso_template_dir = sso_config.get("template_dir")
# Read templates from disk
(
+ self.sso_login_idp_picker_template,
self.sso_redirect_confirm_template,
self.sso_auth_confirm_template,
self.sso_error_template,
sso_account_deactivated_template,
sso_auth_success_template,
+ self.sso_auth_bad_user_template,
) = self.read_templates(
[
+ "sso_login_idp_picker.html",
"sso_redirect_confirm.html",
"sso_auth_confirm.html",
"sso_error.html",
"sso_account_deactivated.html",
"sso_auth_success.html",
+ "sso_auth_bad_user.html",
],
- template_dir,
+ self.sso_template_dir,
)
# These templates have no placeholders, so render them here
@@ -93,41 +113,135 @@ class SSOConfig(Config):
# - https://my.custom.client/
# Directory in which Synapse will try to find the template files below.
- # If not set, default templates from within the Synapse package will be used.
- #
- # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
- # If you *do* uncomment it, you will need to make sure that all the templates
- # below are in the directory.
+ # If not set, or the files named below are not found within the template
+ # directory, default templates from within the Synapse package will be used.
#
# Synapse will look for the following templates in this directory:
#
+ # * HTML page to prompt the user to choose an Identity Provider during
+ # login: 'sso_login_idp_picker.html'.
+ #
+ # This is only used if multiple SSO Identity Providers are configured.
+ #
+ # When rendering, this template is given the following variables:
+ # * redirect_url: the URL that the user will be redirected to after
+ # login.
+ #
+ # * server_name: the homeserver's name.
+ #
+ # * providers: a list of available Identity Providers. Each element is
+ # an object with the following attributes:
+ #
+ # * idp_id: unique identifier for the IdP
+ # * idp_name: user-facing name for the IdP
+ # * idp_icon: if specified in the IdP config, an MXC URI for an icon
+ # for the IdP
+ # * idp_brand: if specified in the IdP config, a textual identifier
+ # for the brand of the IdP
+ #
+ # The rendered HTML page should contain a form which submits its results
+ # back as a GET request, with the following query parameters:
+ #
+ # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
+ # to the template)
+ #
+ # * idp: the 'idp_id' of the chosen IDP.
+ #
+ # * HTML page to prompt new users to enter a userid and confirm other
+ # details: 'sso_auth_account_details.html'. This is only shown if the
+ # SSO implementation (with any user_mapping_provider) does not return
+ # a localpart.
+ #
+ # When rendering, this template is given the following variables:
+ #
+ # * server_name: the homeserver's name.
+ #
+ # * idp: details of the SSO Identity Provider that the user logged in
+ # with: an object with the following attributes:
+ #
+ # * idp_id: unique identifier for the IdP
+ # * idp_name: user-facing name for the IdP
+ # * idp_icon: if specified in the IdP config, an MXC URI for an icon
+ # for the IdP
+ # * idp_brand: if specified in the IdP config, a textual identifier
+ # for the brand of the IdP
+ #
+ # * user_attributes: an object containing details about the user that
+ # we received from the IdP. May have the following attributes:
+ #
+ # * display_name: the user's display_name
+ # * emails: a list of email addresses
+ #
+ # The template should render a form which submits the following fields:
+ #
+ # * username: the localpart of the user's chosen user id
+ #
+ # * HTML page allowing the user to consent to the server's terms and
+ # conditions. This is only shown for new users, and only if
+ # `user_consent.require_at_registration` is set.
+ #
+ # When rendering, this template is given the following variables:
+ #
+ # * server_name: the homeserver's name.
+ #
+ # * user_id: the user's matrix proposed ID.
+ #
+ # * user_profile.display_name: the user's proposed display name, if any.
+ #
+ # * consent_version: the version of the terms that the user will be
+ # shown
+ #
+ # * terms_url: a link to the page showing the terms.
+ #
+ # The template should render a form which submits the following fields:
+ #
+ # * accepted_version: the version of the terms accepted by the user
+ # (ie, 'consent_version' from the input variables).
+ #
# * HTML page for a confirmation step before redirecting back to the client
# with the login token: 'sso_redirect_confirm.html'.
#
- # When rendering, this template is given three variables:
- # * redirect_url: the URL the user is about to be redirected to. Needs
- # manual escaping (see
- # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+ # When rendering, this template is given the following variables:
+ #
+ # * redirect_url: the URL the user is about to be redirected to.
#
# * display_url: the same as `redirect_url`, but with the query
# parameters stripped. The intention is to have a
# human-readable URL to show to users, not to use it as
- # the final address to redirect to. Needs manual escaping
- # (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+ # the final address to redirect to.
#
# * server_name: the homeserver's name.
#
+ # * new_user: a boolean indicating whether this is the user's first time
+ # logging in.
+ #
+ # * user_id: the user's matrix ID.
+ #
+ # * user_profile.avatar_url: an MXC URI for the user's avatar, if any.
+ # None if the user has not set an avatar.
+ #
+ # * user_profile.display_name: the user's display name. None if the user
+ # has not set a display name.
+ #
# * HTML page which notifies the user that they are authenticating to confirm
# an operation on their account during the user interactive authentication
# process: 'sso_auth_confirm.html'.
#
# When rendering, this template is given the following variables:
- # * redirect_url: the URL the user is about to be redirected to. Needs
- # manual escaping (see
- # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+ # * redirect_url: the URL the user is about to be redirected to.
#
# * description: the operation which the user is being asked to confirm
#
+ # * idp: details of the Identity Provider that we will use to confirm
+ # the user's identity: an object with the following attributes:
+ #
+ # * idp_id: unique identifier for the IdP
+ # * idp_name: user-facing name for the IdP
+ # * idp_icon: if specified in the IdP config, an MXC URI for an icon
+ # for the IdP
+ # * idp_brand: if specified in the IdP config, a textual identifier
+ # for the brand of the IdP
+ #
# * HTML page shown after a successful user interactive authentication session:
# 'sso_auth_success.html'.
#
@@ -136,6 +250,14 @@ class SSOConfig(Config):
#
# This template has no additional variables.
#
+ # * HTML page shown after a user-interactive authentication session which
+ # does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
+ #
+ # When rendering, this template is given the following variables:
+ # * server_name: the homeserver's name.
+ # * user_id_to_verify: the MXID of the user that we are trying to
+ # validate.
+ #
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
# attempts to login: 'sso_account_deactivated.html'.
#
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index b559bfa411..2258329a52 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -13,10 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
+import logging
from ._base import Config
+ROOM_STATS_DISABLED_WARN = """\
+WARNING: room/user statistics have been disabled via the stats.enabled
+configuration setting. This means that certain features (such as the room
+directory) will not operate correctly. Future versions of Synapse may ignore
+this setting.
+
+To fix this warning, remove the stats.enabled setting from your configuration
+file.
+--------------------------------------------------------------------------------"""
+
+logger = logging.getLogger(__name__)
+
class StatsConfig(Config):
"""Stats Configuration
@@ -28,30 +40,29 @@ class StatsConfig(Config):
def read_config(self, config, **kwargs):
self.stats_enabled = True
self.stats_bucket_size = 86400 * 1000
- self.stats_retention = sys.maxsize
stats_config = config.get("stats", None)
if stats_config:
self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
self.stats_bucket_size = self.parse_duration(
stats_config.get("bucket_size", "1d")
)
- self.stats_retention = self.parse_duration(
- stats_config.get("retention", "%ds" % (sys.maxsize,))
- )
+ if not self.stats_enabled:
+ logger.warning(ROOM_STATS_DISABLED_WARN)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
- # Local statistics collection. Used in populating the room directory.
+ # Settings for local room and user statistics collection. See
+ # docs/room_and_user_statistics.md.
#
- # 'bucket_size' controls how large each statistics timeslice is. It can
- # be defined in a human readable short form -- e.g. "1d", "1y".
- #
- # 'retention' controls how long historical statistics will be kept for.
- # It can be defined in a human readable short form -- e.g. "1d", "1y".
- #
- #
- #stats:
- # enabled: true
- # bucket_size: 1d
- # retention: 1y
+ stats:
+ # Uncomment the following to disable room and user statistics. Note that doing
+ # so may cause certain features (such as the room directory) not to work
+ # correctly.
+ #
+ #enabled: false
+
+ # The size of each timeslice in the room_stats_historical and
+ # user_stats_historical tables, as a time period. Defaults to "1d".
+ #
+ #bucket_size: 1h
"""
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index 10a99c792e..c04e1c4e07 100644
--- a/synapse/config/third_party_event_rules.py
+++ b/synapse/config/third_party_event_rules.py
@@ -26,7 +26,9 @@ class ThirdPartyRulesConfig(Config):
provider = config.get("third_party_event_rules", None)
if provider is not None:
- self.third_party_event_rules = load_module(provider)
+ self.third_party_event_rules = load_module(
+ provider, ("third_party_event_rules",)
+ )
def generate_config_section(self, **kwargs):
return """\
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 0c1a854f09..727a1e7008 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -39,7 +39,9 @@ class TracerConfig(Config):
try:
check_requirements("opentracing")
except DependencyException as e:
- raise ConfigError(e.message)
+ raise ConfigError(
+ e.message # noqa: B306, DependencyException.message is a property
+ )
# The tracer is enabled so sanitize the config
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index 306e0cc8a4..ab7770e472 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -24,50 +24,55 @@ class UserDirectoryConfig(Config):
section = "userdirectory"
def read_config(self, config, **kwargs):
- self.user_directory_search_enabled = True
- self.user_directory_search_all_users = False
- self.user_directory_defer_to_id_server = None
- self.user_directory_search_prefer_local_users = False
- user_directory_config = config.get("user_directory", None)
- if user_directory_config:
- self.user_directory_search_enabled = user_directory_config.get(
- "enabled", True
- )
- self.user_directory_search_all_users = user_directory_config.get(
- "search_all_users", False
- )
- self.user_directory_defer_to_id_server = user_directory_config.get(
- "defer_to_id_server", None
- )
- self.user_directory_search_prefer_local_users = user_directory_config.get(
- "prefer_local_users", False
- )
+ user_directory_config = config.get("user_directory") or {}
+ self.user_directory_search_enabled = user_directory_config.get("enabled", True)
+ self.user_directory_search_all_users = user_directory_config.get(
+ "search_all_users", False
+ )
+ self.user_directory_defer_to_id_server = user_directory_config.get(
+ "defer_to_id_server", None
+ )
+ self.user_directory_search_prefer_local_users = user_directory_config.get(
+ "prefer_local_users", False
+ )
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# User Directory configuration
#
- # 'enabled' defines whether users can search the user directory. If
- # false then empty responses are returned to all queries. Defaults to
- # true.
- #
- # 'search_all_users' defines whether to search all users visible to your HS
- # when searching the user directory, rather than limiting to users visible
- # in public rooms. Defaults to false. If you set it True, you'll have to
- # rebuild the user_directory search indexes, see
- # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
- #
- # 'prefer_local_users' defines whether to prioritise local users in
- # search query results. If True, local users are more likely to appear above
- # remote users when searching the user directory. Defaults to false.
- #
- #user_directory:
- # enabled: true
- # search_all_users: false
- # prefer_local_users: false
- #
- # # If this is set, user search will be delegated to this ID server instead
- # # of synapse performing the search itself.
- # # This is an experimental API.
- # defer_to_id_server: https://id.example.com
+ user_directory:
+ # Defines whether users can search the user directory. If false then
+ # empty responses are returned to all queries. Defaults to true.
+ #
+ # Uncomment to disable the user directory.
+ #
+ #enabled: false
+
+ # Defines whether to search all users visible to your HS when searching
+ # the user directory, rather than limiting to users visible in public
+ # rooms. Defaults to false.
+ #
+ # If you set it true, you'll have to rebuild the user_directory search
+ # indexes, see:
+ # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
+ #
+ # Uncomment to return search results containing all known users, even if that
+ # user does not share a room with the requester.
+ #
+ #search_all_users: true
+
+ # If this is set, user search will be delegated to this ID server instead
+ # of Synapse performing the search itself.
+ # This is an experimental API.
+ #
+ #defer_to_id_server: https://id.example.com
+
+ # Defines whether to prefer local users in search query results.
+ # If True, local users are more likely to appear above remote users
+ # when searching the user directory. Defaults to false.
+ #
+ # Uncomment to prefer local over remote users in user directory search
+ # results.
+ #
+ #prefer_local_users: true
"""
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 57ab097eba..ac92375a85 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -17,9 +17,28 @@ from typing import List, Union
import attr
-from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
+from ._base import (
+ Config,
+ ConfigError,
+ RoutableShardedWorkerHandlingConfig,
+ ShardedWorkerHandlingConfig,
+)
from .server import ListenerConfig, parse_listener_def
+_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR = """
+The send_federation config option must be disabled in the main
+synapse process before they can be run in a separate worker.
+
+Please add ``send_federation: false`` to the main config
+"""
+
+_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR = """
+The start_pushers config option must be disabled in the main
+synapse process before they can be run in a separate worker.
+
+Please add ``start_pushers: false`` to the main config
+"""
+
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
"""Helper for allowing parsing a string or list of strings to a config
@@ -33,8 +52,7 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
@attr.s
class InstanceLocationConfig:
- """The host and port to talk to an instance via HTTP replication.
- """
+ """The host and port to talk to an instance via HTTP replication."""
host = attr.ib(type=str)
port = attr.ib(type=int)
@@ -53,6 +71,21 @@ class WriterLocations:
default=["master"], type=List[str], converter=_instance_to_list_converter
)
typing = attr.ib(default="master", type=str)
+ to_device = attr.ib(
+ default=["master"],
+ type=List[str],
+ converter=_instance_to_list_converter,
+ )
+ account_data = attr.ib(
+ default=["master"],
+ type=List[str],
+ converter=_instance_to_list_converter,
+ )
+ receipts = attr.ib(
+ default=["master"],
+ type=List[str],
+ converter=_instance_to_list_converter,
+ )
class WorkerConfig(Config):
@@ -85,7 +118,11 @@ class WorkerConfig(Config):
# The port on the main synapse for HTTP replication endpoint
self.worker_replication_http_port = config.get("worker_replication_http_port")
+ # The shared secret used for authentication when connecting to the main synapse.
+ self.worker_replication_secret = config.get("worker_replication_secret", None)
+
self.worker_name = config.get("worker_name", self.worker_app)
+ self.instance_name = self.worker_name or "master"
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
@@ -95,16 +132,47 @@ class WorkerConfig(Config):
if manhole:
self.worker_listeners.append(
ListenerConfig(
- port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
+ port=manhole,
+ bind_addresses=["127.0.0.1"],
+ type="manhole",
)
)
- # Whether to send federation traffic out in this process. This only
- # applies to some federation traffic, and so shouldn't be used to
- # "disable" federation
- self.send_federation = config.get("send_federation", True)
+ # Handle federation sender configuration.
+ #
+ # There are two ways of configuring which instances handle federation
+ # sending:
+ # 1. The old way where "send_federation" is set to false and running a
+ # `synapse.app.federation_sender` worker app.
+ # 2. Specifying the workers sending federation in
+ # `federation_sender_instances`.
+ #
+
+ send_federation = config.get("send_federation", True)
+
+ federation_sender_instances = config.get("federation_sender_instances")
+ if federation_sender_instances is None:
+ # Default to an empty list, which means "another, unknown, worker is
+ # responsible for it".
+ federation_sender_instances = []
+
+ # If no federation sender instances are set we check if
+ # `send_federation` is set, which means use master
+ if send_federation:
+ federation_sender_instances = ["master"]
+
+ if self.worker_app == "synapse.app.federation_sender":
+ if send_federation:
+ # If we're running federation senders, and not using
+ # `federation_sender_instances`, then we should have
+ # explicitly set `send_federation` to false.
+ raise ConfigError(
+ _FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR
+ )
+
+ federation_sender_instances = [self.worker_name]
- federation_sender_instances = config.get("federation_sender_instances") or []
+ self.send_federation = self.instance_name in federation_sender_instances
self.federation_shard_config = ShardedWorkerHandlingConfig(
federation_sender_instances
)
@@ -121,7 +189,7 @@ class WorkerConfig(Config):
# Check that the configured writers for events and typing also appears in
# `instance_map`.
- for stream in ("events", "typing"):
+ for stream in ("events", "typing", "to_device", "account_data", "receipts"):
instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances:
if instance != "master" and instance not in self.instance_map:
@@ -130,7 +198,52 @@ class WorkerConfig(Config):
% (instance, stream)
)
- self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
+ if len(self.writers.to_device) != 1:
+ raise ConfigError(
+ "Must only specify one instance to handle `to_device` messages."
+ )
+
+ if len(self.writers.account_data) != 1:
+ raise ConfigError(
+ "Must only specify one instance to handle `account_data` messages."
+ )
+
+ if len(self.writers.receipts) != 1:
+ raise ConfigError(
+ "Must only specify one instance to handle `receipts` messages."
+ )
+
+ if len(self.writers.events) == 0:
+ raise ConfigError("Must specify at least one instance to handle `events`.")
+
+ self.events_shard_config = RoutableShardedWorkerHandlingConfig(
+ self.writers.events
+ )
+
+ # Handle sharded push
+ start_pushers = config.get("start_pushers", True)
+ pusher_instances = config.get("pusher_instances")
+ if pusher_instances is None:
+ # Default to an empty list, which means "another, unknown, worker is
+ # responsible for it".
+ pusher_instances = []
+
+ # If no pushers instances are set we check if `start_pushers` is
+ # set, which means use master
+ if start_pushers:
+ pusher_instances = ["master"]
+
+ if self.worker_app == "synapse.app.pusher":
+ if start_pushers:
+ # If we're running pushers, and not using
+ # `pusher_instances`, then we should have explicitly set
+ # `start_pushers` to false.
+ raise ConfigError(_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR)
+
+ pusher_instances = [self.instance_name]
+
+ self.start_pushers = self.instance_name in pusher_instances
+ self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
# Whether this worker should run background tasks or not.
#
@@ -185,6 +298,13 @@ class WorkerConfig(Config):
# data). If not provided this defaults to the main process.
#
#run_background_tasks_on: worker1
+
+ # A shared secret used by the replication APIs to authenticate HTTP requests
+ # from workers.
+ #
+ # By default this is unused and traffic is not authenticated.
+ #
+ #worker_replication_secret: ""
"""
def read_arguments(self, args):
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 57fd426e87..c644b4dfc5 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -125,19 +125,24 @@ class FederationPolicyForHTTPS:
self._no_verify_ssl_context = _no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(_context_info_cb)
- def get_options(self, host: bytes):
+ self._should_verify = self._config.federation_verify_certificates
+
+ self._federation_certificate_verification_whitelist = (
+ self._config.federation_certificate_verification_whitelist
+ )
+ def get_options(self, host: bytes):
# IPolicyForHTTPS.get_options takes bytes, but we want to compare
# against the str whitelist. The hostnames in the whitelist are already
# IDNA-encoded like the hosts will be here.
ascii_host = host.decode("ascii")
# Check if certificate verification has been enabled
- should_verify = self._config.federation_verify_certificates
+ should_verify = self._should_verify
# Check if we've disabled certificate verification for this host
- if should_verify:
- for regex in self._config.federation_certificate_verification_whitelist:
+ if self._should_verify:
+ for regex in self._federation_certificate_verification_whitelist:
if regex.match(ascii_host):
should_verify = False
break
@@ -186,7 +191,7 @@ def _context_info_cb(ssl_connection, where, ret):
# ... we further assume that SSLClientConnectionCreator has set the
# '_synapse_tls_verifier' attribute to a ConnectionVerifier object.
tls_protocol._synapse_tls_verifier.verify_context_info_cb(ssl_connection, where)
- except: # noqa: E722, taken from the twisted implementation
+ except BaseException: # taken from the twisted implementation
logger.exception("Error during info_callback")
f = Failure()
tls_protocol.failVerification(f)
@@ -214,7 +219,7 @@ class SSLClientConnectionCreator:
# ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the
# tls_protocol so that the SSL context's info callback has something to
# call to do the cert verification.
- setattr(tls_protocol, "_synapse_tls_verifier", self._verifier)
+ tls_protocol._synapse_tls_verifier = self._verifier
return connection
@@ -227,7 +232,7 @@ class ConnectionVerifier:
# This code is based on twisted.internet.ssl.ClientTLSOptions.
- def __init__(self, hostname: bytes, verify_certs):
+ def __init__(self, hostname: bytes, verify_certs: bool):
self._verify_certs = verify_certs
_decoded = hostname.decode("ascii")
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 0422c43fab..8fb116ae18 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -18,7 +18,7 @@
import collections.abc
import hashlib
import logging
-from typing import Dict
+from typing import Any, Callable, Dict, Tuple
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
@@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
+from synapse.events import EventBase
from synapse.events.utils import prune_event, prune_event_dict
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+Hasher = Callable[[bytes], "hashlib._Hash"]
-def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
+
+def check_event_content_hash(
+ event: EventBase, hash_algorithm: Hasher = hashlib.sha256
+) -> bool:
"""Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
logger.debug(
@@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
return message_hash_bytes == expected_hash
-def compute_content_hash(event_dict, hash_algorithm):
+def compute_content_hash(
+ event_dict: Dict[str, Any], hash_algorithm: Hasher
+) -> Tuple[str, bytes]:
"""Compute the content hash of an event, which is the hash of the
unredacted event.
Args:
- event_dict (dict): The unredacted event as a dict
+ event_dict: The unredacted event as a dict
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event
Returns:
- tuple[str, bytes]: A tuple of the name of hash and the hash as raw
- bytes.
+ A tuple of the name of hash and the hash as raw bytes.
"""
event_dict = dict(event_dict)
event_dict.pop("age_ts", None)
@@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm):
return hashed.name, hashed.digest()
-def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
+def compute_event_reference_hash(
+ event, hash_algorithm: Hasher = hashlib.sha256
+) -> Tuple[str, bytes]:
"""Computes the event reference hash. This is the hash of the redacted
event.
Args:
- event (FrozenEvent)
+ event
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event
Returns:
- tuple[str, bytes]: A tuple of the name of hash and the hash as raw
- bytes.
+ A tuple of the name of hash and the hash as raw bytes.
"""
tmp_event = prune_event(event)
event_dict = tmp_event.get_pdu_json()
@@ -156,7 +163,7 @@ def add_hashes_and_signatures(
event_dict: JsonDict,
signature_name: str,
signing_key: SigningKey,
-):
+) -> None:
"""Add content hash and sign the event
Args:
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f23eacc0d7..d5fb51513b 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
import urllib
from collections import defaultdict
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr
from signedjson.key import (
@@ -40,6 +42,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.config.key import TrustedKeyServer
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
@@ -47,11 +50,15 @@ from synapse.logging.context import (
run_in_background,
)
from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.metrics import Measure
from synapse.util.retryutils import NotRetryingDestination
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -61,16 +68,17 @@ class VerifyJsonRequest:
A request to verify a JSON object.
Attributes:
- server_name(str): The name of the server to verify against.
-
- key_ids(set[str]): The set of key_ids to that could be used to verify the
- JSON object
+ server_name: The name of the server to verify against.
- json_object(dict): The JSON object to verify.
+ json_object: The JSON object to verify.
- minimum_valid_until_ts (int): time at which we require the signing key to
+ minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
+ request_name: The name of the request.
+
+ key_ids: The set of key_ids to that could be used to verify the JSON object
+
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no
@@ -80,12 +88,12 @@ class VerifyJsonRequest:
errbacks with an M_UNAUTHORIZED SynapseError.
"""
- server_name = attr.ib()
- json_object = attr.ib()
- minimum_valid_until_ts = attr.ib()
- request_name = attr.ib()
- key_ids = attr.ib(init=False)
- key_ready = attr.ib(default=attr.Factory(defer.Deferred))
+ server_name = attr.ib(type=str)
+ json_object = attr.ib(type=JsonDict)
+ minimum_valid_until_ts = attr.ib(type=int)
+ request_name = attr.ib(type=str)
+ key_ids = attr.ib(init=False, type=List[str])
+ key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
def __attrs_post_init__(self):
self.key_ids = signature_ids(self.json_object, self.server_name)
@@ -96,7 +104,9 @@ class KeyLookupError(ValueError):
class Keyring:
- def __init__(self, hs, key_fetchers=None):
+ def __init__(
+ self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
+ ):
self.clock = hs.get_clock()
if key_fetchers is None:
@@ -112,22 +122,26 @@ class Keyring:
# completes.
#
# These are regular, logcontext-agnostic Deferreds.
- self.key_downloads = {}
+ self.key_downloads = {} # type: Dict[str, defer.Deferred]
def verify_json_for_server(
- self, server_name, json_object, validity_time, request_name
- ):
+ self,
+ server_name: str,
+ json_object: JsonDict,
+ validity_time: int,
+ request_name: str,
+ ) -> defer.Deferred:
"""Verify that a JSON object has been signed by a given server
Args:
- server_name (str): name of the server which must have signed this object
+ server_name: name of the server which must have signed this object
- json_object (dict): object to be checked
+ json_object: object to be checked
- validity_time (int): timestamp at which we require the signing key to
+ validity_time: timestamp at which we require the signing key to
be valid. (0 implies we don't care)
- request_name (str): an identifier for this json object (eg, an event id)
+ request_name: an identifier for this json object (eg, an event id)
for logging.
Returns:
@@ -138,12 +152,14 @@ class Keyring:
requests = (req,)
return make_deferred_yieldable(self._verify_objects(requests)[0])
- def verify_json_objects_for_server(self, server_and_json):
+ def verify_json_objects_for_server(
+ self, server_and_json: Iterable[Tuple[str, dict, int, str]]
+ ) -> List[defer.Deferred]:
"""Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
Args:
- server_and_json (iterable[Tuple[str, dict, int, str]):
+ server_and_json:
Iterable of (server_name, json_object, validity_time, request_name)
tuples.
@@ -164,13 +180,14 @@ class Keyring:
for server_name, json_object, validity_time, request_name in server_and_json
)
- def _verify_objects(self, verify_requests):
+ def _verify_objects(
+ self, verify_requests: Iterable[VerifyJsonRequest]
+ ) -> List[defer.Deferred]:
"""Does the work of verify_json_[objects_]for_server
Args:
- verify_requests (iterable[VerifyJsonRequest]):
- Iterable of verification requests.
+ verify_requests: Iterable of verification requests.
Returns:
List<Deferred[None]>: for each input item, a deferred indicating success
@@ -182,7 +199,7 @@ class Keyring:
key_lookups = []
handle = preserve_fn(_handle_key_deferred)
- def process(verify_request):
+ def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
"""Process an entry in the request list
Adds a key request to key_lookups, and returns a deferred which
@@ -222,18 +239,20 @@ class Keyring:
return results
- async def _start_key_lookups(self, verify_requests):
+ async def _start_key_lookups(
+ self, verify_requests: List[VerifyJsonRequest]
+ ) -> None:
"""Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved.
Args:
- verify_requests (List[VerifyJsonRequest]):
+ verify_requests:
"""
try:
# map from server name to a set of outstanding request ids
- server_to_request_ids = {}
+ server_to_request_ids = {} # type: Dict[str, Set[int]]
for verify_request in verify_requests:
server_name = verify_request.server_name
@@ -275,11 +294,11 @@ class Keyring:
except Exception:
logger.exception("Error starting key lookups")
- async def wait_for_previous_lookups(self, server_names) -> None:
+ async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
"""Waits for any previous key lookups for the given servers to finish.
Args:
- server_names (Iterable[str]): list of servers which we want to look up
+ server_names: list of servers which we want to look up
Returns:
Resolves once all key lookups for the given servers have
@@ -304,7 +323,7 @@ class Keyring:
loop_count += 1
- def _get_server_verify_keys(self, verify_requests):
+ def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
"""Tries to find at least one key for each verify request
For each verify_request, verify_request.key_ready is called back with
@@ -312,7 +331,7 @@ class Keyring:
with a SynapseError if none of the keys are found.
Args:
- verify_requests (list[VerifyJsonRequest]): list of verify requests
+ verify_requests: list of verify requests
"""
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@@ -366,17 +385,19 @@ class Keyring:
run_in_background(do_iterations)
- async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+ async def _attempt_key_fetches_with_fetcher(
+ self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
+ ):
"""Use a key fetcher to attempt to satisfy some key requests
Args:
- fetcher (KeyFetcher): fetcher to use to fetch the keys
- remaining_requests (set[VerifyJsonRequest]): outstanding key requests.
+ fetcher: fetcher to use to fetch the keys
+ remaining_requests: outstanding key requests.
Any successfully-completed requests will be removed from the list.
"""
- # dict[str, dict[str, int]]: keys to fetch.
+ # The keys to fetch.
# server_name -> key_id -> min_valid_ts
- missing_keys = defaultdict(dict)
+ missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
for verify_request in remaining_requests:
# any completed requests should already have been removed
@@ -438,16 +459,18 @@ class Keyring:
remaining_requests.difference_update(completed)
-class KeyFetcher:
- async def get_keys(self, keys_to_fetch):
+class KeyFetcher(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
- keys_to_fetch (dict[str, dict[str, int]]):
+ keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns:
- Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
- map from server_name -> key_id -> FetchKeyResult
+ Map from server_name -> key_id -> FetchKeyResult
"""
raise NotImplementedError
@@ -455,31 +478,35 @@ class KeyFetcher:
class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- async def get_keys(self, keys_to_fetch):
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys"""
- keys_to_fetch = (
+ key_ids_to_fetch = (
(server_name, key_id)
for server_name, keys_for_server in keys_to_fetch.items()
for key_id in keys_for_server.keys()
)
- res = await self.store.get_server_verify_keys(keys_to_fetch)
- keys = {}
+ res = await self.store.get_server_verify_keys(key_ids_to_fetch)
+ keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
return keys
-class BaseV2KeyFetcher:
- def __init__(self, hs):
+class BaseV2KeyFetcher(KeyFetcher):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.config = hs.get_config()
- async def process_v2_response(self, from_server, response_json, time_added_ms):
+ async def process_v2_response(
+ self, from_server: str, response_json: JsonDict, time_added_ms: int
+ ) -> Dict[str, FetchKeyResult]:
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
@@ -493,16 +520,16 @@ class BaseV2KeyFetcher:
to /_matrix/key/v2/query.
Args:
- from_server (str): the name of the server producing this result: either
+ from_server: the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.
- response_json (dict): the json-decoded Server Keys response object
+ response_json: the json-decoded Server Keys response object
- time_added_ms (int): the timestamp to record in server_keys_json
+ time_added_ms: the timestamp to record in server_keys_json
Returns:
- Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+ Map from key_id to result object
"""
ts_valid_until_ms = response_json["valid_until_ts"]
@@ -575,21 +602,22 @@ class BaseV2KeyFetcher:
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the "perspectives" servers"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers
- async def get_keys(self, keys_to_fetch):
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys"""
- async def get_key(key_server):
+ async def get_key(key_server: TrustedKeyServer) -> Dict:
try:
- result = await self.get_server_verify_key_v2_indirect(
+ return await self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server
)
- return result
except KeyLookupError as e:
logger.warning(
"Key lookup failed from %r: %s", key_server.server_name, e
@@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
).addErrback(unwrapFirstError)
)
- union_of_keys = {}
+ union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for result in results:
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
return union_of_keys
- async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
+ async def get_server_verify_key_v2_indirect(
+ self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
- keys_to_fetch (dict[str, dict[str, int]]):
+ keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts
- key_server (synapse.config.key.TrustedKeyServer): notary server to query for
- the keys
+ key_server: notary server to query for the keys
Returns:
- dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
- from server_name -> key_id -> FetchKeyResult
+ Map from server_name -> key_id -> FetchKeyResult
Raises:
KeyLookupError if there was an error processing the entire response from
@@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
- keys = {}
- added_keys = []
+ keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
+ added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]]
time_now_ms = self.clock.time_msec()
+ assert isinstance(query_response, dict)
for response in query_response["server_keys"]:
# do this first, so that we can give useful errors thereafter
server_name = response.get("server_name")
@@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return keys
- def _validate_perspectives_response(self, key_server, response):
+ def _validate_perspectives_response(
+ self, key_server: TrustedKeyServer, response: JsonDict
+ ) -> None:
"""Optionally check the signature on the result of a /key/query request
Args:
- key_server (synapse.config.key.TrustedKeyServer): the notary server that
- produced this result
+ key_server: the notary server that produced this result
- response (dict): the json-decoded Server Keys response object
+ response: the json-decoded Server Keys response object
"""
perspective_name = key_server.server_name
perspective_keys = key_server.verify_keys
@@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
class ServerKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the origin servers"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
- async def get_keys(self, keys_to_fetch):
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
- keys_to_fetch (dict[str, iterable[str]]):
+ keys_to_fetch:
the keys to be fetched. server_name -> key_ids
Returns:
- dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
- map from server_name -> key_id -> FetchKeyResult
+ Map from server_name -> key_id -> FetchKeyResult
"""
results = {}
- async def get_key(key_to_fetch_item):
+ async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
server_name, key_ids = key_to_fetch_item
try:
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
@@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
await yieldable_gather_results(get_key, keys_to_fetch.items())
return results
- async def get_server_verify_key_v2_direct(self, server_name, key_ids):
+ async def get_server_verify_key_v2_direct(
+ self, server_name: str, key_ids: Iterable[str]
+ ) -> Dict[str, FetchKeyResult]:
"""
Args:
- server_name (str):
- key_ids (iterable[str]):
+ server_name:
+ key_ids:
Returns:
- dict[str, FetchKeyResult]: map from key ID to lookup result
+ Map from key ID to lookup result
Raises:
KeyLookupError if there was a problem making the lookup
"""
- keys = {} # type: dict[str, FetchKeyResult]
+ keys = {} # type: Dict[str, FetchKeyResult]
for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
@@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
+ assert isinstance(response, dict)
if response["server_name"] != server_name:
raise KeyLookupError(
"Expected a response for server %r not %r"
@@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
return keys
-async def _handle_key_deferred(verify_request) -> None:
+async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
"""Waits for the key to become available, and then performs a verification
Args:
- verify_request (VerifyJsonRequest):
+ verify_request:
Raises:
SynapseError if there was a problem performing the verification
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 498a699290..4e20851d7f 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -42,7 +42,7 @@ def check(
do_sig_check: bool = True,
do_size_check: bool = True,
) -> None:
- """ Checks if this event is correctly authed.
+ """Checks if this event is correctly authed.
Args:
room_version_obj: the version of the room
@@ -439,7 +439,9 @@ def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
def check_redaction(
- room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
+ room_version_obj: RoomVersion,
+ event: EventBase,
+ auth_events: StateMap[EventBase],
) -> bool:
"""Check whether the event sender is allowed to redact the target event.
@@ -475,7 +477,9 @@ def check_redaction(
def _check_power_levels(
- room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
+ room_version_obj: RoomVersion,
+ event: EventBase,
+ auth_events: StateMap[EventBase],
) -> None:
user_list = event.content.get("users", {})
# Validate users
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 8028663fa8..8f6b955d17 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -17,7 +17,6 @@
import abc
import os
-from distutils.util import strtobool
from typing import Dict, Optional, Tuple, Type
from unpaddedbase64 import encode_base64
@@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
+from synapse.util.stringutils import strtobool
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting a
@@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze
# NOTE: This is overridden by the configuration by the Synapse worker apps, but
# for the sake of tests, it is set here while it cannot be configured on the
# homeserver object itself.
+
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
@@ -97,7 +98,7 @@ class DefaultDictProperty(DictProperty):
class _EventInternalMetadata:
- __slots__ = ["_dict", "stream_ordering"]
+ __slots__ = ["_dict", "stream_ordering", "outlier"]
def __init__(self, internal_metadata_dict: JsonDict):
# we have to copy the dict, because it turns out that the same dict is
@@ -107,7 +108,10 @@ class _EventInternalMetadata:
# the stream ordering of this event. None, until it has been persisted.
self.stream_ordering = None # type: Optional[int]
- outlier = DictProperty("outlier") # type: bool
+ # whether this event is an outlier (ie, whether we have the state at that point
+ # in the DAG)
+ self.outlier = False
+
out_of_band_membership = DictProperty("out_of_band_membership") # type: bool
send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str
recheck_redaction = DictProperty("recheck_redaction") # type: bool
@@ -128,7 +132,7 @@ class _EventInternalMetadata:
return dict(self._dict)
def is_outlier(self) -> bool:
- return self._dict.get("outlier", False)
+ return self.outlier
def is_out_of_band_membership(self) -> bool:
"""Whether this is an out of band membership, like an invite or an invite
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 07df258e6e..c1c0426f6e 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -98,7 +98,9 @@ class EventBuilder:
return self._state_key is not None
async def build(
- self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
+ self,
+ prev_event_ids: List[str],
+ auth_event_ids: Optional[List[str]],
) -> EventBase:
"""Transform into a fully signed and hashed event
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index afecafe15c..7295df74fe 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -341,8 +341,7 @@ def _encode_state_dict(state_dict):
def _decode_state_dict(input):
- """Decodes a state dict encoded using `_encode_state_dict` above
- """
+ """Decodes a state dict encoded using `_encode_state_dict` above"""
if input is None:
return None
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 4404ffffe9..b5a9c71ee6 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,15 +15,21 @@
# limitations under the License.
import inspect
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from synapse.rest.media.v1._base import FileInfo
+from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
+from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import synapse.events
import synapse.server
+logger = logging.getLogger(__name__)
+
class SpamChecker:
def __init__(self, hs: "synapse.server.HomeServer"):
@@ -39,7 +45,9 @@ class SpamChecker:
else:
self.spam_checkers.append(module(config=config))
- def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
+ async def check_event_for_spam(
+ self, event: "synapse.events.EventBase"
+ ) -> Union[bool, str]:
"""Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if
@@ -50,18 +58,19 @@ class SpamChecker:
event: the event to be checked
Returns:
- True if the event is spammy.
+ True or a string if the event is spammy. If a string is returned it
+ will be used as the error message returned to the user.
"""
for spam_checker in self.spam_checkers:
- if spam_checker.check_event_for_spam(event):
+ if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
return True
return False
- def user_may_invite(
+ async def user_may_invite(
self,
inviter_userid: str,
- invitee_userid: str,
+ invitee_userid: Optional[str],
third_party_invite: Optional[Dict],
room_id: str,
new_room: bool,
@@ -91,13 +100,15 @@ class SpamChecker:
"""
for spam_checker in self.spam_checkers:
if (
- spam_checker.user_may_invite(
- inviter_userid,
- invitee_userid,
- third_party_invite,
- room_id,
- new_room,
- published_room,
+ await maybe_awaitable(
+ spam_checker.user_may_invite(
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room,
+ )
)
is False
):
@@ -105,7 +116,7 @@ class SpamChecker:
return True
- def user_may_create_room(
+ async def user_may_create_room(
self,
userid: str,
invite_list: List[str],
@@ -130,8 +141,10 @@ class SpamChecker:
"""
for spam_checker in self.spam_checkers:
if (
- spam_checker.user_may_create_room(
- userid, invite_list, third_party_invite_list, cloning
+ await maybe_awaitable(
+ spam_checker.user_may_create_room(
+ userid, invite_list, third_party_invite_list, cloning
+ )
)
is False
):
@@ -139,7 +152,7 @@ class SpamChecker:
return True
- def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
+ async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
"""Checks if a given user may create a room alias
If this method returns false, the association request will be rejected.
@@ -152,12 +165,17 @@ class SpamChecker:
True if the user may create a room alias, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
+ if (
+ await maybe_awaitable(
+ spam_checker.user_may_create_room_alias(userid, room_alias)
+ )
+ is False
+ ):
return False
return True
- def user_may_publish_room(self, userid: str, room_id: str) -> bool:
+ async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
"""Checks if a given user may publish a room to the directory
If this method returns false, the publish request will be rejected.
@@ -170,7 +188,12 @@ class SpamChecker:
True if the user may publish the room, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_publish_room(userid, room_id) is False:
+ if (
+ await maybe_awaitable(
+ spam_checker.user_may_publish_room(userid, room_id)
+ )
+ is False
+ ):
return False
return True
@@ -194,7 +217,7 @@ class SpamChecker:
return True
- def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+ async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in
@@ -216,16 +239,17 @@ class SpamChecker:
if checker:
# Make a copy of the user profile object to ensure the spam checker
# cannot modify it.
- if checker(user_profile.copy()):
+ if await maybe_awaitable(checker(user_profile.copy())):
return True
return False
- def check_registration_for_spam(
+ async def check_registration_for_spam(
self,
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str] = None,
) -> RegistrationBehaviour:
"""Checks if we should allow the given registration request.
@@ -234,6 +258,9 @@ class SpamChecker:
username: The request user name, if any
request_info: List of tuples of user agent and IP that
were used during the registration process.
+ auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
+ "cas". If any. Note this does not include users registered
+ via a password provider.
Returns:
Enum for how the request should be handled
@@ -244,9 +271,72 @@ class SpamChecker:
# spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker:
- behaviour = checker(email_threepid, username, request_info)
+ # Provide auth_provider_id if the function supports it
+ checker_args = inspect.signature(checker)
+ if len(checker_args.parameters) == 4:
+ d = checker(
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ )
+ elif len(checker_args.parameters) == 3:
+ d = checker(email_threepid, username, request_info)
+ else:
+ logger.error(
+ "Invalid signature for %s.check_registration_for_spam. Denying registration",
+ spam_checker.__module__,
+ )
+ return RegistrationBehaviour.DENY
+
+ behaviour = await maybe_awaitable(d)
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour
return RegistrationBehaviour.ALLOW
+
+ async def check_media_file_for_spam(
+ self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
+ ) -> bool:
+ """Checks if a piece of newly uploaded media should be blocked.
+
+ This will be called for local uploads, downloads of remote media, each
+ thumbnail generated for those, and web pages/images used for URL
+ previews.
+
+ Note that care should be taken to not do blocking IO operations in the
+ main thread. For example, to get the contents of a file a module
+ should do::
+
+ async def check_media_file_for_spam(
+ self, file: ReadableFileWrapper, file_info: FileInfo
+ ) -> bool:
+ buffer = BytesIO()
+ await file.write_chunks_to(buffer.write)
+
+ if buffer.getvalue() == b"Hello World":
+ return True
+
+ return False
+
+
+ Args:
+ file: An object that allows reading the contents of the media.
+ file_info: Metadata about the file.
+
+ Returns:
+ True if the media should be blocked or False if it should be
+ allowed.
+ """
+
+ for spam_checker in self.spam_checkers:
+ # For backwards compatibility, only run if the method exists on the
+ # spam checker
+ checker = getattr(spam_checker, "check_media_file_for_spam", None)
+ if checker:
+ spam = await maybe_awaitable(checker(file_wrapper, file_info))
+ if spam:
+ return True
+
+ return False
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 77fbd3f68a..9767d23940 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, Union
+from typing import TYPE_CHECKING, Union
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Requester, StateMap
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class ThirdPartyEventRules:
"""Allows server admins to provide a Python module implementing an extra
@@ -28,7 +31,7 @@ class ThirdPartyEventRules:
behaviours.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.third_party_rules = None
self.store = hs.get_datastore()
@@ -40,7 +43,8 @@ class ThirdPartyEventRules:
if module is not None:
self.third_party_rules = module(
- config=config, module_api=hs.get_module_api(),
+ config=config,
+ module_api=hs.get_module_api(),
)
async def check_event_allowed(
@@ -94,10 +98,9 @@ class ThirdPartyEventRules:
if self.third_party_rules is None:
return True
- ret = await self.third_party_rules.on_create_room(
+ return await self.third_party_rules.on_create_room(
requester, config, is_requester_admin
)
- return ret
async def check_threepid_can_be_invited(
self, medium: str, address: str, room_id: str
@@ -118,10 +121,9 @@ class ThirdPartyEventRules:
state_events = await self._get_state_map_for_room(room_id)
- ret = await self.third_party_rules.check_threepid_can_be_invited(
+ return await self.third_party_rules.check_threepid_can_be_invited(
medium, address, state_events
)
- return ret
async def check_visibility_can_be_modified(
self, room_id: str, new_visibility: str
@@ -142,7 +144,7 @@ class ThirdPartyEventRules:
check_func = getattr(
self.third_party_rules, "check_visibility_can_be_modified", None
)
- if not check_func or not isinstance(check_func, Callable):
+ if not check_func or not callable(check_func):
return True
state_events = await self._get_state_map_for_room(room_id)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 72c8ad4fe5..7050f0c885 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.util.async_helpers import yieldable_gather_results
+from synapse.util.frozenutils import unfreeze
from . import EventBase
@@ -34,7 +35,7 @@ SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
def prune_event(event: EventBase) -> EventBase:
- """ Returns a pruned version of the given event, which removes all keys we
+ """Returns a pruned version of the given event, which removes all keys we
don't know about or think could potentially be dodgy.
This is used when we "redact" an event. We want to remove all fields that
@@ -54,6 +55,8 @@ def prune_event(event: EventBase) -> EventBase:
event.internal_metadata.stream_ordering
)
+ pruned_event.internal_metadata.outlier = event.internal_metadata.outlier
+
# Mark the event as redacted
pruned_event.internal_metadata.redacted = True
@@ -79,13 +82,15 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
"state_key",
"depth",
"prev_events",
- "prev_state",
"auth_events",
"origin",
"origin_server_ts",
- "membership",
]
+ # Room versions from before MSC2176 had additional allowed keys.
+ if not room_version.msc2176_redaction_rules:
+ allowed_keys.extend(["prev_state", "membership"])
+
event_type = event_dict["type"]
new_content = {}
@@ -98,6 +103,10 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
if event_type == EventTypes.Member:
add_fields("membership")
elif event_type == EventTypes.Create:
+ # MSC2176 rules state that create events cannot be redacted.
+ if room_version.msc2176_redaction_rules:
+ return event_dict
+
add_fields("creator")
elif event_type == EventTypes.JoinRules:
add_fields("join_rule")
@@ -112,10 +121,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
"kick",
"redact",
)
+
+ if room_version.msc2176_redaction_rules:
+ add_fields("invite")
+
elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
add_fields("aliases")
elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility")
+ elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
+ add_fields("redacts")
allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
@@ -393,10 +408,19 @@ class EventClientSerializer:
# If there is an edit replace the content, preserving existing
# relations.
+ # Ensure we take copies of the edit content, otherwise we risk modifying
+ # the original event.
+ edit_content = edit.content.copy()
+
+ # Unfreeze the event content if necessary, so that we may modify it below
+ edit_content = unfreeze(edit_content)
+ serialized_event["content"] = edit_content.get("m.new_content", {})
+
+ # Check for existing relations
relations = event.content.get("m.relates_to")
- serialized_event["content"] = edit.content.get("m.new_content", {})
if relations:
- serialized_event["content"]["m.relates_to"] = relations
+ # Keep the relations, ensuring we use a dict copy of the original
+ serialized_event["content"]["m.relates_to"] = relations.copy()
else:
serialized_event["content"].pop("m.relates_to", None)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 38aa47963f..383737520a 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -78,6 +78,7 @@ class FederationBase:
ctx = current_context()
+ @defer.inlineCallbacks
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
@@ -105,7 +106,11 @@ class FederationBase:
)
return redacted_event
- if self.spam_checker.check_event_for_spam(pdu):
+ result = yield defer.ensureDeferred(
+ self.spam_checker.check_event_for_spam(pdu)
+ )
+
+ if result:
logger.warning(
"Event contains spam, redacting %s: %s",
pdu.event_id,
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 895ac59d36..184096d165 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -20,6 +20,7 @@ import copy
import itertools
import logging
from typing import (
+ TYPE_CHECKING,
Any,
Awaitable,
Callable,
@@ -34,6 +35,7 @@ from typing import (
Union,
)
+import attr
from prometheus_client import Counter
from twisted.internet import defer
@@ -63,6 +65,9 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
@@ -82,10 +87,10 @@ class InvalidResponseError(RuntimeError):
class FederationClient(FederationBase):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.pdu_destination_tried = {}
+ self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
@@ -118,33 +123,32 @@ class FederationClient(FederationBase):
self.pdu_destination_tried[event_id] = destination_dict
@log_function
- def make_query(
+ async def make_query(
self,
- destination,
- query_type,
- args,
- retry_on_dns_fail=False,
- ignore_backoff=False,
- ):
+ destination: str,
+ query_type: str,
+ args: dict,
+ retry_on_dns_fail: bool = False,
+ ignore_backoff: bool = False,
+ ) -> JsonDict:
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
Args:
- destination (str): Domain name of the remote homeserver
- query_type (str): Category of the query type; should match the
+ destination: Domain name of the remote homeserver
+ query_type: Category of the query type; should match the
handler name used in register_query_handler().
- args (dict): Mapping of strings to strings containing the details
+ args: Mapping of strings to strings containing the details
of the query request.
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
Returns:
- a Awaitable which will eventually yield a JSON object from the
- response
+ The JSON object from the response
"""
sent_queries_counter.labels(query_type).inc()
- return self.transport_layer.make_query(
+ return await self.transport_layer.make_query(
destination,
query_type,
args,
@@ -153,42 +157,52 @@ class FederationClient(FederationBase):
)
@log_function
- def query_client_keys(self, destination, content, timeout):
+ async def query_client_keys(
+ self, destination: str, content: JsonDict, timeout: int
+ ) -> JsonDict:
"""Query device keys for a device hosted on a remote server.
Args:
- destination (str): Domain name of the remote homeserver
- content (dict): The query content.
+ destination: Domain name of the remote homeserver
+ content: The query content.
Returns:
- an Awaitable which will eventually yield a JSON object from the
- response
+ The JSON object from the response
"""
sent_queries_counter.labels("client_device_keys").inc()
- return self.transport_layer.query_client_keys(destination, content, timeout)
+ return await self.transport_layer.query_client_keys(
+ destination, content, timeout
+ )
@log_function
- def query_user_devices(self, destination, user_id, timeout=30000):
+ async def query_user_devices(
+ self, destination: str, user_id: str, timeout: int = 30000
+ ) -> JsonDict:
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
sent_queries_counter.labels("user_devices").inc()
- return self.transport_layer.query_user_devices(destination, user_id, timeout)
+ return await self.transport_layer.query_user_devices(
+ destination, user_id, timeout
+ )
@log_function
- def claim_client_keys(self, destination, content, timeout):
+ async def claim_client_keys(
+ self, destination: str, content: JsonDict, timeout: int
+ ) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
Args:
- destination (str): Domain name of the remote homeserver
- content (dict): The query content.
+ destination: Domain name of the remote homeserver
+ content: The query content.
Returns:
- an Awaitable which will eventually yield a JSON object from the
- response
+ The JSON object from the response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
- return self.transport_layer.claim_client_keys(destination, content, timeout)
+ return await self.transport_layer.claim_client_keys(
+ destination, content, timeout
+ )
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
@@ -197,10 +211,10 @@ class FederationClient(FederationBase):
given destination server.
Args:
- dest (str): The remote homeserver to ask.
- room_id (str): The room_id to backfill.
- limit (int): The maximum number of events to return.
- extremities (list): our current backwards extremities, to backfill from
+ dest: The remote homeserver to ask.
+ room_id: The room_id to backfill.
+ limit: The maximum number of events to return.
+ extremities: our current backwards extremities, to backfill from
"""
logger.debug("backfill extrem=%s", extremities)
@@ -372,7 +386,7 @@ class FederationClient(FederationBase):
for events that have failed their checks
Returns:
- Deferred : A list of PDUs that have valid signatures and hashes.
+ A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@@ -420,7 +434,9 @@ class FederationClient(FederationBase):
else:
return [p for p in valid_pdus if p]
- async def get_event_auth(self, destination, room_id, event_id):
+ async def get_event_auth(
+ self, destination: str, room_id: str, event_id: str
+ ) -> List[EventBase]:
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
room_version = await self.store.get_room_version(room_id)
@@ -443,6 +459,7 @@ class FederationClient(FederationBase):
description: str,
destinations: Iterable[str],
callback: Callable[[str], Awaitable[T]],
+ failover_on_unknown_endpoint: bool = False,
) -> T:
"""Try an operation on a series of servers, until it succeeds
@@ -462,6 +479,10 @@ class FederationClient(FederationBase):
next server tried. Normally the stacktrace is logged but this is
suppressed if the exception is an InvalidResponseError.
+ failover_on_unknown_endpoint: if True, we will try other servers if it looks
+ like a server doesn't support the endpoint. This is typically useful
+ if the endpoint in question is new or experimental.
+
Returns:
The result of callback, if it succeeds
@@ -481,16 +502,31 @@ class FederationClient(FederationBase):
except UnsupportedRoomVersionError:
raise
except HttpResponseException as e:
- if not 500 <= e.code < 600:
- raise e.to_synapse_error()
- else:
- logger.warning(
- "Failed to %s via %s: %i %s",
- description,
- destination,
- e.code,
- e.args[0],
- )
+ synapse_error = e.to_synapse_error()
+ failover = False
+
+ if 500 <= e.code < 600:
+ failover = True
+
+ elif failover_on_unknown_endpoint:
+ # there is no good way to detect an "unknown" endpoint. Dendrite
+ # returns a 404 (with no body); synapse returns a 400
+ # with M_UNRECOGNISED.
+ if e.code == 404 or (
+ e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED
+ ):
+ failover = True
+
+ if not failover:
+ raise synapse_error from e
+
+ logger.warning(
+ "Failed to %s via %s: %i %s",
+ description,
+ destination,
+ e.code,
+ e.args[0],
+ )
except Exception:
logger.warning(
"Failed to %s via %s", description, destination, exc_info=True
@@ -702,18 +738,16 @@ class FederationClient(FederationBase):
return await self._try_destination_list("send_join", destinations, send_request)
- async def _do_send_join(self, destination: str, pdu: EventBase):
+ async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec()
try:
- content = await self.transport_layer.send_join_v2(
+ return await self.transport_layer.send_join_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
-
- return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -740,7 +774,11 @@ class FederationClient(FederationBase):
return resp[1]
async def send_invite(
- self, destination: str, room_id: str, event_id: str, pdu: EventBase,
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ pdu: EventBase,
) -> EventBase:
room_version = await self.store.get_room_version(room_id)
@@ -771,7 +809,7 @@ class FederationClient(FederationBase):
time_now = self._clock.time_msec()
try:
- content = await self.transport_layer.send_invite_v2(
+ return await self.transport_layer.send_invite_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
@@ -781,7 +819,6 @@ class FederationClient(FederationBase):
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
},
)
- return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -801,7 +838,7 @@ class FederationClient(FederationBase):
"User's homeserver does not support this room version",
Codes.UNSUPPORTED_ROOM_VERSION,
)
- elif e.code == 403:
+ elif e.code in (403, 429):
raise e.to_synapse_error()
else:
raise
@@ -844,18 +881,16 @@ class FederationClient(FederationBase):
"send_leave", destinations, send_request
)
- async def _do_send_leave(self, destination, pdu):
+ async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec()
try:
- content = await self.transport_layer.send_leave_v2(
+ return await self.transport_layer.send_leave_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
-
- return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -937,7 +972,7 @@ class FederationClient(FederationBase):
content=pdu.get_pdu_json(time_now),
)
- def get_public_rooms(
+ async def get_public_rooms(
self,
remote_server: str,
limit: Optional[int] = None,
@@ -945,7 +980,7 @@ class FederationClient(FederationBase):
search_filter: Optional[Dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
- ):
+ ) -> JsonDict:
"""Get the list of public rooms from a remote homeserver
Args:
@@ -959,8 +994,7 @@ class FederationClient(FederationBase):
party instance
Returns:
- Awaitable[Dict[str, Any]]: The response from the remote server, or None if
- `remote_server` is the same as the local server_name
+ The response from the remote server.
Raises:
HttpResponseException: There was an exception returned from the remote server
@@ -968,7 +1002,7 @@ class FederationClient(FederationBase):
requests over federation
"""
- return self.transport_layer.get_public_rooms(
+ return await self.transport_layer.get_public_rooms(
remote_server,
limit,
since_token,
@@ -981,7 +1015,7 @@ class FederationClient(FederationBase):
self,
destination: str,
room_id: str,
- earliest_events_ids: Sequence[str],
+ earliest_events_ids: Iterable[str],
latest_events: Iterable[EventBase],
limit: int,
min_depth: int,
@@ -1032,7 +1066,9 @@ class FederationClient(FederationBase):
return signed_events
- async def forward_third_party_invite(self, destinations, room_id, event_dict):
+ async def forward_third_party_invite(
+ self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
+ ) -> None:
for destination in destinations:
if destination == self.server_name:
continue
@@ -1041,7 +1077,7 @@ class FederationClient(FederationBase):
await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
- return None
+ return
except CodeMessageException:
raise
except Exception as e:
@@ -1053,7 +1089,7 @@ class FederationClient(FederationBase):
async def get_room_complexity(
self, destination: str, room_id: str
- ) -> Optional[dict]:
+ ) -> Optional[JsonDict]:
"""
Fetch the complexity of a remote room from another server.
@@ -1066,10 +1102,9 @@ class FederationClient(FederationBase):
could not fetch the complexity.
"""
try:
- complexity = await self.transport_layer.get_room_complexity(
+ return await self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id
)
- return complexity
except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us.
@@ -1087,3 +1122,141 @@ class FederationClient(FederationBase):
# If we don't manage to find it, return None. It's not an error if a
# server doesn't give it to us.
return None
+
+ async def get_space_summary(
+ self,
+ destinations: Iterable[str],
+ room_id: str,
+ suggested_only: bool,
+ max_rooms_per_space: Optional[int],
+ exclude_rooms: List[str],
+ ) -> "FederationSpaceSummaryResult":
+ """
+ Call other servers to get a summary of the given space
+
+
+ Args:
+ destinations: The remote servers. We will try them in turn, omitting any
+ that have been blacklisted.
+
+ room_id: ID of the space to be queried
+
+ suggested_only: If true, ask the remote server to only return children
+ with the "suggested" flag set
+
+ max_rooms_per_space: A limit on the number of children to return for each
+ space
+
+ exclude_rooms: A list of room IDs to tell the remote server to skip
+
+ Returns:
+ a parsed FederationSpaceSummaryResult
+
+ Raises:
+ SynapseError if we were unable to get a valid summary from any of the
+ remote servers
+ """
+
+ async def send_request(destination: str) -> FederationSpaceSummaryResult:
+ res = await self.transport_layer.get_space_summary(
+ destination=destination,
+ room_id=room_id,
+ suggested_only=suggested_only,
+ max_rooms_per_space=max_rooms_per_space,
+ exclude_rooms=exclude_rooms,
+ )
+
+ try:
+ return FederationSpaceSummaryResult.from_json_dict(res)
+ except ValueError as e:
+ raise InvalidResponseError(str(e))
+
+ return await self._try_destination_list(
+ "fetch space summary",
+ destinations,
+ send_request,
+ failover_on_unknown_endpoint=True,
+ )
+
+
+@attr.s(frozen=True, slots=True)
+class FederationSpaceSummaryEventResult:
+ """Represents a single event in the result of a successful get_space_summary call.
+
+ It's essentially just a serialised event object, but we do a bit of parsing and
+ validation in `from_json_dict` and store some of the validated properties in
+ object attributes.
+ """
+
+ event_type = attr.ib(type=str)
+ state_key = attr.ib(type=str)
+ via = attr.ib(type=Sequence[str])
+
+ # the raw data, including the above keys
+ data = attr.ib(type=JsonDict)
+
+ @classmethod
+ def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult":
+ """Parse an event within the result of a /spaces/ request
+
+ Args:
+ d: json object to be parsed
+
+ Raises:
+ ValueError if d is not a valid event
+ """
+
+ event_type = d.get("type")
+ if not isinstance(event_type, str):
+ raise ValueError("Invalid event: 'event_type' must be a str")
+
+ state_key = d.get("state_key")
+ if not isinstance(state_key, str):
+ raise ValueError("Invalid event: 'state_key' must be a str")
+
+ content = d.get("content")
+ if not isinstance(content, dict):
+ raise ValueError("Invalid event: 'content' must be a dict")
+
+ via = content.get("via")
+ if not isinstance(via, Sequence):
+ raise ValueError("Invalid event: 'via' must be a list")
+ if any(not isinstance(v, str) for v in via):
+ raise ValueError("Invalid event: 'via' must be a list of strings")
+
+ return cls(event_type, state_key, via, d)
+
+
+@attr.s(frozen=True, slots=True)
+class FederationSpaceSummaryResult:
+ """Represents the data returned by a successful get_space_summary call."""
+
+ rooms = attr.ib(type=Sequence[JsonDict])
+ events = attr.ib(type=Sequence[FederationSpaceSummaryEventResult])
+
+ @classmethod
+ def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult":
+ """Parse the result of a /spaces/ request
+
+ Args:
+ d: json object to be parsed
+
+ Raises:
+ ValueError if d is not a valid /spaces/ response
+ """
+ rooms = d.get("rooms")
+ if not isinstance(rooms, Sequence):
+ raise ValueError("'rooms' must be a list")
+ if any(not isinstance(r, dict) for r in rooms):
+ raise ValueError("Invalid room in 'rooms' list")
+
+ events = d.get("events")
+ if not isinstance(events, Sequence):
+ raise ValueError("'events' must be a list")
+ if any(not isinstance(e, dict) for e in events):
+ raise ValueError("Invalid event in 'events' list")
+ parsed_events = [
+ FederationSpaceSummaryEventResult.from_json_dict(e) for e in events
+ ]
+
+ return cls(rooms, parsed_events)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d8a2caf75f..cb48cc5722 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,12 +15,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import random
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
+ Iterable,
List,
Optional,
Tuple,
@@ -33,7 +35,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import (
AuthError,
Codes,
@@ -43,13 +45,13 @@ from synapse.api.errors import (
SynapseError,
UnsupportedRoomVersionError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.api import DEFAULT_ROOM_STATE_TYPES
from synapse.events import EventBase
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
-from synapse.http.endpoint import parse_server_name
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
@@ -62,10 +64,11 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
ReplicationGetQueryRestServlet,
)
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import JsonDict
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import parse_server_name
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -85,19 +88,19 @@ received_queries_counter = Counter(
)
pdu_process_time = Histogram(
- "synapse_federation_server_pdu_process_time", "Time taken to process an event",
+ "synapse_federation_server_pdu_process_time",
+ "Time taken to process an event",
)
-
-last_pdu_age_metric = Gauge(
- "synapse_federation_last_received_pdu_age",
- "The age (in seconds) of the last PDU successfully received from the given domain",
+last_pdu_ts_metric = Gauge(
+ "synapse_federation_last_received_pdu_time",
+ "The timestamp of the last PDU which was successfully received from the given domain",
labelnames=("server_name",),
)
class FederationServer(FederationBase):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.auth = hs.get_auth()
@@ -110,14 +113,15 @@ class FederationServer(FederationBase):
# with FederationHandlerRegistry.
hs.get_directory_handler()
- self._federation_ratelimiter = hs.get_federation_ratelimiter()
-
self._server_linearizer = Linearizer("fed_server")
- self._transaction_linearizer = Linearizer("fed_txn_handler")
+
+ # origins that we are currently processing a transaction from.
+ # a dict from origin to txn id.
+ self._active_transactions = {} # type: Dict[str, str]
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
- hs, "fed_txn_handler", timeout_ms=30000
+ hs.get_clock(), "fed_txn_handler", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self.transaction_actions = TransactionActions(self.store)
@@ -127,10 +131,10 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(
- hs, "state_resp", timeout_ms=30000
+ hs.get_clock(), "state_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._state_ids_resp_cache = ResponseCache(
- hs, "state_ids_resp", timeout_ms=30000
+ hs.get_clock(), "state_ids_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._federation_metrics_domains = (
@@ -167,6 +171,33 @@ class FederationServer(FederationBase):
logger.debug("[%s] Got transaction", transaction_id)
+ # Reject malformed transactions early: reject if too many PDUs/EDUs
+ if len(transaction.pdus) > 50 or ( # type: ignore
+ hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
+ ):
+ logger.info("Transaction PDU or EDU count too large. Returning 400")
+ return 400, {}
+
+ # we only process one transaction from each origin at a time. We need to do
+ # this check here, rather than in _on_incoming_transaction_inner so that we
+ # don't cache the rejection in _transaction_resp_cache (so that if the txn
+ # arrives again later, we can process it).
+ current_transaction = self._active_transactions.get(origin)
+ if current_transaction and current_transaction != transaction_id:
+ logger.warning(
+ "Received another txn %s from %s while still processing %s",
+ transaction_id,
+ origin,
+ current_transaction,
+ )
+ return 429, {
+ "errcode": Codes.UNKNOWN,
+ "error": "Too many concurrent transactions",
+ }
+
+ # CRITICAL SECTION: we must now not await until we populate _active_transactions
+ # in _on_incoming_transaction_inner.
+
# We wrap in a ResponseCache so that we de-duplicate retried
# transactions.
return await self._transaction_resp_cache.wrap(
@@ -180,31 +211,23 @@ class FederationServer(FederationBase):
async def _on_incoming_transaction_inner(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
- # Use a linearizer to ensure that transactions from a remote are
- # processed in order.
- with await self._transaction_linearizer.queue(origin):
- # We rate limit here *after* we've queued up the incoming requests,
- # so that we don't fill up the ratelimiter with blocked requests.
- #
- # This is important as the ratelimiter allows N concurrent requests
- # at a time, and only starts ratelimiting if there are more requests
- # than that being processed at a time. If we queued up requests in
- # the linearizer/response cache *after* the ratelimiting then those
- # queued up requests would count as part of the allowed limit of N
- # concurrent requests.
- with self._federation_ratelimiter.ratelimit(origin) as d:
- await d
-
- result = await self._handle_incoming_transaction(
- origin, transaction, request_time
- )
+ # CRITICAL SECTION: the first thing we must do (before awaiting) is
+ # add an entry to _active_transactions.
+ assert origin not in self._active_transactions
+ self._active_transactions[origin] = transaction.transaction_id # type: ignore
- return result
+ try:
+ result = await self._handle_incoming_transaction(
+ origin, transaction, request_time
+ )
+ return result
+ finally:
+ del self._active_transactions[origin]
async def _handle_incoming_transaction(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
- """ Process an incoming transaction and return the HTTP response
+ """Process an incoming transaction and return the HTTP response
Args:
origin: the server making the request
@@ -225,19 +248,6 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
- # Reject if PDU count > 50 or EDU count > 100
- if len(transaction.pdus) > 50 or ( # type: ignore
- hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
- ):
-
- logger.info("Transaction PDU or EDU count too large. Returning 400")
-
- response = {}
- await self.transaction_actions.set_response(
- origin, transaction, 400, response
- )
- return 400, response
-
# We process PDUs and EDUs in parallel. This is important as we don't
# want to block things like to device messages from reaching clients
# behind the potentially expensive handling of PDUs.
@@ -333,48 +343,53 @@ class FederationServer(FederationBase):
# impose a limit to avoid going too crazy with ram/cpu.
async def process_pdus_for_room(room_id: str):
- logger.debug("Processing PDUs for %s", room_id)
- try:
- await self.check_server_matches_acl(origin_host, room_id)
- except AuthError as e:
- logger.warning("Ignoring PDUs for room %s from banned server", room_id)
+ with nested_logging_context(room_id):
+ logger.debug("Processing PDUs for %s", room_id)
+
+ try:
+ await self.check_server_matches_acl(origin_host, room_id)
+ except AuthError as e:
+ logger.warning(
+ "Ignoring PDUs for room %s from banned server", room_id
+ )
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ pdu_results[event_id] = e.error_dict()
+ return
+
for pdu in pdus_by_room[room_id]:
- event_id = pdu.event_id
- pdu_results[event_id] = e.error_dict()
- return
-
- for pdu in pdus_by_room[room_id]:
- event_id = pdu.event_id
- with pdu_process_time.time():
- with nested_logging_context(event_id):
- try:
- await self._handle_received_pdu(origin, pdu)
- pdu_results[event_id] = {}
- except FederationError as e:
- logger.warning("Error handling PDU %s: %s", event_id, e)
- pdu_results[event_id] = {"error": str(e)}
- except Exception as e:
- f = failure.Failure()
- pdu_results[event_id] = {"error": str(e)}
- logger.error(
- "Failed to handle PDU %s",
- event_id,
- exc_info=(f.type, f.value, f.getTracebackObject()),
- )
+ pdu_results[pdu.event_id] = await process_pdu(pdu)
+
+ async def process_pdu(pdu: EventBase) -> JsonDict:
+ event_id = pdu.event_id
+ with pdu_process_time.time():
+ with nested_logging_context(event_id):
+ try:
+ await self._handle_received_pdu(origin, pdu)
+ return {}
+ except FederationError as e:
+ logger.warning("Error handling PDU %s: %s", event_id, e)
+ return {"error": str(e)}
+ except Exception as e:
+ f = failure.Failure()
+ logger.error(
+ "Failed to handle PDU %s",
+ event_id,
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
+ )
+ return {"error": str(e)}
await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
if newest_pdu_ts and origin in self._federation_metrics_domains:
- newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
- last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
+ last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
return pdu_results
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
- """Process the EDUs in a received transaction.
- """
+ """Process the EDUs in a received transaction."""
async def _process_edu(edu_dict):
received_edus_counter.inc()
@@ -437,25 +452,32 @@ class FederationServer(FederationBase):
raise AuthError(403, "Host not in room.")
resp = await self._state_ids_resp_cache.wrap(
- (room_id, event_id), self._on_state_ids_request_compute, room_id, event_id,
+ (room_id, event_id),
+ self._on_state_ids_request_compute,
+ room_id,
+ event_id,
)
return 200, resp
async def _on_state_ids_request_compute(self, room_id, event_id):
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
- auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
+ auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
self, room_id: str, event_id: str
) -> Dict[str, list]:
if event_id:
- pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+ pdus = await self.handler.get_state_for_pdu(
+ room_id, event_id
+ ) # type: Iterable[EventBase]
else:
pdus = (await self.state.get_current_state(room_id)).values()
- auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
+ auth_chain = await self.store.get_auth_chain(
+ room_id, [pdu.event_id for pdu in pdus]
+ )
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
@@ -599,7 +621,10 @@ class FederationServer(FederationBase):
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
async def on_send_knock_request(
- self, origin: str, content: JsonDict, room_id: str,
+ self,
+ origin: str,
+ content: JsonDict,
+ room_id: str,
) -> Dict[str, List[JsonDict]]:
"""
We have received a knock event for a room. Verify and send the event into the room
@@ -632,8 +657,10 @@ class FederationServer(FederationBase):
# Retrieve stripped state events from the room and send them back to the remote
# server. This will allow the remote server's clients to display information
# related to the room while the knock request is pending.
- stripped_room_state = await self.store.get_stripped_room_state_from_event_context(
- event_context, DEFAULT_ROOM_STATE_TYPES
+ stripped_room_state = (
+ await self.store.get_stripped_room_state_from_event_context(
+ event_context, DEFAULT_ROOM_STATE_TYPES
+ )
)
return {"knock_state_events": stripped_room_state}
@@ -749,7 +776,7 @@ class FederationServer(FederationBase):
)
async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
- """ Process a PDU received in a federation /send/ transaction.
+ """Process a PDU received in a federation /send/ transaction.
If the event is invalid, then this method throws a FederationError.
(The error will then be logged and sent back to the sender (which
@@ -776,27 +803,6 @@ class FederationServer(FederationBase):
if the event was unacceptable for any other reason (eg, too large,
too many prev_events, couldn't find the prev_events)
"""
- # check that it's actually being sent from a valid destination to
- # workaround bug #1753 in 0.18.5 and 0.18.6
- if origin != get_domain_from_id(pdu.sender):
- # We continue to accept join events from any server; this is
- # necessary for the federation join dance to work correctly.
- # (When we join over federation, the "helper" server is
- # responsible for sending out the join event, rather than the
- # origin. See bug #1893. This is also true for some third party
- # invites).
- if not (
- pdu.type == "m.room.member"
- and pdu.content
- and pdu.content.get("membership", None)
- in (Membership.JOIN, Membership.INVITE)
- ):
- logger.info(
- "Discarding PDU %s from invalid origin %s", pdu.event_id, origin
- )
- return
- else:
- logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
room_version = await self.store.get_room_version(pdu.room_id)
@@ -929,10 +935,21 @@ class FederationHandlerRegistry:
self.edu_handlers = (
{}
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
- self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
-
- # Map from type to instance name that we should route EDU handling to.
- self._edu_type_to_instance = {} # type: Dict[str, str]
+ self.query_handlers = (
+ {}
+ ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
+
+ # Map from type to instance names that we should route EDU handling to.
+ # We randomly choose one instance from the list to route to for each new
+ # EDU received.
+ self._edu_type_to_instance = {} # type: Dict[str, List[str]]
+
+ # A rate limiter for incoming room key requests per origin.
+ self._room_key_request_rate_limiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=self.config.rc_key_requests.per_second,
+ burst_count=self.config.rc_key_requests.burst_count,
+ )
def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
@@ -954,7 +971,7 @@ class FederationHandlerRegistry:
self.edu_handlers[edu_type] = handler
def register_query_handler(
- self, query_type: str, handler: Callable[[dict], defer.Deferred]
+ self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
@@ -974,12 +991,23 @@ class FederationHandlerRegistry:
self.query_handlers[query_type] = handler
def register_instance_for_edu(self, edu_type: str, instance_name: str):
- """Register that the EDU handler is on a different instance than master.
- """
- self._edu_type_to_instance[edu_type] = instance_name
+ """Register that the EDU handler is on a different instance than master."""
+ self._edu_type_to_instance[edu_type] = [instance_name]
+
+ def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
+ """Register that the EDU handler is on multiple instances."""
+ self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict):
- if not self.config.use_presence and edu_type == "m.presence":
+ if not self.config.use_presence and edu_type == EduTypes.Presence:
+ return
+
+ # If the incoming room key requests from a particular origin are over
+ # the limit, drop them.
+ if (
+ edu_type == EduTypes.RoomKeyRequest
+ and not self._room_key_request_rate_limiter.can_do_action(origin)
+ ):
return
# Check if we have a handler on this instance
@@ -995,8 +1023,11 @@ class FederationHandlerRegistry:
return
# Check if we can route it somewhere else that isn't us
- route_to = self._edu_type_to_instance.get(edu_type, "master")
- if route_to != self._instance_name:
+ instances = self._edu_type_to_instance.get(edu_type, ["master"])
+ if self._instance_name not in instances:
+ # Pick an instance randomly so that we don't overload one.
+ route_to = random.choice(instances)
+
try:
await self._send_edu(
instance_name=route_to,
@@ -1013,7 +1044,7 @@ class FederationHandlerRegistry:
# Oh well, let's just log and move on.
logger.warning("No handler registered for EDU type %s", edu_type)
- async def on_query(self, query_type: str, args: dict):
+ async def on_query(self, query_type: str, args: dict) -> JsonDict:
handler = self.query_handlers.get(query_type)
if handler:
return await handler(args)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 079e2b2fe0..ce5fc758f0 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -30,8 +30,7 @@ logger = logging.getLogger(__name__)
class TransactionActions:
- """ Defines persistence actions that relate to handling Transactions.
- """
+ """Defines persistence actions that relate to handling Transactions."""
def __init__(self, datastore):
self.store = datastore
@@ -57,8 +56,7 @@ class TransactionActions:
async def set_response(
self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None:
- """Persist how we responded to a transaction.
- """
+ """Persist how we responded to a transaction."""
transaction_id = transaction.transaction_id # type: ignore
if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id")
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5f1bf492c1..0c18c49abb 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,25 +31,39 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
-from typing import Dict, List, Tuple, Type
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ Hashable,
+ Iterable,
+ List,
+ Optional,
+ Sized,
+ Tuple,
+ Type,
+)
from sortedcontainers import SortedDict
-from twisted.internet import defer
-
from synapse.api.presence import UserPresenceState
+from synapse.federation.sender import AbstractFederationSender, FederationSender
from synapse.metrics import LaterGauge
+from synapse.replication.tcp.streams.federation import FederationStream
+from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
from synapse.util.metrics import Measure
from .units import Edu
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-class FederationRemoteSendQueue:
+class FederationRemoteSendQueue(AbstractFederationSender):
"""A drop in replacement for FederationSender"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
@@ -58,7 +72,7 @@ class FederationRemoteSendQueue:
# We may have multiple federation sender instances, so we need to track
# their positions separately.
self._sender_instances = hs.config.worker.federation_shard_config.instances
- self._sender_positions = {}
+ self._sender_positions = {} # type: Dict[str, int]
# Pending presence map user_id -> UserPresenceState
self.presence_map = {} # type: Dict[str, UserPresenceState]
@@ -71,7 +85,7 @@ class FederationRemoteSendQueue:
# Stream position -> (user_id, destinations)
self.presence_destinations = (
SortedDict()
- ) # type: SortedDict[int, Tuple[str, List[str]]]
+ ) # type: SortedDict[int, Tuple[str, Iterable[str]]]
# (destination, key) -> EDU
self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
@@ -94,7 +108,7 @@ class FederationRemoteSendQueue:
# we make a new function, so we need to make a new function so the inner
# lambda binds to the queue rather than to the name of the queue which
# changes. ARGH.
- def register(name, queue):
+ def register(name: str, queue: Sized) -> None:
LaterGauge(
"synapse_federation_send_queue_%s_size" % (queue_name,),
"",
@@ -115,13 +129,13 @@ class FederationRemoteSendQueue:
self.clock.looping_call(self._clear_queue, 30 * 1000)
- def _next_pos(self):
+ def _next_pos(self) -> int:
pos = self.pos
self.pos += 1
self.pos_time[self.clock.time_msec()] = pos
return pos
- def _clear_queue(self):
+ def _clear_queue(self) -> None:
"""Clear the queues for anything older than N minutes"""
FIVE_MINUTES_AGO = 5 * 60 * 1000
@@ -138,7 +152,7 @@ class FederationRemoteSendQueue:
self._clear_queue_before_pos(position_to_delete)
- def _clear_queue_before_pos(self, position_to_delete):
+ def _clear_queue_before_pos(self, position_to_delete: int) -> None:
"""Clear all the queues from before a given position"""
with Measure(self.clock, "send_queue._clear"):
# Delete things out of presence maps
@@ -188,13 +202,18 @@ class FederationRemoteSendQueue:
for key in keys[:i]:
del self.edus[key]
- def notify_new_events(self, max_token):
+ def notify_new_events(self, max_token: RoomStreamToken) -> None:
"""As per FederationSender"""
- # We don't need to replicate this as it gets sent down a different
- # stream.
- pass
+ # This should never get called.
+ raise NotImplementedError()
- def build_and_send_edu(self, destination, edu_type, content, key=None):
+ def build_and_send_edu(
+ self,
+ destination: str,
+ edu_type: str,
+ content: JsonDict,
+ key: Optional[Hashable] = None,
+ ) -> None:
"""As per FederationSender"""
if destination == self.server_name:
logger.info("Not sending EDU to ourselves")
@@ -218,38 +237,39 @@ class FederationRemoteSendQueue:
self.notifier.on_new_replication_data()
- def send_read_receipt(self, receipt):
+ async def send_read_receipt(self, receipt: ReadReceipt) -> None:
"""As per FederationSender
Args:
- receipt (synapse.types.ReadReceipt):
+ receipt:
"""
# nothing to do here: the replication listener will handle it.
- return defer.succeed(None)
- def send_presence(self, states):
+ def send_presence(self, states: List[UserPresenceState]) -> None:
"""As per FederationSender
Args:
- states (list(UserPresenceState))
+ states
"""
pos = self._next_pos()
# We only want to send presence for our own users, so lets always just
# filter here just in case.
- local_states = list(filter(lambda s: self.is_mine_id(s.user_id), states))
+ local_states = [s for s in states if self.is_mine_id(s.user_id)]
self.presence_map.update({state.user_id: state for state in local_states})
self.presence_changed[pos] = [state.user_id for state in local_states]
self.notifier.on_new_replication_data()
- def send_presence_to_destinations(self, states, destinations):
+ def send_presence_to_destinations(
+ self, states: Iterable[UserPresenceState], destinations: Iterable[str]
+ ) -> None:
"""As per FederationSender
Args:
- states (list[UserPresenceState])
- destinations (list[str])
+ states
+ destinations
"""
for state in states:
pos = self._next_pos()
@@ -258,15 +278,18 @@ class FederationRemoteSendQueue:
self.notifier.on_new_replication_data()
- def send_device_messages(self, destination):
+ def send_device_messages(self, destination: str) -> None:
"""As per FederationSender"""
# We don't need to replicate this as it gets sent down a different
# stream.
- def get_current_token(self):
+ def wake_destination(self, server: str) -> None:
+ pass
+
+ def get_current_token(self) -> int:
return self.pos - 1
- def federation_ack(self, instance_name, token):
+ def federation_ack(self, instance_name: str, token: int) -> None:
if self._sender_instances:
# If we have configured multiple federation sender instances we need
# to track their positions separately, and only clear the queue up
@@ -468,8 +491,7 @@ class KeyedEduRow(
class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
- """Streams EDUs that don't have keys. See KeyedEduRow
- """
+ """Streams EDUs that don't have keys. See KeyedEduRow"""
TypeId = "e"
@@ -505,13 +527,16 @@ ParsedFederationStreamData = namedtuple(
)
-def process_rows_for_federation(transaction_queue, rows):
+def process_rows_for_federation(
+ transaction_queue: FederationSender,
+ rows: List[FederationStream.FederationStreamRow],
+) -> None:
"""Parse a list of rows from the federation stream and put them in the
transaction queue ready for sending to the relevant homeservers.
Args:
- transaction_queue (FederationSender)
- rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow))
+ transaction_queue
+ rows
"""
# The federation stream contains a bunch of different types of
@@ -519,7 +544,10 @@ def process_rows_for_federation(transaction_queue, rows):
# them into the appropriate collection and then send them off.
buff = ParsedFederationStreamData(
- presence=[], presence_destinations=[], keyed_edus={}, edus={},
+ presence=[],
+ presence_destinations=[],
+ keyed_edus={},
+ edus={},
)
# Parse the rows in the stream and add to the buffer
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 604cfd1935..8babb1ebbe 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
-from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter
from twisted.internet import defer
-import synapse
import synapse.metrics
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
@@ -40,9 +40,12 @@ from synapse.metrics import (
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import ReadReceipt, RoomStreamToken
+from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
from synapse.util.metrics import Measure, measure_func
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
sent_pdus_destination_dist_count = Counter(
@@ -65,8 +68,91 @@ CATCH_UP_STARTUP_DELAY_SEC = 15
CATCH_UP_STARTUP_INTERVAL_SEC = 5
-class FederationSender:
- def __init__(self, hs: "synapse.server.HomeServer"):
+class AbstractFederationSender(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ def notify_new_events(self, max_token: RoomStreamToken) -> None:
+ """This gets called when we have some new events we might want to
+ send out to other servers.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def send_read_receipt(self, receipt: ReadReceipt) -> None:
+ """Send a RR to any other servers in the room
+
+ Args:
+ receipt: receipt to be sent
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def send_presence(self, states: List[UserPresenceState]) -> None:
+ """Send the new presence states to the appropriate destinations.
+
+ This actually queues up the presence states ready for sending and
+ triggers a background task to process them and send out the transactions.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def send_presence_to_destinations(
+ self, states: Iterable[UserPresenceState], destinations: Iterable[str]
+ ) -> None:
+ """Send the given presence states to the given destinations.
+
+ Args:
+ destinations:
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def build_and_send_edu(
+ self,
+ destination: str,
+ edu_type: str,
+ content: JsonDict,
+ key: Optional[Hashable] = None,
+ ) -> None:
+ """Construct an Edu object, and queue it for sending
+
+ Args:
+ destination: name of server to send to
+ edu_type: type of EDU to send
+ content: content of EDU
+ key: clobbering key for this edu
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def send_device_messages(self, destination: str) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def wake_destination(self, destination: str) -> None:
+ """Called when we want to retry sending transactions to a remote.
+
+ This is mainly useful if the remote server has been down and we think it
+ might have come back.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_current_token(self) -> int:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def federation_ack(self, instance_name: str, token: int) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def get_replication_rows(
+ self, instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
+ raise NotImplementedError()
+
+
+class FederationSender(AbstractFederationSender):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.server_name = hs.hostname
@@ -142,6 +228,8 @@ class FederationSender:
self._wake_destinations_needing_catchup,
)
+ self._external_cache = hs.get_external_cache()
+
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
@@ -197,22 +285,40 @@ class FederationSender:
if not event.internal_metadata.should_proactively_send():
return
- try:
- # Get the state from before the event.
- # We need to make sure that this is the state from before
- # the event and not from after it.
- # Otherwise if the last member on a server in a room is
- # banned then it won't receive the event because it won't
- # be in the room after the ban.
- destinations = await self.state.get_hosts_in_room_at_events(
- event.room_id, event_ids=event.prev_event_ids()
- )
- except Exception:
- logger.exception(
- "Failed to calculate hosts in room for event: %s",
- event.event_id,
+ destinations = None # type: Optional[Set[str]]
+ if not event.prev_event_ids():
+ # If there are no prev event IDs then the state is empty
+ # and so no remote servers in the room
+ destinations = set()
+ else:
+ # We check the external cache for the destinations, which is
+ # stored per state group.
+
+ sg = await self._external_cache.get(
+ "event_to_prev_state_group", event.event_id
)
- return
+ if sg:
+ destinations = await self._external_cache.get(
+ "get_joined_hosts", str(sg)
+ )
+
+ if destinations is None:
+ try:
+ # Get the state from before the event.
+ # We need to make sure that this is the state from before
+ # the event and not from after it.
+ # Otherwise if the last member on a server in a room is
+ # banned then it won't receive the event because it won't
+ # be in the room after the ban.
+ destinations = await self.state.get_hosts_in_room_at_events(
+ event.room_id, event_ids=event.prev_event_ids()
+ )
+ except Exception:
+ logger.exception(
+ "Failed to calculate hosts in room for event: %s",
+ event.event_id,
+ )
+ return
destinations = {
d
@@ -308,7 +414,9 @@ class FederationSender:
# to allow us to perform catch-up later on if the remote is unreachable
# for a while.
await self.store.store_destination_rooms_entries(
- destinations, pdu.room_id, pdu.internal_metadata.stream_ordering,
+ destinations,
+ pdu.room_id,
+ pdu.internal_metadata.stream_ordering,
)
for destination in destinations:
@@ -410,7 +518,7 @@ class FederationSender:
queue.flush_read_receipts_for_room(room_id)
@preserve_fn # the caller should not yield on this
- async def send_presence(self, states: List[UserPresenceState]):
+ async def send_presence(self, states: List[UserPresenceState]) -> None:
"""Send the new presence states to the appropriate destinations.
This actually queues up the presence states ready for sending and
@@ -452,10 +560,10 @@ class FederationSender:
self._processing_pending_presence = False
def send_presence_to_destinations(
- self, states: List[UserPresenceState], destinations: List[str]
+ self, states: Iterable[UserPresenceState], destinations: Iterable[str]
) -> None:
"""Send the given presence states to the given destinations.
- destinations (list[str])
+ destinations (list[str])
"""
if not states or not self.hs.config.use_presence:
@@ -472,7 +580,7 @@ class FederationSender:
self._get_per_destination_queue(destination).send_presence(states)
@measure_func("txnqueue._process_presence")
- async def _process_presence_inner(self, states: List[UserPresenceState]):
+ async def _process_presence_inner(self, states: List[UserPresenceState]) -> None:
"""Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination
"""
@@ -494,9 +602,9 @@ class FederationSender:
self,
destination: str,
edu_type: str,
- content: dict,
+ content: JsonDict,
key: Optional[Hashable] = None,
- ):
+ ) -> None:
"""Construct an Edu object, and queue it for sending
Args:
@@ -523,7 +631,7 @@ class FederationSender:
self.send_edu(edu, key)
- def send_edu(self, edu: Edu, key: Optional[Hashable]):
+ def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None:
"""Queue an EDU for sending
Args:
@@ -541,7 +649,7 @@ class FederationSender:
else:
queue.send_edu(edu)
- def send_device_messages(self, destination: str):
+ def send_device_messages(self, destination: str) -> None:
if destination == self.server_name:
logger.warning("Not sending device update to ourselves")
return
@@ -553,7 +661,7 @@ class FederationSender:
self._get_per_destination_queue(destination).attempt_new_transaction()
- def wake_destination(self, destination: str):
+ def wake_destination(self, destination: str) -> None:
"""Called when we want to retry sending transactions to a remote.
This is mainly useful if the remote server has been down and we think it
@@ -577,6 +685,10 @@ class FederationSender:
# to a worker.
return 0
+ def federation_ack(self, instance_name: str, token: int) -> None:
+ # It is not expected that this gets called on FederationSender.
+ raise NotImplementedError()
+
@staticmethod
async def get_replication_rows(
instance_name: str, from_token: int, to_token: int, target_row_count: int
@@ -585,7 +697,7 @@ class FederationSender:
# to a worker.
return [], 0, False
- async def _wake_destinations_needing_catchup(self):
+ async def _wake_destinations_needing_catchup(self) -> None:
"""
Wakes up destinations that need catch-up and are not currently being
backed off from.
@@ -596,8 +708,8 @@ class FederationSender:
last_processed = None # type: Optional[str]
while True:
- destinations_to_wake = await self.store.get_catch_up_outstanding_destinations(
- last_processed
+ destinations_to_wake = (
+ await self.store.get_catch_up_outstanding_destinations(last_processed)
)
if not destinations_to_wake:
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index db8e456fe8..89df9a619b 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,8 +15,9 @@
# limitations under the License.
import datetime
import logging
-from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple
+import attr
from prometheus_client import Counter
from synapse.api.errors import (
@@ -76,6 +77,7 @@ class PerDestinationQueue:
self._transaction_manager = transaction_manager
self._instance_name = hs.get_instance_name()
self._federation_shard_config = hs.config.worker.federation_shard_config
+ self._state = hs.get_state_handler()
self._should_send_on_this_instance = True
if not self._federation_shard_config.should_handle(
@@ -85,13 +87,18 @@ class PerDestinationQueue:
# processing. We have a guard in `attempt_new_transaction` that
# ensure we don't start sending stuff.
logger.error(
- "Create a per destination queue for %s on wrong worker", destination,
+ "Create a per destination queue for %s on wrong worker",
+ destination,
)
self._should_send_on_this_instance = False
self._destination = destination
self.transmission_loop_running = False
+ # Flag to signal to any running transmission loop that there is new data
+ # queued up to be sent.
+ self._new_data_to_send = False
+
# True whilst we are sending events that the remote homeserver missed
# because it was unreachable. We start in this state so we can perform
# catch-up at startup.
@@ -107,7 +114,7 @@ class PerDestinationQueue:
# destination (we are the only updater so this is safe)
self._last_successful_stream_ordering = None # type: Optional[int]
- # a list of pending PDUs
+ # a queue of pending PDUs
self._pending_pdus = [] # type: List[EventBase]
# XXX this is never actually used: see
@@ -207,6 +214,10 @@ class PerDestinationQueue:
transaction in the background.
"""
+ # Mark that we (may) have new things to send, so that any running
+ # transmission loop will recheck whether there is stuff to send.
+ self._new_data_to_send = True
+
if self.transmission_loop_running:
# XXX: this can get stuck on by a never-ending
# request at which point pending_pdus just keeps growing.
@@ -249,125 +260,41 @@ class PerDestinationQueue:
pending_pdus = []
while True:
- # We have to keep 2 free slots for presence and rr_edus
- limit = MAX_EDUS_PER_TRANSACTION - 2
-
- device_update_edus, dev_list_id = await self._get_device_update_edus(
- limit
- )
-
- limit -= len(device_update_edus)
-
- (
- to_device_edus,
- device_stream_id,
- ) = await self._get_to_device_message_edus(limit)
+ self._new_data_to_send = False
- pending_edus = device_update_edus + to_device_edus
-
- # BEGIN CRITICAL SECTION
- #
- # In order to avoid a race condition, we need to make sure that
- # the following code (from popping the queues up to the point
- # where we decide if we actually have any pending messages) is
- # atomic - otherwise new PDUs or EDUs might arrive in the
- # meantime, but not get sent because we hold the
- # transmission_loop_running flag.
-
- pending_pdus = self._pending_pdus
-
- # We can only include at most 50 PDUs per transactions
- pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
+ async with _TransactionQueueManager(self) as (
+ pending_pdus,
+ pending_edus,
+ ):
+ if not pending_pdus and not pending_edus:
+ logger.debug("TX [%s] Nothing to send", self._destination)
+
+ # If we've gotten told about new things to send during
+ # checking for things to send, we try looking again.
+ # Otherwise new PDUs or EDUs might arrive in the meantime,
+ # but not get sent because we hold the
+ # `transmission_loop_running` flag.
+ if self._new_data_to_send:
+ continue
+ else:
+ return
- pending_edus.extend(self._get_rr_edus(force_flush=False))
- pending_presence = self._pending_presence
- self._pending_presence = {}
- if pending_presence:
- pending_edus.append(
- Edu(
- origin=self._server_name,
- destination=self._destination,
- edu_type="m.presence",
- content={
- "push": [
- format_user_presence_state(
- presence, self._clock.time_msec()
- )
- for presence in pending_presence.values()
- ]
- },
+ if pending_pdus:
+ logger.debug(
+ "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+ self._destination,
+ len(pending_pdus),
)
- )
- pending_edus.extend(
- self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
- )
- while (
- len(pending_edus) < MAX_EDUS_PER_TRANSACTION
- and self._pending_edus_keyed
- ):
- _, val = self._pending_edus_keyed.popitem()
- pending_edus.append(val)
-
- if pending_pdus:
- logger.debug(
- "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
- self._destination,
- len(pending_pdus),
+ await self._transaction_manager.send_new_transaction(
+ self._destination, pending_pdus, pending_edus
)
- if not pending_pdus and not pending_edus:
- logger.debug("TX [%s] Nothing to send", self._destination)
- self._last_device_stream_id = device_stream_id
- return
-
- # if we've decided to send a transaction anyway, and we have room, we
- # may as well send any pending RRs
- if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
- pending_edus.extend(self._get_rr_edus(force_flush=True))
-
- # END CRITICAL SECTION
-
- success = await self._transaction_manager.send_new_transaction(
- self._destination, pending_pdus, pending_edus
- )
- if success:
sent_transactions_counter.inc()
sent_edus_counter.inc(len(pending_edus))
for edu in pending_edus:
sent_edus_by_type.labels(edu.edu_type).inc()
- # Remove the acknowledged device messages from the database
- # Only bother if we actually sent some device messages
- if to_device_edus:
- await self._store.delete_device_msgs_for_remote(
- self._destination, device_stream_id
- )
- # also mark the device updates as sent
- if device_update_edus:
- logger.info(
- "Marking as sent %r %r", self._destination, dev_list_id
- )
- await self._store.mark_as_sent_devices_by_remote(
- self._destination, dev_list_id
- )
-
- self._last_device_stream_id = device_stream_id
- self._last_device_list_stream_id = dev_list_id
-
- if pending_pdus:
- # we sent some PDUs and it was successful, so update our
- # last_successful_stream_ordering in the destinations table.
- final_pdu = pending_pdus[-1]
- last_successful_stream_ordering = (
- final_pdu.internal_metadata.stream_ordering
- )
- assert last_successful_stream_ordering
- await self._store.set_destination_last_successful_stream_ordering(
- self._destination, last_successful_stream_ordering
- )
- else:
- break
except NotRetryingDestination as e:
logger.debug(
"TX [%s] not ready for retry yet (next retry at %s) - "
@@ -400,7 +327,7 @@ class PerDestinationQueue:
self._pending_presence = {}
self._pending_rrs = {}
- self._start_catching_up()
+ self._start_catching_up()
except FederationDeniedError as e:
logger.info(e)
except HttpResponseException as e:
@@ -411,7 +338,6 @@ class PerDestinationQueue:
e,
)
- self._start_catching_up()
except RequestSendFailed as e:
logger.warning(
"TX [%s] Failed to send transaction: %s", self._destination, e
@@ -421,16 +347,12 @@ class PerDestinationQueue:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
-
- self._start_catching_up()
except Exception:
logger.exception("TX [%s] Failed to send transaction", self._destination)
for p in pending_pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
-
- self._start_catching_up()
finally:
# We want to be *very* sure we clear this after we stop processing
self.transmission_loop_running = False
@@ -440,8 +362,10 @@ class PerDestinationQueue:
if first_catch_up_check:
# first catchup so get last_successful_stream_ordering from database
- self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering(
- self._destination
+ self._last_successful_stream_ordering = (
+ await self._store.get_destination_last_successful_stream_ordering(
+ self._destination
+ )
)
if self._last_successful_stream_ordering is None:
@@ -457,7 +381,8 @@ class PerDestinationQueue:
# get at most 50 catchup room/PDUs
while True:
event_ids = await self._store.get_catch_up_room_event_ids(
- self._destination, self._last_successful_stream_ordering,
+ self._destination,
+ self._last_successful_stream_ordering,
)
if not event_ids:
@@ -491,25 +416,97 @@ class PerDestinationQueue:
"This should not happen." % event_ids
)
- if logger.isEnabledFor(logging.INFO):
- rooms = [p.room_id for p in catchup_pdus]
- logger.info("Catching up rooms to %s: %r", self._destination, rooms)
+ # We send transactions with events from one room only, as its likely
+ # that the remote will have to do additional processing, which may
+ # take some time. It's better to give it small amounts of work
+ # rather than risk the request timing out and repeatedly being
+ # retried, and not making any progress.
+ #
+ # Note: `catchup_pdus` will have exactly one PDU per room.
+ for pdu in catchup_pdus:
+ # The PDU from the DB will be the last PDU in the room from
+ # *this server* that wasn't sent to the remote. However, other
+ # servers may have sent lots of events since then, and we want
+ # to try and tell the remote only about the *latest* events in
+ # the room. This is so that it doesn't get inundated by events
+ # from various parts of the DAG, which all need to be processed.
+ #
+ # Note: this does mean that in large rooms a server coming back
+ # online will get sent the same events from all the different
+ # servers, but the remote will correctly deduplicate them and
+ # handle it only once.
+
+ # Step 1, fetch the current extremities
+ extrems = await self._store.get_prev_events_for_room(pdu.room_id)
+
+ if pdu.event_id in extrems:
+ # If the event is in the extremities, then great! We can just
+ # use that without having to do further checks.
+ room_catchup_pdus = [pdu]
+ else:
+ # If not, fetch the extremities and figure out which we can
+ # send.
+ extrem_events = await self._store.get_events_as_list(extrems)
+
+ new_pdus = []
+ for p in extrem_events:
+ # We pulled this from the DB, so it'll be non-null
+ assert p.internal_metadata.stream_ordering
+
+ # Filter out events that happened before the remote went
+ # offline
+ if (
+ p.internal_metadata.stream_ordering
+ < self._last_successful_stream_ordering
+ ):
+ continue
+
+ # Filter out events where the server is not in the room,
+ # e.g. it may have left/been kicked. *Ideally* we'd pull
+ # out the kick and send that, but it's a rare edge case
+ # so we don't bother for now (the server that sent the
+ # kick should send it out if its online).
+ hosts = await self._state.get_hosts_in_room_at_events(
+ p.room_id, [p.event_id]
+ )
+ if self._destination not in hosts:
+ continue
- success = await self._transaction_manager.send_new_transaction(
- self._destination, catchup_pdus, []
- )
+ new_pdus.append(p)
- if not success:
- return
+ # If we've filtered out all the extremities, fall back to
+ # sending the original event. This should ensure that the
+ # server gets at least some of missed events (especially if
+ # the other sending servers are up).
+ if new_pdus:
+ room_catchup_pdus = new_pdus
+ else:
+ room_catchup_pdus = [pdu]
- sent_transactions_counter.inc()
- final_pdu = catchup_pdus[-1]
- self._last_successful_stream_ordering = cast(
- int, final_pdu.internal_metadata.stream_ordering
- )
- await self._store.set_destination_last_successful_stream_ordering(
- self._destination, self._last_successful_stream_ordering
- )
+ logger.info(
+ "Catching up rooms to %s: %r", self._destination, pdu.room_id
+ )
+
+ await self._transaction_manager.send_new_transaction(
+ self._destination, room_catchup_pdus, []
+ )
+
+ sent_transactions_counter.inc()
+
+ # We pulled this from the DB, so it'll be non-null
+ assert pdu.internal_metadata.stream_ordering
+
+ # Note that we mark the last successful stream ordering as that
+ # from the *original* PDU, rather than the PDU(s) we actually
+ # send. This is because we use it to mark our position in the
+ # queue of missed PDUs to process.
+ self._last_successful_stream_ordering = (
+ pdu.internal_metadata.stream_ordering
+ )
+
+ await self._store.set_destination_last_successful_stream_ordering(
+ self._destination, self._last_successful_stream_ordering
+ )
def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
if not self._pending_rrs:
@@ -580,3 +577,135 @@ class PerDestinationQueue:
"""
self._catching_up = True
self._pending_pdus = []
+
+
+@attr.s(slots=True)
+class _TransactionQueueManager:
+ """A helper async context manager for pulling stuff off the queues and
+ tracking what was last successfully sent, etc.
+ """
+
+ queue = attr.ib(type=PerDestinationQueue)
+
+ _device_stream_id = attr.ib(type=Optional[int], default=None)
+ _device_list_id = attr.ib(type=Optional[int], default=None)
+ _last_stream_ordering = attr.ib(type=Optional[int], default=None)
+ _pdus = attr.ib(type=List[EventBase], factory=list)
+
+ async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]:
+ # First we calculate the EDUs we want to send, if any.
+
+ # We start by fetching device related EDUs, i.e device updates and to
+ # device messages. We have to keep 2 free slots for presence and rr_edus.
+ limit = MAX_EDUS_PER_TRANSACTION - 2
+
+ device_update_edus, dev_list_id = await self.queue._get_device_update_edus(
+ limit
+ )
+
+ if device_update_edus:
+ self._device_list_id = dev_list_id
+ else:
+ self.queue._last_device_list_stream_id = dev_list_id
+
+ limit -= len(device_update_edus)
+
+ (
+ to_device_edus,
+ device_stream_id,
+ ) = await self.queue._get_to_device_message_edus(limit)
+
+ if to_device_edus:
+ self._device_stream_id = device_stream_id
+ else:
+ self.queue._last_device_stream_id = device_stream_id
+
+ pending_edus = device_update_edus + to_device_edus
+
+ # Now add the read receipt EDU.
+ pending_edus.extend(self.queue._get_rr_edus(force_flush=False))
+
+ # And presence EDU.
+ if self.queue._pending_presence:
+ pending_edus.append(
+ Edu(
+ origin=self.queue._server_name,
+ destination=self.queue._destination,
+ edu_type="m.presence",
+ content={
+ "push": [
+ format_user_presence_state(
+ presence, self.queue._clock.time_msec()
+ )
+ for presence in self.queue._pending_presence.values()
+ ]
+ },
+ )
+ )
+ self.queue._pending_presence = {}
+
+ # Finally add any other types of EDUs if there is room.
+ pending_edus.extend(
+ self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
+ )
+ while (
+ len(pending_edus) < MAX_EDUS_PER_TRANSACTION
+ and self.queue._pending_edus_keyed
+ ):
+ _, val = self.queue._pending_edus_keyed.popitem()
+ pending_edus.append(val)
+
+ # Now we look for any PDUs to send, by getting up to 50 PDUs from the
+ # queue
+ self._pdus = self.queue._pending_pdus[:50]
+
+ if not self._pdus and not pending_edus:
+ return [], []
+
+ # if we've decided to send a transaction anyway, and we have room, we
+ # may as well send any pending RRs
+ if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
+ pending_edus.extend(self.queue._get_rr_edus(force_flush=True))
+
+ if self._pdus:
+ self._last_stream_ordering = self._pdus[
+ -1
+ ].internal_metadata.stream_ordering
+ assert self._last_stream_ordering
+
+ return self._pdus, pending_edus
+
+ async def __aexit__(self, exc_type, exc, tb):
+ if exc_type is not None:
+ # Failed to send transaction, so we bail out.
+ return
+
+ # Successfully sent transactions, so we remove pending PDUs from the queue
+ if self._pdus:
+ self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :]
+
+ # Succeeded to send the transaction so we record where we have sent up
+ # to in the various streams
+
+ if self._device_stream_id:
+ await self.queue._store.delete_device_msgs_for_remote(
+ self.queue._destination, self._device_stream_id
+ )
+ self.queue._last_device_stream_id = self._device_stream_id
+
+ # also mark the device updates as sent
+ if self._device_list_id:
+ logger.info(
+ "Marking as sent %r %r", self.queue._destination, self._device_list_id
+ )
+ await self.queue._store.mark_as_sent_devices_by_remote(
+ self.queue._destination, self._device_list_id
+ )
+ self.queue._last_device_list_stream_id = self._device_list_id
+
+ if self._last_stream_ordering:
+ # we sent some PDUs and it was successful, so update our
+ # last_successful_stream_ordering in the destinations table.
+ await self.queue._store.set_destination_last_successful_stream_ordering(
+ self.queue._destination, self._last_stream_ordering
+ )
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 3e07f925e0..07b740c2f2 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -36,9 +36,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-last_pdu_age_metric = Gauge(
- "synapse_federation_last_sent_pdu_age",
- "The age (in seconds) of the last PDU successfully sent to the given domain",
+last_pdu_ts_metric = Gauge(
+ "synapse_federation_last_sent_pdu_time",
+ "The timestamp of the last PDU which was successfully sent to the given domain",
labelnames=("server_name",),
)
@@ -65,16 +65,16 @@ class TransactionManager:
@measure_func("_send_new_transaction")
async def send_new_transaction(
- self, destination: str, pdus: List[EventBase], edus: List[Edu],
- ) -> bool:
+ self,
+ destination: str,
+ pdus: List[EventBase],
+ edus: List[Edu],
+ ) -> None:
"""
Args:
destination: The destination to send to (e.g. 'example.org')
pdus: In-order list of PDUs to send
edus: List of EDUs to send
-
- Returns:
- True iff the transaction was successful
"""
# Make a transaction-sending opentracing span. This span follows on from
@@ -93,8 +93,6 @@ class TransactionManager:
edu.strip_context()
with start_active_span_follows_from("send_transaction", span_contexts):
- success = True
-
logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id)
@@ -149,45 +147,29 @@ class TransactionManager:
response = await self._transport_layer.send_transaction(
transaction, json_data_cb
)
- code = 200
except HttpResponseException as e:
code = e.code
response = e.response
- if e.code in (401, 404, 429) or 500 <= e.code:
- logger.info(
- "TX [%s] {%s} got %d response", destination, txn_id, code
- )
- raise e
-
- logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
-
- if code == 200:
- for e_id, r in response.get("pdus", {}).items():
- if "error" in r:
- logger.warning(
- "TX [%s] {%s} Remote returned error for %s: %s",
- destination,
- txn_id,
- e_id,
- r,
- )
- else:
- for p in pdus:
+ set_tag(tags.ERROR, True)
+
+ logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
+ raise
+
+ logger.info("TX [%s] {%s} got 200 response", destination, txn_id)
+
+ for e_id, r in response.get("pdus", {}).items():
+ if "error" in r:
logger.warning(
- "TX [%s] {%s} Failed to send event %s",
+ "TX [%s] {%s} Remote returned error for %s: %s",
destination,
txn_id,
- p.event_id,
+ e_id,
+ r,
)
- success = False
- if success and pdus and destination in self._federation_metrics_domains:
+ if pdus and destination in self._federation_metrics_domains:
last_pdu = pdus[-1]
- last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
- last_pdu_age_metric.labels(server_name=destination).set(
- last_pdu_age / 1000
+ last_pdu_ts_metric.labels(server_name=destination).set(
+ last_pdu.origin_server_ts / 1000
)
-
- set_tag(tags.ERROR, not success)
- return success
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9c454e5885..df7e9fbbc2 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -42,7 +42,7 @@ class TransportLayerClient:
@log_function
def get_room_state_ids(self, destination, room_id, event_id):
- """ Requests all state for a given room from the given server at the
+ """Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
Args:
@@ -66,7 +66,7 @@ class TransportLayerClient:
@log_function
def get_event(self, destination, event_id, timeout=None):
- """ Requests the pdu with give id and origin from the given server.
+ """Requests the pdu with give id and origin from the given server.
Args:
destination (str): The host name of the remote homeserver we want
@@ -87,7 +87,7 @@ class TransportLayerClient:
@log_function
def backfill(self, destination, room_id, event_tuples, limit):
- """ Requests `limit` previous PDUs in a given context before list of
+ """Requests `limit` previous PDUs in a given context before list of
PDUs.
Args:
@@ -121,7 +121,7 @@ class TransportLayerClient:
@log_function
async def send_transaction(self, transaction, json_data_callback=None):
- """ Sends the given Transaction to its destination
+ """Sends the given Transaction to its destination
Args:
transaction (Transaction)
@@ -309,7 +309,11 @@ class TransportLayerClient:
@log_function
async def send_knock_v2(
- self, destination: str, room_id: str, event_id: str, content: JsonDict,
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ content: JsonDict,
) -> JsonDict:
"""
Sends a signed knock membership event to a remote server. This is the second
@@ -600,8 +604,7 @@ class TransportLayerClient:
@log_function
def get_group_profile(self, destination, group_id, requester_user_id):
- """Get a group profile
- """
+ """Get a group profile"""
path = _create_v1_path("/groups/%s/profile", group_id)
return self.client.get_json(
@@ -633,8 +636,7 @@ class TransportLayerClient:
@log_function
def get_group_summary(self, destination, group_id, requester_user_id):
- """Get a group summary
- """
+ """Get a group summary"""
path = _create_v1_path("/groups/%s/summary", group_id)
return self.client.get_json(
@@ -646,8 +648,7 @@ class TransportLayerClient:
@log_function
def get_rooms_in_group(self, destination, group_id, requester_user_id):
- """Get all rooms in a group
- """
+ """Get all rooms in a group"""
path = _create_v1_path("/groups/%s/rooms", group_id)
return self.client.get_json(
@@ -660,8 +661,7 @@ class TransportLayerClient:
def add_room_to_group(
self, destination, group_id, requester_user_id, room_id, content
):
- """Add a room to a group
- """
+ """Add a room to a group"""
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
return self.client.post_json(
@@ -675,8 +675,7 @@ class TransportLayerClient:
def update_room_in_group(
self, destination, group_id, requester_user_id, room_id, config_key, content
):
- """Update room in group
- """
+ """Update room in group"""
path = _create_v1_path(
"/groups/%s/room/%s/config/%s", group_id, room_id, config_key
)
@@ -690,8 +689,7 @@ class TransportLayerClient:
)
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
- """Remove a room from a group
- """
+ """Remove a room from a group"""
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
return self.client.delete_json(
@@ -703,8 +701,7 @@ class TransportLayerClient:
@log_function
def get_users_in_group(self, destination, group_id, requester_user_id):
- """Get users in a group
- """
+ """Get users in a group"""
path = _create_v1_path("/groups/%s/users", group_id)
return self.client.get_json(
@@ -716,8 +713,7 @@ class TransportLayerClient:
@log_function
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
- """Get users that have been invited to a group
- """
+ """Get users that have been invited to a group"""
path = _create_v1_path("/groups/%s/invited_users", group_id)
return self.client.get_json(
@@ -729,8 +725,7 @@ class TransportLayerClient:
@log_function
def accept_group_invite(self, destination, group_id, user_id, content):
- """Accept a group invite
- """
+ """Accept a group invite"""
path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
return self.client.post_json(
@@ -739,8 +734,7 @@ class TransportLayerClient:
@log_function
def join_group(self, destination, group_id, user_id, content):
- """Attempts to join a group
- """
+ """Attempts to join a group"""
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
return self.client.post_json(
@@ -751,8 +745,7 @@ class TransportLayerClient:
def invite_to_group(
self, destination, group_id, user_id, requester_user_id, content
):
- """Invite a user to a group
- """
+ """Invite a user to a group"""
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
@@ -779,8 +772,7 @@ class TransportLayerClient:
def remove_user_from_group(
self, destination, group_id, requester_user_id, user_id, content
):
- """Remove a user from a group
- """
+ """Remove a user from a group"""
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
@@ -821,8 +813,7 @@ class TransportLayerClient:
def update_group_summary_room(
self, destination, group_id, user_id, room_id, category_id, content
):
- """Update a room entry in a group summary
- """
+ """Update a room entry in a group summary"""
if category_id:
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
@@ -845,8 +836,7 @@ class TransportLayerClient:
def delete_group_summary_room(
self, destination, group_id, user_id, room_id, category_id
):
- """Delete a room entry in a group summary
- """
+ """Delete a room entry in a group summary"""
if category_id:
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
@@ -866,8 +856,7 @@ class TransportLayerClient:
@log_function
def get_group_categories(self, destination, group_id, requester_user_id):
- """Get all categories in a group
- """
+ """Get all categories in a group"""
path = _create_v1_path("/groups/%s/categories", group_id)
return self.client.get_json(
@@ -879,8 +868,7 @@ class TransportLayerClient:
@log_function
def get_group_category(self, destination, group_id, requester_user_id, category_id):
- """Get category info in a group
- """
+ """Get category info in a group"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.get_json(
@@ -894,8 +882,7 @@ class TransportLayerClient:
def update_group_category(
self, destination, group_id, requester_user_id, category_id, content
):
- """Update a category in a group
- """
+ """Update a category in a group"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.post_json(
@@ -910,8 +897,7 @@ class TransportLayerClient:
def delete_group_category(
self, destination, group_id, requester_user_id, category_id
):
- """Delete a category in a group
- """
+ """Delete a category in a group"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.delete_json(
@@ -923,8 +909,7 @@ class TransportLayerClient:
@log_function
def get_group_roles(self, destination, group_id, requester_user_id):
- """Get all roles in a group
- """
+ """Get all roles in a group"""
path = _create_v1_path("/groups/%s/roles", group_id)
return self.client.get_json(
@@ -936,8 +921,7 @@ class TransportLayerClient:
@log_function
def get_group_role(self, destination, group_id, requester_user_id, role_id):
- """Get a roles info
- """
+ """Get a roles info"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.get_json(
@@ -951,8 +935,7 @@ class TransportLayerClient:
def update_group_role(
self, destination, group_id, requester_user_id, role_id, content
):
- """Update a role in a group
- """
+ """Update a role in a group"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.post_json(
@@ -965,8 +948,7 @@ class TransportLayerClient:
@log_function
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
- """Delete a role in a group
- """
+ """Delete a role in a group"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.delete_json(
@@ -980,8 +962,7 @@ class TransportLayerClient:
def update_group_summary_user(
self, destination, group_id, requester_user_id, user_id, role_id, content
):
- """Update a users entry in a group
- """
+ """Update a users entry in a group"""
if role_id:
path = _create_v1_path(
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
@@ -999,8 +980,7 @@ class TransportLayerClient:
@log_function
def set_group_join_policy(self, destination, group_id, requester_user_id, content):
- """Sets the join policy for a group
- """
+ """Sets the join policy for a group"""
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
return self.client.put_json(
@@ -1015,8 +995,7 @@ class TransportLayerClient:
def delete_group_summary_user(
self, destination, group_id, requester_user_id, user_id, role_id
):
- """Delete a users entry in a group
- """
+ """Delete a users entry in a group"""
if role_id:
path = _create_v1_path(
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
@@ -1032,8 +1011,7 @@ class TransportLayerClient:
)
def bulk_get_publicised_groups(self, destination, user_ids):
- """Get the groups a list of users are publicising
- """
+ """Get the groups a list of users are publicising"""
path = _create_v1_path("/get_groups_publicised")
@@ -1053,6 +1031,38 @@ class TransportLayerClient:
return self.client.get_json(destination=destination, path=path)
+ async def get_space_summary(
+ self,
+ destination: str,
+ room_id: str,
+ suggested_only: bool,
+ max_rooms_per_space: Optional[int],
+ exclude_rooms: List[str],
+ ) -> JsonDict:
+ """
+ Args:
+ destination: The remote server
+ room_id: The room ID to ask about.
+ suggested_only: if True, only suggested rooms will be returned
+ max_rooms_per_space: an optional limit to the number of children to be
+ returned per space
+ exclude_rooms: a list of any rooms we can skip
+ """
+ path = _create_path(
+ FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc2946/spaces/%s", room_id
+ )
+
+ params = {
+ "suggested_only": suggested_only,
+ "exclude_rooms": exclude_rooms,
+ }
+ if max_rooms_per_space is not None:
+ params["max_rooms_per_space"] = max_rooms_per_space
+
+ return await self.client.post_json(
+ destination=destination, path=path, data=params
+ )
+
def get_info_of_users(self, destination: str, user_ids: List[str]):
"""
Args:
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 961d16ac5a..294031e2a0 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -18,9 +18,10 @@
import functools
import logging
import re
-from typing import Optional, Tuple, Type
+from typing import Container, Mapping, Optional, Sequence, Tuple, Type
import synapse
+from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import (
@@ -28,8 +29,7 @@ from synapse.api.urls import (
FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX,
)
-from synapse.http.endpoint import parse_and_validate_server_name
-from synapse.http.server import JsonResource
+from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
assert_params_in_dict,
parse_boolean_from_args,
@@ -46,7 +46,9 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.server import HomeServer
-from synapse.types import ThirdPartyInstanceID, get_domain_from_id
+from synapse.types import JsonDict, ThirdPartyInstanceID, get_domain_from_id
+from synapse.util.ratelimitutils import FederationRateLimiter
+from synapse.util.stringutils import parse_and_validate_server_name
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
@@ -146,7 +148,7 @@ class Authenticator:
):
raise FederationDeniedError(origin)
- if not json_request["signatures"]:
+ if origin is None or not json_request["signatures"]:
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED
)
@@ -366,7 +368,10 @@ class BaseFederationServlet:
continue
server.register_paths(
- method, (pattern,), self._wrap(code), self.__class__.__name__,
+ method,
+ (pattern,),
+ self._wrap(code),
+ self.__class__.__name__,
)
@@ -383,7 +388,7 @@ class FederationSendServlet(BaseFederationServlet):
# This is when someone is trying to send us a bunch of data.
async def on_PUT(self, origin, content, query, transaction_id):
- """ Called on PUT /send/<transaction_id>/
+ """Called on PUT /send/<transaction_id>/
Args:
request (twisted.web.http.Request): The HTTP request.
@@ -482,10 +487,9 @@ class FederationQueryServlet(BaseFederationServlet):
# This is when we receive a server-server Query
async def on_GET(self, origin, content, query, query_type):
- return await self.handler.on_query_request(
- query_type,
- {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()},
- )
+ args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}
+ args["origin"] = origin
+ return await self.handler.on_query_request(query_type, args)
class FederationMakeJoinServlet(BaseFederationServlet):
@@ -936,8 +940,7 @@ class FederationVersionServlet(BaseFederationServlet):
class FederationGroupsProfileServlet(BaseFederationServlet):
- """Get/set the basic profile of a group on behalf of a user
- """
+ """Get/set the basic profile of a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/profile"
@@ -976,8 +979,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
class FederationGroupsRoomsServlet(BaseFederationServlet):
- """Get the rooms in a group on behalf of a user
- """
+ """Get the rooms in a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
@@ -992,8 +994,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
- """Add/remove room from group
- """
+ """Add/remove room from group"""
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@@ -1021,8 +1022,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
- """Update room config in group
- """
+ """Update room config in group"""
PATH = (
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@@ -1042,8 +1042,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
class FederationGroupsUsersServlet(BaseFederationServlet):
- """Get the users in a group on behalf of a user
- """
+ """Get the users in a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/users"
@@ -1058,8 +1057,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
- """Get the users that have been invited to a group
- """
+ """Get the users that have been invited to a group"""
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
@@ -1076,8 +1074,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
class FederationGroupsInviteServlet(BaseFederationServlet):
- """Ask a group server to invite someone to the group
- """
+ """Ask a group server to invite someone to the group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@@ -1094,8 +1091,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
- """Accept an invitation from the group server
- """
+ """Accept an invitation from the group server"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
@@ -1109,8 +1105,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
class FederationGroupsJoinServlet(BaseFederationServlet):
- """Attempt to join a group
- """
+ """Attempt to join a group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
@@ -1124,8 +1119,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
- """Leave or kick a user from the group
- """
+ """Leave or kick a user from the group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@@ -1142,8 +1136,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
- """A group server has invited a local user
- """
+ """A group server has invited a local user"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@@ -1157,8 +1150,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
- """A group server has removed a local user
- """
+ """A group server has removed a local user"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@@ -1174,8 +1166,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
- """A group or user's server renews their attestation
- """
+ """A group or user's server renews their attestation"""
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
@@ -1209,7 +1200,17 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
raise SynapseError(403, "requester_user_id doesn't match origin")
if category_id == "":
- raise SynapseError(400, "category_id cannot be empty string")
+ raise SynapseError(
+ 400, "category_id cannot be empty string", Codes.INVALID_PARAM
+ )
+
+ if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+ raise SynapseError(
+ 400,
+ "category_id may not be longer than %s characters"
+ % (MAX_GROUP_CATEGORYID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
resp = await self.handler.update_group_summary_room(
group_id,
@@ -1237,8 +1238,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
class FederationGroupsCategoriesServlet(BaseFederationServlet):
- """Get all categories for a group
- """
+ """Get all categories for a group"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
@@ -1253,8 +1253,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
class FederationGroupsCategoryServlet(BaseFederationServlet):
- """Add/remove/get a category in a group
- """
+ """Add/remove/get a category in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
@@ -1277,6 +1276,14 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
+ if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+ raise SynapseError(
+ 400,
+ "category_id may not be longer than %s characters"
+ % (MAX_GROUP_CATEGORYID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
resp = await self.handler.upsert_group_category(
group_id, requester_user_id, category_id, content
)
@@ -1299,8 +1306,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
class FederationGroupsRolesServlet(BaseFederationServlet):
- """Get roles in a group
- """
+ """Get roles in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
@@ -1315,8 +1321,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
class FederationGroupsRoleServlet(BaseFederationServlet):
- """Add/remove/get a role in a group
- """
+ """Add/remove/get a role in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
@@ -1335,7 +1340,17 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
raise SynapseError(403, "requester_user_id doesn't match origin")
if role_id == "":
- raise SynapseError(400, "role_id cannot be empty string")
+ raise SynapseError(
+ 400, "role_id cannot be empty string", Codes.INVALID_PARAM
+ )
+
+ if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+ raise SynapseError(
+ 400,
+ "role_id may not be longer than %s characters"
+ % (MAX_GROUP_ROLEID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
resp = await self.handler.update_group_role(
group_id, requester_user_id, role_id, content
@@ -1380,6 +1395,14 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
+ if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+ raise SynapseError(
+ 400,
+ "role_id may not be longer than %s characters"
+ % (MAX_GROUP_ROLEID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
resp = await self.handler.update_group_summary_user(
group_id,
requester_user_id,
@@ -1406,8 +1429,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
- """Get roles in a group
- """
+ """Get roles in a group"""
PATH = "/get_groups_publicised"
@@ -1420,8 +1442,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
- """Sets whether a group is joinable without an invite or knock
- """
+ """Sets whether a group is joinable without an invite or knock"""
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
@@ -1437,6 +1458,40 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
return 200, new_content
+class FederationSpaceSummaryServlet(BaseFederationServlet):
+ PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
+ PATH = "/spaces/(?P<room_id>[^/]*)"
+
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Mapping[bytes, Sequence[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
+ suggested_only = content.get("suggested_only", False)
+ if not isinstance(suggested_only, bool):
+ raise SynapseError(
+ 400, "'suggested_only' must be a boolean", Codes.BAD_JSON
+ )
+
+ exclude_rooms = content.get("exclude_rooms", [])
+ if not isinstance(exclude_rooms, list) or any(
+ not isinstance(x, str) for x in exclude_rooms
+ ):
+ raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON)
+
+ max_rooms_per_space = content.get("max_rooms_per_space")
+ if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int):
+ raise SynapseError(
+ 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON
+ )
+
+ return 200, await self.handler.federation_space_summary(
+ room_id, suggested_only, max_rooms_per_space, exclude_rooms
+ )
+
+
class RoomComplexityServlet(BaseFederationServlet):
"""
Indicates to other servers how complex (and therefore likely
@@ -1538,18 +1593,24 @@ DEFAULT_SERVLET_GROUPS = (
)
-def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=None):
+def register_servlets(
+ hs: HomeServer,
+ resource: HttpServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ servlet_groups: Optional[Container[str]] = None,
+):
"""Initialize and register servlet classes.
Will by default register all servlets. For custom behaviour, pass in
a list of servlet_groups to register.
Args:
- hs (synapse.server.HomeServer): homeserver
- resource (TransportLayerServer): resource class to register to
- authenticator (Authenticator): authenticator to use
- ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
- servlet_groups (list[str], optional): List of servlet groups to register.
+ hs: homeserver
+ resource: resource class to register to
+ authenticator: authenticator to use
+ ratelimiter: ratelimiter to use
+ servlet_groups: List of servlet groups to register.
Defaults to ``DEFAULT_SERVLET_GROUPS``.
"""
if not servlet_groups:
@@ -1564,6 +1625,14 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
server_name=hs.hostname,
).register(resource)
+ if hs.config.experimental.spaces_enabled:
+ FederationSpaceSummaryServlet(
+ handler=hs.get_space_summary_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
if "openid" in servlet_groups:
for servletclass in OPENID_SERVLET_CLASSES:
servletclass(
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 64d98fc8f6..b662c42621 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True)
class Edu(JsonEncodedObject):
- """ An Edu represents a piece of data sent from one homeserver to another.
+ """An Edu represents a piece of data sent from one homeserver to another.
In comparison to Pdus, Edus are not persisted for a long time on disk, are
not meaningful beyond a given pair of homeservers, and don't have an
@@ -63,7 +63,7 @@ class Edu(JsonEncodedObject):
class Transaction(JsonEncodedObject):
- """ A transaction is a list of Pdus and Edus to be sent to a remote home
+ """A transaction is a list of Pdus and Edus to be sent to a remote home
server with some extra metadata.
Example transaction::
@@ -99,7 +99,7 @@ class Transaction(JsonEncodedObject):
]
def __init__(self, transaction_id=None, pdus=[], **kwargs):
- """ If we include a list of pdus then we decode then as PDU's
+ """If we include a list of pdus then we decode then as PDU's
automatically.
"""
@@ -111,7 +111,7 @@ class Transaction(JsonEncodedObject):
@staticmethod
def create_new(pdus, **kwargs):
- """ Used to create a new transaction. Will auto fill out
+ """Used to create a new transaction. Will auto fill out
transaction_id and origin_server_ts keys.
"""
if "origin_server_ts" not in kwargs:
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 41cf07cc88..368c44708d 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -37,13 +37,16 @@ An attestation is a signed blob of json that looks like:
import logging
import random
-from typing import Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
from signedjson.sign import sign_json
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import get_domain_from_id
+from synapse.types import JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -61,18 +64,21 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
class GroupAttestationSigning:
- """Creates and verifies group attestations.
- """
+ """Creates and verifies group attestations."""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.keyring = hs.get_keyring()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.signing_key = hs.signing_key
async def verify_attestation(
- self, attestation, group_id, user_id, server_name=None
- ):
+ self,
+ attestation: JsonDict,
+ group_id: str,
+ user_id: str,
+ server_name: Optional[str] = None,
+ ) -> None:
"""Verifies that the given attestation matches the given parameters.
An optional server_name can be supplied to explicitly set which server's
@@ -101,16 +107,18 @@ class GroupAttestationSigning:
if valid_until_ms < now:
raise SynapseError(400, "Attestation expired")
+ assert server_name is not None
await self.keyring.verify_json_for_server(
server_name, attestation, now, "Group attestation"
)
- def create_attestation(self, group_id, user_id):
+ def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
"""Create an attestation for the group_id and user_id with default
validity length.
"""
- validity_period = DEFAULT_ATTESTATION_LENGTH_MS
- validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
+ validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random.uniform(
+ *DEFAULT_ATTESTATION_JITTER
+ )
valid_until_ms = int(self.clock.time_msec() + validity_period)
return sign_json(
@@ -125,10 +133,9 @@ class GroupAttestationSigning:
class GroupAttestionRenewer:
- """Responsible for sending and receiving attestation updates.
- """
+ """Responsible for sending and receiving attestation updates."""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.assestations = hs.get_groups_attestation_signing()
@@ -141,9 +148,10 @@ class GroupAttestionRenewer:
self._start_renew_attestations, 30 * 60 * 1000
)
- async def on_renew_attestation(self, group_id, user_id, content):
- """When a remote updates an attestation
- """
+ async def on_renew_attestation(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
+ """When a remote updates an attestation"""
attestation = content["attestation"]
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
@@ -157,12 +165,11 @@ class GroupAttestionRenewer:
return {}
- def _start_renew_attestations(self):
+ def _start_renew_attestations(self) -> None:
return run_as_background_process("renew_attestations", self._renew_attestations)
- async def _renew_attestations(self):
- """Called periodically to check if we need to update any of our attestations
- """
+ async def _renew_attestations(self) -> None:
+ """Called periodically to check if we need to update any of our attestations"""
now = self.clock.time_msec()
@@ -170,7 +177,7 @@ class GroupAttestionRenewer:
now + UPDATE_ATTESTATION_TIME_MS
)
- async def _renew_attestation(group_user: Tuple[str, str]):
+ async def _renew_attestation(group_user: Tuple[str, str]) -> None:
group_id, user_id = group_user
try:
if not self.is_mine_id(group_id):
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 0d042cbfac..4b16a4ac29 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -16,11 +16,17 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, SynapseError
-from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
+from synapse.types import GroupID, JsonDict, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -32,8 +38,13 @@ logger = logging.getLogger(__name__)
# TODO: Flairs
+# Note that the maximum lengths are somewhat arbitrary.
+MAX_SHORT_DESC_LEN = 1000
+MAX_LONG_DESC_LEN = 10000
+
+
class GroupsServerWorkerHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler()
@@ -48,16 +59,21 @@ class GroupsServerWorkerHandler:
self.profile_handler = hs.get_profile_handler()
async def check_group_is_ours(
- self, group_id, requester_user_id, and_exists=False, and_is_admin=None
- ):
+ self,
+ group_id: str,
+ requester_user_id: str,
+ and_exists: bool = False,
+ and_is_admin: Optional[str] = None,
+ ) -> Optional[dict]:
"""Check that the group is ours, and optionally if it exists.
If group does exist then return group.
Args:
- group_id (str)
- and_exists (bool): whether to also check if group exists
- and_is_admin (str): whether to also check if given str is a user_id
+ group_id: The group ID to check.
+ requester_user_id: The user ID of the requester.
+ and_exists: whether to also check if group exists
+ and_is_admin: whether to also check if given str is a user_id
that is an admin
"""
if not self.is_mine_id(group_id):
@@ -80,7 +96,9 @@ class GroupsServerWorkerHandler:
return group
- async def get_group_summary(self, group_id, requester_user_id):
+ async def get_group_summary(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get the summary for a group as seen by requester_user_id.
The group summary consists of the profile of the room, and a curated
@@ -113,6 +131,8 @@ class GroupsServerWorkerHandler:
entry = await self.room_list_handler.generate_room_entry(
room_id, len(joined_users), with_alias=False, allow_private=True
)
+ if entry is None:
+ continue
entry = dict(entry) # so we don't change what's cached
entry.pop("room_id", None)
@@ -120,22 +140,22 @@ class GroupsServerWorkerHandler:
rooms.sort(key=lambda e: e.get("order", 0))
- for entry in users:
- user_id = entry["user_id"]
+ for user in users:
+ user_id = user["user_id"]
if not self.is_mine_id(requester_user_id):
attestation = await self.store.get_remote_attestation(group_id, user_id)
if not attestation:
continue
- entry["attestation"] = attestation
+ user["attestation"] = attestation
else:
- entry["attestation"] = self.attestations.create_attestation(
+ user["attestation"] = self.attestations.create_attestation(
group_id, user_id
)
user_profile = await self.profile_handler.get_profile_from_cache(user_id)
- entry.update(user_profile)
+ user.update(user_profile)
users.sort(key=lambda e: e.get("order", 0))
@@ -158,46 +178,44 @@ class GroupsServerWorkerHandler:
"user": membership_info,
}
- async def get_group_categories(self, group_id, requester_user_id):
- """Get all categories in a group (as seen by user)
- """
+ async def get_group_categories(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
+ """Get all categories in a group (as seen by user)"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
categories = await self.store.get_group_categories(group_id=group_id)
return {"categories": categories}
- async def get_group_category(self, group_id, requester_user_id, category_id):
- """Get a specific category in a group (as seen by user)
- """
+ async def get_group_category(
+ self, group_id: str, requester_user_id: str, category_id: str
+ ) -> JsonDict:
+ """Get a specific category in a group (as seen by user)"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- res = await self.store.get_group_category(
+ return await self.store.get_group_category(
group_id=group_id, category_id=category_id
)
- logger.info("group %s", res)
-
- return res
-
- async def get_group_roles(self, group_id, requester_user_id):
- """Get all roles in a group (as seen by user)
- """
+ async def get_group_roles(self, group_id: str, requester_user_id: str) -> JsonDict:
+ """Get all roles in a group (as seen by user)"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
roles = await self.store.get_group_roles(group_id=group_id)
return {"roles": roles}
- async def get_group_role(self, group_id, requester_user_id, role_id):
- """Get a specific role in a group (as seen by user)
- """
+ async def get_group_role(
+ self, group_id: str, requester_user_id: str, role_id: str
+ ) -> JsonDict:
+ """Get a specific role in a group (as seen by user)"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- res = await self.store.get_group_role(group_id=group_id, role_id=role_id)
- return res
+ return await self.store.get_group_role(group_id=group_id, role_id=role_id)
- async def get_group_profile(self, group_id, requester_user_id):
- """Get the group profile as seen by requester_user_id
- """
+ async def get_group_profile(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
+ """Get the group profile as seen by requester_user_id"""
await self.check_group_is_ours(group_id, requester_user_id)
@@ -218,7 +236,9 @@ class GroupsServerWorkerHandler:
else:
raise SynapseError(404, "Unknown group")
- async def get_users_in_group(self, group_id, requester_user_id):
+ async def get_users_in_group(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get the users in group as seen by requester_user_id.
The ordering is arbitrary at the moment
@@ -267,7 +287,9 @@ class GroupsServerWorkerHandler:
return {"chunk": chunk, "total_user_count_estimate": len(user_results)}
- async def get_invited_users_in_group(self, group_id, requester_user_id):
+ async def get_invited_users_in_group(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get the users that have been invited to a group as seen by requester_user_id.
The ordering is arbitrary at the moment
@@ -297,7 +319,9 @@ class GroupsServerWorkerHandler:
return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
- async def get_rooms_in_group(self, group_id, requester_user_id):
+ async def get_rooms_in_group(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get the rooms in group as seen by requester_user_id
This returns rooms in order of decreasing number of joined users
@@ -335,17 +359,21 @@ class GroupsServerWorkerHandler:
class GroupsServerHandler(GroupsServerWorkerHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
# Ensure attestations get renewed
hs.get_groups_attestation_renewer()
async def update_group_summary_room(
- self, group_id, requester_user_id, room_id, category_id, content
- ):
- """Add/update a room to the group summary
- """
+ self,
+ group_id: str,
+ requester_user_id: str,
+ room_id: str,
+ category_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
+ """Add/update a room to the group summary"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -367,10 +395,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
async def delete_group_summary_room(
- self, group_id, requester_user_id, room_id, category_id
- ):
- """Remove a room from the summary
- """
+ self, group_id: str, requester_user_id: str, room_id: str, category_id: str
+ ) -> JsonDict:
+ """Remove a room from the summary"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -381,7 +408,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
- async def set_group_join_policy(self, group_id, requester_user_id, content):
+ async def set_group_join_policy(
+ self, group_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Sets the group join policy.
Currently supported policies are:
@@ -401,10 +430,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
async def update_group_category(
- self, group_id, requester_user_id, category_id, content
- ):
- """Add/Update a group category
- """
+ self, group_id: str, requester_user_id: str, category_id: str, content: JsonDict
+ ) -> JsonDict:
+ """Add/Update a group category"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -421,9 +449,10 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
- async def delete_group_category(self, group_id, requester_user_id, category_id):
- """Delete a group category
- """
+ async def delete_group_category(
+ self, group_id: str, requester_user_id: str, category_id: str
+ ) -> JsonDict:
+ """Delete a group category"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -434,9 +463,10 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
- async def update_group_role(self, group_id, requester_user_id, role_id, content):
- """Add/update a role in a group
- """
+ async def update_group_role(
+ self, group_id: str, requester_user_id: str, role_id: str, content: JsonDict
+ ) -> JsonDict:
+ """Add/update a role in a group"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -451,9 +481,10 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
- async def delete_group_role(self, group_id, requester_user_id, role_id):
- """Remove role from group
- """
+ async def delete_group_role(
+ self, group_id: str, requester_user_id: str, role_id: str
+ ) -> JsonDict:
+ """Remove role from group"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -463,10 +494,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
async def update_group_summary_user(
- self, group_id, requester_user_id, user_id, role_id, content
- ):
- """Add/update a users entry in the group summary
- """
+ self,
+ group_id: str,
+ requester_user_id: str,
+ user_id: str,
+ role_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
+ """Add/update a users entry in the group summary"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -486,10 +521,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
async def delete_group_summary_user(
- self, group_id, requester_user_id, user_id, role_id
- ):
- """Remove a user from the group summary
- """
+ self, group_id: str, requester_user_id: str, user_id: str, role_id: str
+ ) -> JsonDict:
+ """Remove a user from the group summary"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -500,26 +534,43 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
- async def update_group_profile(self, group_id, requester_user_id, content):
- """Update the group profile
- """
+ async def update_group_profile(
+ self, group_id: str, requester_user_id: str, content: JsonDict
+ ) -> None:
+ """Update the group profile"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
profile = {}
- for keyname in ("name", "avatar_url", "short_description", "long_description"):
+ for keyname, max_length in (
+ ("name", MAX_DISPLAYNAME_LEN),
+ ("avatar_url", MAX_AVATAR_URL_LEN),
+ ("short_description", MAX_SHORT_DESC_LEN),
+ ("long_description", MAX_LONG_DESC_LEN),
+ ):
if keyname in content:
value = content[keyname]
if not isinstance(value, str):
- raise SynapseError(400, "%r value is not a string" % (keyname,))
+ raise SynapseError(
+ 400,
+ "%r value is not a string" % (keyname,),
+ errcode=Codes.INVALID_PARAM,
+ )
+ if len(value) > max_length:
+ raise SynapseError(
+ 400,
+ "Invalid %s parameter" % (keyname,),
+ errcode=Codes.INVALID_PARAM,
+ )
profile[keyname] = value
await self.store.update_group_profile(group_id, profile)
- async def add_room_to_group(self, group_id, requester_user_id, room_id, content):
- """Add room to group
- """
+ async def add_room_to_group(
+ self, group_id: str, requester_user_id: str, room_id: str, content: JsonDict
+ ) -> JsonDict:
+ """Add room to group"""
RoomID.from_string(room_id) # Ensure valid room id
await self.check_group_is_ours(
@@ -533,10 +584,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
async def update_room_in_group(
- self, group_id, requester_user_id, room_id, config_key, content
- ):
- """Update room in group
- """
+ self,
+ group_id: str,
+ requester_user_id: str,
+ room_id: str,
+ config_key: str,
+ content: JsonDict,
+ ) -> JsonDict:
+ """Update room in group"""
RoomID.from_string(room_id) # Ensure valid room id
await self.check_group_is_ours(
@@ -554,9 +609,10 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
- async def remove_room_from_group(self, group_id, requester_user_id, room_id):
- """Remove room from group
- """
+ async def remove_room_from_group(
+ self, group_id: str, requester_user_id: str, room_id: str
+ ) -> JsonDict:
+ """Remove room from group"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -565,13 +621,16 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
- async def invite_to_group(self, group_id, user_id, requester_user_id, content):
- """Invite user to group
- """
+ async def invite_to_group(
+ self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
+ """Invite user to group"""
group = await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
+ if not group:
+ raise SynapseError(400, "Group does not exist", errcode=Codes.BAD_STATE)
# TODO: Check if user knocked
@@ -594,6 +653,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
+ assert isinstance(
+ groups_local, GroupsLocalHandler
+ ), "Workers cannot invites users to groups."
res = await groups_local.on_invite(group_id, user_id, content)
local_attestation = None
else:
@@ -629,6 +691,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
local_attestation=local_attestation,
remote_attestation=remote_attestation,
)
+ return {"state": "join"}
elif res["state"] == "invite":
await self.store.add_group_invite(group_id, user_id)
return {"state": "invite"}
@@ -637,13 +700,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
else:
raise SynapseError(502, "Unknown state returned by HS")
- async def _add_user(self, group_id, user_id, content):
+ async def _add_user(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> Optional[JsonDict]:
"""Add a user to a group based on a content dict.
See accept_invite, join_group.
"""
if not self.hs.is_mine_id(user_id):
- local_attestation = self.attestations.create_attestation(group_id, user_id)
+ local_attestation = self.attestations.create_attestation(
+ group_id, user_id
+ ) # type: Optional[JsonDict]
remote_attestation = content["attestation"]
@@ -667,7 +734,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return local_attestation
- async def accept_invite(self, group_id, requester_user_id, content):
+ async def accept_invite(
+ self, group_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
"""User tries to accept an invite to the group.
This is different from them asking to join, and so should error if no
@@ -686,7 +755,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"state": "join", "attestation": local_attestation}
- async def join_group(self, group_id, requester_user_id, content):
+ async def join_group(
+ self, group_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
"""User tries to join the group.
This will error if the group requires an invite/knock to join
@@ -695,6 +766,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
group_info = await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True
)
+ if not group_info:
+ raise SynapseError(404, "Group does not exist", errcode=Codes.NOT_FOUND)
if group_info["join_policy"] != "open":
raise SynapseError(403, "Group is not publicly joinable")
@@ -702,26 +775,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"state": "join", "attestation": local_attestation}
- async def knock(self, group_id, requester_user_id, content):
- """A user requests becoming a member of the group
- """
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- raise NotImplementedError()
-
- async def accept_knock(self, group_id, requester_user_id, content):
- """Accept a users knock to the room.
-
- Errors if the user hasn't knocked, rather than inviting them.
- """
-
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- raise NotImplementedError()
-
async def remove_user_from_group(
- self, group_id, user_id, requester_user_id, content
- ):
+ self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Remove a user from the group; either a user is leaving or an admin
kicked them.
"""
@@ -743,6 +799,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if is_kick:
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
+ assert isinstance(
+ groups_local, GroupsLocalHandler
+ ), "Workers cannot remove users from groups."
await groups_local.user_removed_from_group(group_id, user_id, {})
else:
await self.transport_client.remove_user_from_group_notification(
@@ -759,14 +818,15 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
- async def create_group(self, group_id, requester_user_id, content):
- group = await self.check_group_is_ours(group_id, requester_user_id)
-
+ async def create_group(
+ self, group_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
logger.info("Attempting to create group with ID: %r", group_id)
# parsing the id into a GroupID validates it.
group_id_obj = GroupID.from_string(group_id)
+ group = await self.check_group_is_ours(group_id, requester_user_id)
if group:
raise SynapseError(400, "Group already exists")
@@ -811,7 +871,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
local_attestation = self.attestations.create_attestation(
group_id, requester_user_id
- )
+ ) # type: Optional[JsonDict]
else:
local_attestation = None
remote_attestation = None
@@ -834,15 +894,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"group_id": group_id}
- async def delete_group(self, group_id, requester_user_id):
+ async def delete_group(self, group_id: str, requester_user_id: str) -> None:
"""Deletes a group, kicking out all current members.
Only group admins or server admins can call this request
Args:
- group_id (str)
- request_user_id (str)
-
+ group_id: The group ID to delete.
+ requester_user_id: The user requesting to delete the group.
"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
@@ -865,6 +924,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def _kick_user_from_group(user_id):
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
+ assert isinstance(
+ groups_local, GroupsLocalHandler
+ ), "Workers cannot kick users from groups."
await groups_local.user_removed_from_group(group_id, user_id, {})
else:
await self.transport_client.remove_user_from_group_notification(
@@ -896,9 +958,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
await self.store.delete_group(group_id)
-def _parse_join_policy_from_contents(content):
- """Given a content for a request, return the specified join policy or None
- """
+def _parse_join_policy_from_contents(content: JsonDict) -> Optional[str]:
+ """Given a content for a request, return the specified join policy or None"""
join_policy_dict = content.get("m.join_policy")
if join_policy_dict:
@@ -907,9 +968,8 @@ def _parse_join_policy_from_contents(content):
return None
-def _parse_join_policy_dict(join_policy_dict):
- """Given a dict for the "m.join_policy" config return the join policy specified
- """
+def _parse_join_policy_dict(join_policy_dict: JsonDict) -> str:
+ """Given a dict for the "m.join_policy" config return the join policy specified"""
join_policy_type = join_policy_dict.get("type")
if not join_policy_type:
return "invite"
@@ -919,7 +979,7 @@ def _parse_join_policy_dict(join_policy_dict):
return join_policy_type
-def _parse_visibility_from_contents(content):
+def _parse_visibility_from_contents(content: JsonDict) -> bool:
"""Given a content for a request parse out whether the entity should be
public or not
"""
@@ -933,7 +993,7 @@ def _parse_visibility_from_contents(content):
return is_public
-def _parse_visibility_dict(visibility):
+def _parse_visibility_dict(visibility: JsonDict) -> bool:
"""Given a dict for the "m.visibility" config return if the entity should
be public or not
"""
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bb81c0e81d..aade2c4a3a 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -24,7 +24,7 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.types import UserID
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
class BaseHandler:
"""
Common base class for the event handlers.
+
+ Deprecated: new code should not use this. Instead, Handler classes should define the
+ fields they actually need. The utility methods should either be factored out to
+ standalone helper functions, or to different Handler classes.
"""
def __init__(self, hs: "HomeServer"):
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 341135822e..1ce6d697ed 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,12 +13,155 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import random
from typing import TYPE_CHECKING, List, Tuple
+from synapse.replication.http.account_data import (
+ ReplicationAddTagRestServlet,
+ ReplicationRemoveTagRestServlet,
+ ReplicationRoomAccountDataRestServlet,
+ ReplicationUserAccountDataRestServlet,
+)
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
+
+
+class AccountDataHandler:
+ def __init__(self, hs: "HomeServer"):
+ self._store = hs.get_datastore()
+ self._instance_name = hs.get_instance_name()
+ self._notifier = hs.get_notifier()
+
+ self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
+ self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
+ self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
+ self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
+ self._account_data_writers = hs.config.worker.writers.account_data
+
+ async def add_account_data_to_room(
+ self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
+ ) -> int:
+ """Add some account_data to a room for a user.
+
+ Args:
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ account_data_type: The type of account_data to add.
+ content: A json object to associate with the tag.
+
+ Returns:
+ The maximum stream ID.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.add_account_data_to_room(
+ user_id, room_id, account_data_type, content
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+
+ return max_stream_id
+ else:
+ response = await self._room_data_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ account_data_type=account_data_type,
+ content=content,
+ )
+ return response["max_stream_id"]
+
+ async def add_account_data_for_user(
+ self, user_id: str, account_data_type: str, content: JsonDict
+ ) -> int:
+ """Add some account_data to a room for a user.
+
+ Args:
+ user_id: The user to add a tag for.
+ account_data_type: The type of account_data to add.
+ content: A json object to associate with the tag.
+
+ Returns:
+ The maximum stream ID.
+ """
+
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.add_account_data_for_user(
+ user_id, account_data_type, content
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+ return max_stream_id
+ else:
+ response = await self._user_data_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ account_data_type=account_data_type,
+ content=content,
+ )
+ return response["max_stream_id"]
+
+ async def add_tag_to_room(
+ self, user_id: str, room_id: str, tag: str, content: JsonDict
+ ) -> int:
+ """Add a tag to a room for a user.
+
+ Args:
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ tag: The tag name to add.
+ content: A json object to associate with the tag.
+
+ Returns:
+ The next account data ID.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.add_tag_to_room(
+ user_id, room_id, tag, content
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+ return max_stream_id
+ else:
+ response = await self._add_tag_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ tag=tag,
+ content=content,
+ )
+ return response["max_stream_id"]
+
+ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
+ """Remove a tag from a room for a user.
+
+ Returns:
+ The next account data ID.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.remove_tag_from_room(
+ user_id, room_id, tag
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+ return max_stream_id
+ else:
+ response = await self._remove_tag_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ tag=tag,
+ )
+ return response["max_stream_id"]
class AccountDataEventSource:
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index ce97fa70d7..125902ac17 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -30,7 +30,7 @@ from synapse.types import UserID
from synapse.util import stringutils
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 8476256a59..2a25af6288 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
import twisted
import twisted.internet.error
@@ -22,6 +23,9 @@ from twisted.web.resource import Resource
from synapse.app import check_bind_error
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
ACME_REGISTER_FAIL_ERROR = """
@@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
class AcmeHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain
- async def start_listening(self):
+ async def start_listening(self) -> None:
from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug
@@ -69,7 +73,9 @@ class AcmeHandler:
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
)
try:
- self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
+ self.reactor.listenTCP(
+ self.hs.config.acme_port, srv, backlog=50, interface=host
+ )
except twisted.internet.error.CannotListenError as e:
check_bind_error(e, host, bind_addresses)
@@ -85,7 +91,7 @@ class AcmeHandler:
logger.error(ACME_REGISTER_FAIL_ERROR)
raise
- async def provision_certificate(self):
+ async def provision_certificate(self) -> None:
logger.warning("Reprovisioning %s", self._acme_domain)
@@ -110,5 +116,3 @@ class AcmeHandler:
except Exception:
logger.exception("Failed saving!")
raise
-
- return True
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 7294649d71..ae2a9dd9c2 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
imported conditionally.
"""
import logging
+from typing import Dict, Iterable, List
import attr
+import pem
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from josepy import JWKRSA
@@ -36,20 +38,27 @@ from txacme.util import generate_private_key
from zope.interface import implementer
from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTCP
from twisted.python.filepath import FilePath
from twisted.python.url import URL
+from twisted.web.resource import IResource
logger = logging.getLogger(__name__)
-def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
+def create_issuing_service(
+ reactor: IReactorTCP,
+ acme_url: str,
+ account_key_file: str,
+ well_known_resource: IResource,
+) -> AcmeIssuingService:
"""Create an ACME issuing service, and attach it to a web Resource
Args:
reactor: twisted reactor
- acme_url (str): URL to use to request certificates
- account_key_file (str): where to store the account key
- well_known_resource (twisted.web.IResource): web resource for .well-known.
+ acme_url: URL to use to request certificates
+ account_key_file: where to store the account key
+ well_known_resource: web resource for .well-known.
we will attach a child resource for "acme-challenge".
Returns:
@@ -83,18 +92,20 @@ class ErsatzStore:
A store that only stores in memory.
"""
- certs = attr.ib(default=attr.Factory(dict))
+ certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
- def store(self, server_name, pem_objects):
+ def store(
+ self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
+ ) -> defer.Deferred:
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)
-def load_or_create_client_key(key_file):
+def load_or_create_client_key(key_file: str) -> JWKRSA:
"""Load the ACME account key from a file, creating it if it does not exist.
Args:
- key_file (str): name of the file to use as the account key
+ key_file: name of the file to use as the account key
"""
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
# hardcode the 'client.key' filename
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index a703944543..c494de49a3 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -13,27 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
-from typing import List
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from synapse.api.constants import Membership
-from synapse.events import FrozenEvent
-from synapse.types import RoomStreamToken, StateMap
+from synapse.events import EventBase
+from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class AdminHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.storage = hs.get_storage()
self.state_store = self.storage.state
- async def get_whois(self, user):
+ async def get_whois(self, user: UserID) -> JsonDict:
connections = []
sessions = await self.store.get_user_ip_and_agents(user)
@@ -53,7 +57,7 @@ class AdminHandler(BaseHandler):
return ret
- async def get_user(self, user):
+ async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
ret = await self.store.get_user_by_id(user.to_string())
if ret:
@@ -64,12 +68,12 @@ class AdminHandler(BaseHandler):
ret["threepids"] = threepids
return ret
- async def export_user_data(self, user_id, writer):
+ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
"""Write all data we have on the user to the given writer.
Args:
- user_id (str)
- writer (ExfiltrationWriter)
+ user_id: The user ID to fetch data of.
+ writer: The writer to write to.
Returns:
Resolves when all data for a user has been written.
@@ -128,7 +132,8 @@ class AdminHandler(BaseHandler):
from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering)
- written_events = set() # Events that we've processed in this room
+ # Events that we've processed in this room
+ written_events = set() # type: Set[str]
# We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track
@@ -140,8 +145,8 @@ class AdminHandler(BaseHandler):
# The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen
- # events "children". dict[str, set[str]]
- unseen_to_child_events = {}
+ # events "children".
+ unseen_to_child_events = {} # type: Dict[str, Set[str]]
# We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most
@@ -197,38 +202,44 @@ class AdminHandler(BaseHandler):
return writer.finished()
-class ExfiltrationWriter:
- """Interface used to specify how to write exported data.
- """
+class ExfiltrationWriter(metaclass=abc.ABCMeta):
+ """Interface used to specify how to write exported data."""
- def write_events(self, room_id: str, events: List[FrozenEvent]):
- """Write a batch of events for a room.
- """
- pass
+ @abc.abstractmethod
+ def write_events(self, room_id: str, events: List[EventBase]) -> None:
+ """Write a batch of events for a room."""
+ raise NotImplementedError()
- def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
+ @abc.abstractmethod
+ def write_state(
+ self, room_id: str, event_id: str, state: StateMap[EventBase]
+ ) -> None:
"""Write the state at the given event in the room.
This only gets called for backward extremities rather than for each
event.
"""
- pass
+ raise NotImplementedError()
- def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
+ @abc.abstractmethod
+ def write_invite(
+ self, room_id: str, event: EventBase, state: StateMap[dict]
+ ) -> None:
"""Write an invite for the room, with associated invite state.
Args:
- room_id
- event
- state: A subset of the state at the
- invite, with a subset of the event keys (type, state_key
- content and sender)
+ room_id: The room ID the invite is for.
+ event: The invite event.
+ state: A subset of the state at the invite, with a subset of the
+ event keys (type, state_key content and sender).
"""
+ raise NotImplementedError()
- def finished(self):
+ @abc.abstractmethod
+ def finished(self) -> Any:
"""Called when all data has successfully been exported and written.
This functions return value is passed to the caller of
`export_user_data`.
"""
- pass
+ raise NotImplementedError()
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 5c6458eb52..996f9e5deb 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -38,7 +38,7 @@ from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, User
from synapse.util.metrics import Measure
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -290,7 +290,9 @@ class ApplicationServicesHandler:
if not interested:
continue
presence_events, _ = await presence_source.get_new_events(
- user=user, service=service, from_key=from_key,
+ user=user,
+ service=service,
+ from_key=from_key,
)
time_now = self.clock.time_msec()
events.extend(
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 143d12b48d..d537ea8137 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import time
import unicodedata
@@ -22,6 +21,7 @@ import urllib.parse
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Dict,
Iterable,
@@ -36,6 +36,8 @@ import attr
import bcrypt
import pymacaroons
+from twisted.web.server import Request
+
from synapse.api.constants import LoginType
from synapse.api.errors import (
AuthError,
@@ -47,22 +49,28 @@ from synapse.api.errors import (
UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
-from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.ui_auth import (
+ INTERACTIVE_AUTH_CHECKERS,
+ UIAuthSessionDataConstants,
+)
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http import get_request_user_agent
from synapse.http.server import finish_request, respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
+from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
+from synapse.util.async_helpers import maybe_awaitable
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
-from ._base import BaseHandler
-
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -113,7 +121,9 @@ def convert_client_dict_legacy_fields_to_identifier(
# Ensure the identifier has a type
if "type" not in identifier:
raise SynapseError(
- 400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
+ 400,
+ "'identifier' dict has no key 'type'",
+ errcode=Codes.MISSING_PARAM,
)
return identifier
@@ -161,6 +171,16 @@ class SsoLoginExtraAttributes:
extra_attributes = attr.ib(type=JsonDict)
+@attr.s(slots=True, frozen=True)
+class LoginTokenAttributes:
+ """Data we store in a short-term login token"""
+
+ user_id = attr.ib(type=str)
+
+ # the SSO Identity Provider that the user authenticated with, to get this token
+ auth_provider_id = attr.ib(type=str)
+
+
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@@ -193,39 +213,27 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
- self._sso_enabled = (
- hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
- )
-
- # we keep this as a list despite the O(N^2) implication so that we can
- # keep PASSWORD first and avoid confusing clients which pick the first
- # type in the list. (NB that the spec doesn't require us to do so and
- # clients which favour types that they don't understand over those that
- # they do are technically broken)
+ self._password_localdb_enabled = hs.config.password_localdb_enabled
# start out by assuming PASSWORD is enabled; we will remove it later if not.
- login_types = []
- if hs.config.password_localdb_enabled:
- login_types.append(LoginType.PASSWORD)
+ login_types = set()
+ if self._password_localdb_enabled:
+ login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
- if hasattr(provider, "get_supported_login_types"):
- for t in provider.get_supported_login_types().keys():
- if t not in login_types:
- login_types.append(t)
+ login_types.update(provider.get_supported_login_types().keys())
if not self._password_enabled:
+ login_types.discard(LoginType.PASSWORD)
+
+ # Some clients just pick the first type in the list. In this case, we want
+ # them to use PASSWORD (rather than token or whatever), so we want to make sure
+ # that comes first, where it's present.
+ self._supported_login_types = []
+ if LoginType.PASSWORD in login_types:
+ self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD)
-
- self._supported_login_types = login_types
-
- # Login types and UI Auth types have a heavy overlap, but are not
- # necessarily identical. Login types have SSO (and other login types)
- # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
- ui_auth_types = login_types.copy()
- if self._sso_enabled:
- ui_auth_types.append(LoginType.SSO)
- self._supported_ui_auth_types = ui_auth_types
+ self._supported_login_types.extend(login_types)
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
@@ -235,6 +243,9 @@ class AuthHandler(BaseHandler):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
+ # The number of seconds to keep a UI auth session active.
+ self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout
+
# Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
@@ -266,10 +277,6 @@ class AuthHandler(BaseHandler):
# authenticating for an operation to occur on their account.
self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
- # The following template is shown after a successful user interactive
- # authentication session. It tells the user they can close the window.
- self._sso_auth_success_template = hs.config.sso_auth_success_template
-
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
@@ -290,9 +297,8 @@ class AuthHandler(BaseHandler):
requester: Requester,
request: SynapseRequest,
request_body: Dict[str, Any],
- clientip: str,
description: str,
- ) -> Tuple[dict, str]:
+ ) -> Tuple[dict, Optional[str]]:
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -307,8 +313,6 @@ class AuthHandler(BaseHandler):
request_body: The body of the request sent by the client
- clientip: The IP address of the client.
-
description: A human readable string to be displayed to the user that
describes the operation happening on their account.
@@ -319,7 +323,8 @@ class AuthHandler(BaseHandler):
have been given only in a previous call).
'session_id' is the ID of this session, either passed in by the
- client or assigned by this call
+ client or assigned by this call. This is None if UI auth was
+ skipped (by re-using a previous validation).
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
@@ -332,41 +337,95 @@ class AuthHandler(BaseHandler):
user is too high to proceed
"""
+ if not requester.access_token_id:
+ raise ValueError("Cannot validate a user without an access token")
+ if self._ui_auth_session_timeout:
+ last_validated = await self.store.get_access_token_last_validated(
+ requester.access_token_id
+ )
+ if self.clock.time_msec() - last_validated < self._ui_auth_session_timeout:
+ # Return the input parameters, minus the auth key, which matches
+ # the logic in check_ui_auth.
+ request_body.pop("auth", None)
+ return request_body, None
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
- self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
+ self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
# build a list of supported flows
- flows = [[login_type] for login_type in self._supported_ui_auth_types]
+ supported_ui_auth_types = await self._get_available_ui_auth_types(
+ requester.user
+ )
+ flows = [[login_type] for login_type in supported_ui_auth_types]
+
+ def get_new_session_data() -> JsonDict:
+ return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
try:
result, params, session_id = await self.check_ui_auth(
- flows, request, request_body, clientip, description
+ flows,
+ request,
+ request_body,
+ description,
+ get_new_session_data,
)
except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
- self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
+ self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
raise
# find the completed login type
- for login_type in self._supported_ui_auth_types:
+ for login_type in supported_ui_auth_types:
if login_type not in result:
continue
- user_id = result[login_type]
+ validated_user_id = result[login_type]
break
else:
# this can't happen
raise Exception("check_auth returned True but no successful login type")
# check that the UI auth matched the access token
- if user_id != requester.user.to_string():
+ if validated_user_id != requester_user_id:
raise AuthError(403, "Invalid auth")
+ # Note that the access token has been validated.
+ await self.store.update_access_token_last_validated(requester.access_token_id)
+
return params, session_id
+ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+ """Get a list of the authentication types this user can use"""
+
+ ui_auth_types = set()
+
+ # if the HS supports password auth, and the user has a non-null password, we
+ # support password auth
+ if self._password_localdb_enabled and self._password_enabled:
+ lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+ if lookupres:
+ _, password_hash = lookupres
+ if password_hash:
+ ui_auth_types.add(LoginType.PASSWORD)
+
+ # also allow auth from password providers
+ for provider in self.password_providers:
+ for t in provider.get_supported_login_types().keys():
+ if t == LoginType.PASSWORD and not self._password_enabled:
+ continue
+ ui_auth_types.add(t)
+
+ # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+ # from sso to mxid.
+ if await self.hs.get_sso_handler().get_identity_providers_for_user(
+ user.to_string()
+ ):
+ ui_auth_types.add(LoginType.SSO)
+
+ return ui_auth_types
+
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -380,8 +439,8 @@ class AuthHandler(BaseHandler):
flows: List[List[str]],
request: SynapseRequest,
clientdict: Dict[str, Any],
- clientip: str,
description: str,
+ get_new_session_data: Optional[Callable[[], JsonDict]] = None,
) -> Tuple[dict, dict, str]:
"""
Takes a dictionary sent by the client in the login / registration
@@ -402,11 +461,16 @@ class AuthHandler(BaseHandler):
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
- clientip: The IP address of the client.
-
description: A human readable string to be displayed to the user that
describes the operation happening on their account.
+ get_new_session_data:
+ an optional callback which will be called when starting a new session.
+ it should return data to be stored as part of the session.
+
+ The keys of the returned data should be entries in
+ UIAuthSessionDataConstants.
+
Returns:
A tuple of (creds, params, session_id).
@@ -423,24 +487,26 @@ class AuthHandler(BaseHandler):
all the stages in any of the permitted flows.
"""
- authdict = None
sid = None # type: Optional[str]
- if clientdict and "auth" in clientdict:
- authdict = clientdict["auth"]
- del clientdict["auth"]
- if "session" in authdict:
- sid = authdict["session"]
+ authdict = clientdict.pop("auth", {})
+ if "session" in authdict:
+ sid = authdict["session"]
# Convert the URI and method to strings.
- uri = request.uri.decode("utf-8")
+ uri = request.uri.decode("utf-8") # type: ignore
method = request.method.decode("utf-8")
# If there's no session ID, create a new session.
if not sid:
+ new_session_data = get_new_session_data() if get_new_session_data else {}
+
session = await self.store.create_ui_auth_session(
clientdict, uri, method, description
)
+ for k, v in new_session_data.items():
+ await self.set_session_data(session.session_id, k, v)
+
else:
try:
session = await self.store.get_ui_auth_session(sid)
@@ -496,7 +562,8 @@ class AuthHandler(BaseHandler):
# authentication flow.
await self.store.set_ui_auth_clientdict(sid, clientdict)
- user_agent = request.get_user_agent("")
+ user_agent = get_request_user_agent(request)
+ clientip = request.getClientIP()
await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip
@@ -524,6 +591,8 @@ class AuthHandler(BaseHandler):
creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
+ # If all the required credentials have been supplied, the user has
+ # successfully completed the UI auth process!
if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
# include the password in the case of registering, so only log
@@ -589,7 +658,8 @@ class AuthHandler(BaseHandler):
Args:
session_id: The ID of this session as returned from check_auth
- key: The key to store the data under
+ key: The key to store the data under. An entry from
+ UIAuthSessionDataConstants.
value: The data to store
"""
try:
@@ -605,7 +675,8 @@ class AuthHandler(BaseHandler):
Args:
session_id: The ID of this session as returned from check_auth
- key: The key to store the data under
+ key: The key the data was stored under. An entry from
+ UIAuthSessionDataConstants.
default: Value to return if the key has not been set
"""
try:
@@ -669,7 +740,9 @@ class AuthHandler(BaseHandler):
}
def _auth_dict_for_flows(
- self, flows: List[List[str]], session_id: str,
+ self,
+ flows: List[List[str]],
+ session_id: str,
) -> Dict[str, Any]:
public_flows = []
for f in flows:
@@ -699,6 +772,7 @@ class AuthHandler(BaseHandler):
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
+ is_appservice_ghost: bool = False,
) -> str:
"""
Creates a new access token for the user with the given user ID.
@@ -715,6 +789,7 @@ class AuthHandler(BaseHandler):
we should always have a device ID)
valid_until_ms: when the token is valid until. None for
no expiry.
+ is_appservice_ghost: Whether the user is an application ghost user
Returns:
The access token for the user's session.
Raises:
@@ -735,7 +810,11 @@ class AuthHandler(BaseHandler):
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
)
- await self.auth.check_auth_blocking(user_id)
+ if (
+ not is_appservice_ghost
+ or self.hs.config.appservice.track_appservice_user_ips
+ ):
+ await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user(
@@ -807,6 +886,19 @@ class AuthHandler(BaseHandler):
)
return result
+ def can_change_password(self) -> bool:
+ """Get whether users on this server are allowed to change or set a password.
+
+ Both `config.password_enabled` and `config.password_localdb_enabled` must be true.
+
+ Note that any account (even SSO accounts) are allowed to add passwords if the above
+ is true.
+
+ Returns:
+ Whether users on this server are allowed to change or set a password
+ """
+ return self._password_enabled and self._password_localdb_enabled
+
def get_supported_login_types(self) -> Iterable[str]:
"""Get a the login types supported for the /login API
@@ -820,8 +912,10 @@ class AuthHandler(BaseHandler):
return self._supported_login_types
async def validate_login(
- self, login_submission: Dict[str, Any], ratelimit: bool = False,
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ self,
+ login_submission: Dict[str, Any],
+ ratelimit: bool = False,
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
@@ -963,8 +1057,10 @@ class AuthHandler(BaseHandler):
raise
async def _validate_userid_login(
- self, username: str, login_submission: Dict[str, Any],
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ self,
+ username: str,
+ login_submission: Dict[str, Any],
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
@@ -1019,7 +1115,7 @@ class AuthHandler(BaseHandler):
if result:
return result
- if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+ if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
# we've already checked that there is a (valid) password field
@@ -1042,7 +1138,7 @@ class AuthHandler(BaseHandler):
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
- ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -1093,18 +1189,16 @@ class AuthHandler(BaseHandler):
return None
return user_id
- async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
- auth_api = self.hs.get_auth()
- user_id = None
+ async def validate_short_term_login_token(
+ self, login_token: str
+ ) -> LoginTokenAttributes:
try:
- macaroon = pymacaroons.Macaroon.deserialize(login_token)
- user_id = auth_api.get_user_id_from_macaroon(macaroon)
- auth_api.validate_macaroon(macaroon, "login", user_id)
+ res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
- await self.auth.check_auth_blocking(user_id)
- return user_id
+ await self.auth.check_auth_blocking(res.user_id)
+ return res
async def delete_access_token(self, access_token: str):
"""Invalidate a single access token
@@ -1133,7 +1227,7 @@ class AuthHandler(BaseHandler):
async def delete_access_tokens_for_user(
self,
user_id: str,
- except_token_id: Optional[str] = None,
+ except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
):
"""Invalidate access tokens belonging to a user
@@ -1273,12 +1367,12 @@ class AuthHandler(BaseHandler):
else:
return False
- async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+ async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.
Args:
- redirect_url: The URL to redirect to the SSO provider.
+ request: The incoming HTTP request
session_id: The user interactive authentication session ID.
Returns:
@@ -1288,48 +1382,64 @@ class AuthHandler(BaseHandler):
session = await self.store.get_ui_auth_session(session_id)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- return self._sso_auth_confirm_template.render(
- description=session.description, redirect_url=redirect_url,
+
+ user_id_to_verify = await self.get_session_data(
+ session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+ ) # type: str
+
+ idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
+ user_id_to_verify
)
- async def complete_sso_ui_auth(
- self, registered_user_id: str, session_id: str, request: SynapseRequest,
- ):
- """Having figured out a mxid for this user, complete the HTTP request
+ if not idps:
+ # we checked that the user had some remote identities before offering an SSO
+ # flow, so either it's been deleted or the client has requested SSO despite
+ # it not being offered.
+ raise SynapseError(400, "User has no SSO identities")
- Args:
- registered_user_id: The registered user ID to complete SSO login for.
- request: The request to complete.
- client_redirect_url: The URL to which to redirect the user at the end of the
- process.
- """
- # Mark the stage of the authentication as successful.
- # Save the user who authenticated with SSO, this will be used to ensure
- # that the account be modified is also the person who logged in.
- await self.store.mark_ui_auth_stage_complete(
- session_id, LoginType.SSO, registered_user_id
+ # for now, just pick one
+ idp_id, sso_auth_provider = next(iter(idps.items()))
+ if len(idps) > 0:
+ logger.warning(
+ "User %r has previously logged in with multiple SSO IdPs; arbitrarily "
+ "picking %r",
+ user_id_to_verify,
+ idp_id,
+ )
+
+ redirect_url = await sso_auth_provider.handle_redirect_request(
+ request, None, session_id
)
- # Render the HTML and return.
- html = self._sso_auth_success_template
- respond_with_html(request, 200, html)
+ return self._sso_auth_confirm_template.render(
+ description=session.description,
+ redirect_url=redirect_url,
+ idp=sso_auth_provider,
+ )
async def complete_sso_login(
self,
registered_user_id: str,
- request: SynapseRequest,
+ auth_provider_id: str,
+ request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
+ new_user: bool = False,
):
"""Having figured out a mxid for this user, complete the HTTP request
Args:
registered_user_id: The registered user ID to complete SSO login for.
+ auth_provider_id: The id of the SSO Identity provider that was used for
+ login. This will be stored in the login token for future tracking in
+ prometheus metrics.
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
extra_attributes: Extra attributes which will be passed to the client
during successful login. Must be JSON serializable.
+ new_user: True if we should use wording appropriate to a user who has just
+ registered.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
@@ -1338,33 +1448,51 @@ class AuthHandler(BaseHandler):
respond_with_html(request, 403, self._sso_account_deactivated_template)
return
+ profile = await self.store.get_profileinfo(
+ UserID.from_string(registered_user_id).localpart
+ )
+
self._complete_sso_login(
- registered_user_id, request, client_redirect_url, extra_attributes
+ registered_user_id,
+ auth_provider_id,
+ request,
+ client_redirect_url,
+ extra_attributes,
+ new_user=new_user,
+ user_profile_data=profile,
)
def _complete_sso_login(
self,
registered_user_id: str,
- request: SynapseRequest,
+ auth_provider_id: str,
+ request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
+ new_user: bool = False,
+ user_profile_data: Optional[ProfileInfo] = None,
):
"""
The synchronous portion of complete_sso_login.
This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
"""
+
+ if user_profile_data is None:
+ user_profile_data = ProfileInfo(None, None)
+
# Store any extra attributes which will be passed in the login response.
# Note that this is per-user so it may overwrite a previous value, this
# is considered OK since the newest SSO attributes should be most valid.
if extra_attributes:
self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
- self._clock.time_msec(), extra_attributes,
+ self._clock.time_msec(),
+ extra_attributes,
)
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
- registered_user_id
+ registered_user_id, auth_provider_id=auth_provider_id
)
# Append the login token to the original redirect URL (i.e. with its query
@@ -1385,12 +1513,27 @@ class AuthHandler(BaseHandler):
# Remove the query parameters from the redirect URL to get a shorter version of
# it. This is only to display a human-readable URL in the template, but not the
# URL we redirect users to.
- redirect_url_no_params = client_redirect_url.split("?")[0]
+ url_parts = urllib.parse.urlsplit(client_redirect_url)
+
+ if url_parts.scheme == "https":
+ # for an https uri, just show the netloc (ie, the hostname. Specifically,
+ # the bit between "//" and "/"; this includes any potential
+ # "username:password@" prefix.)
+ display_url = url_parts.netloc
+ else:
+ # for other uris, strip the query-params (including the login token) and
+ # fragment.
+ display_url = urllib.parse.urlunsplit(
+ (url_parts.scheme, url_parts.netloc, url_parts.path, "", "")
+ )
html = self._sso_redirect_confirm_template.render(
- display_url=redirect_url_no_params,
+ display_url=display_url,
redirect_url=redirect_url,
server_name=self._server_name,
+ new_user=new_user,
+ user_id=registered_user_id,
+ user_profile=user_profile_data,
)
respond_with_html(request, 200, html)
@@ -1428,8 +1571,8 @@ class AuthHandler(BaseHandler):
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({param_name: param})
+ query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
+ query.append((param_name, param))
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)
@@ -1455,15 +1598,48 @@ class MacaroonGenerator:
return macaroon.serialize()
def generate_short_term_login_token(
- self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ self,
+ user_id: str,
+ auth_provider_id: str,
+ duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
+ macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
return macaroon.serialize()
+ def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
+ """Verify a short-term-login macaroon
+
+ Checks that the given token is a valid, unexpired short-term-login token
+ minted by this server.
+
+ Args:
+ token: the login token to verify
+
+ Returns:
+ the user_id that this token is valid for
+
+ Raises:
+ MacaroonVerificationFailedException if the verification failed
+ """
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+ user_id = get_value_from_macaroon(macaroon, "user_id")
+ auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+
+ v = pymacaroons.Verifier()
+ v.satisfy_exact("gen = 1")
+ v.satisfy_exact("type = login")
+ v.satisfy_general(lambda c: c.startswith("user_id = "))
+ v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+ satisfy_expiry(v, self.hs.get_clock().time_msec)
+ v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
+
+ return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+
def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
@@ -1599,6 +1775,10 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out
# until it completes.
- result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
- if inspect.isawaitable(result):
- await result
+ await maybe_awaitable(
+ g(
+ user_id=user_id,
+ device_id=device_id,
+ access_token=access_token,
+ )
+ )
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index f4ea0a9767..5060936f94 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -13,22 +13,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import urllib
-from typing import TYPE_CHECKING, Dict, Optional, Tuple
+import urllib.parse
+from typing import TYPE_CHECKING, Dict, List, Optional
from xml.etree import ElementTree as ET
+import attr
+
from twisted.web.client import PartialDownloadError
-from synapse.api.errors import Codes, LoginError
+from synapse.api.errors import HttpResponseException
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
+class CasError(Exception):
+ """Used to catch errors when validating the CAS ticket."""
+
+ def __init__(self, error, error_description=None):
+ self.error = error
+ self.error_description = error_description
+
+ def __str__(self):
+ if self.error_description:
+ return "{}: {}".format(self.error, self.error_description)
+ return self.error
+
+
+@attr.s(slots=True, frozen=True)
+class CasResponse:
+ username = attr.ib(type=str)
+ attributes = attr.ib(type=Dict[str, List[Optional[str]]])
+
+
class CasHandler:
"""
Utility class for to handle the response from a CAS SSO service.
@@ -40,6 +62,7 @@ class CasHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self._hostname = hs.hostname
+ self._store = hs.get_datastore()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
@@ -50,6 +73,22 @@ class CasHandler:
self._http_client = hs.get_proxied_http_client()
+ # identifier for the external_ids table
+ self.idp_id = "cas"
+
+ # user-facing name of this auth provider
+ self.idp_name = "CAS"
+
+ # we do not currently support brands/icons for CAS auth, but this is required by
+ # the SsoIdentityProvider protocol type.
+ self.idp_icon = None
+ self.idp_brand = None
+ self.unstable_idp_brand = None
+
+ self._sso_handler = hs.get_sso_handler()
+
+ self._sso_handler.register_identity_provider(self)
+
def _build_service_param(self, args: Dict[str, str]) -> str:
"""
Generates a value to use as the "service" parameter when redirecting or
@@ -61,22 +100,27 @@ class CasHandler:
Returns:
The URL to use as a "service" parameter.
"""
- return "%s%s?%s" % (
+ return "%s?%s" % (
self._cas_service_url,
- "/_matrix/client/r0/login/cas/ticket",
urllib.parse.urlencode(args),
)
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
- ) -> Tuple[str, Optional[str]]:
+ ) -> CasResponse:
"""
- Validate a CAS ticket with the server, parse the response, and return the user and display name.
+ Validate a CAS ticket with the server, and return the parsed the response.
Args:
ticket: The CAS ticket from the client.
service_args: Additional arguments to include in the service URL.
- Should be the same as those passed to `get_redirect_url`.
+ Should be the same as those passed to `handle_redirect_request`.
+
+ Raises:
+ CasError: If there's an error parsing the CAS response.
+
+ Returns:
+ The parsed CAS response.
"""
uri = self._cas_server_url + "/proxyValidate"
args = {
@@ -89,77 +133,91 @@ class CasHandler:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
+ except HttpResponseException as e:
+ description = (
+ (
+ 'Authorization server responded with a "{status}" error '
+ "while exchanging the authorization code."
+ ).format(status=e.code),
+ )
+ raise CasError("server_error", description) from e
- user, attributes = self._parse_cas_response(body)
- displayname = attributes.pop(self._cas_displayname_attribute, None)
-
- for required_attribute, required_value in self._cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- return user, displayname
+ return self._parse_cas_response(body)
- def _parse_cas_response(
- self, cas_response_body: bytes
- ) -> Tuple[str, Dict[str, Optional[str]]]:
+ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
"""
Retrieve the user and other parameters from the CAS response.
Args:
cas_response_body: The response from the CAS query.
+ Raises:
+ CasError: If there's an error parsing the CAS response.
+
Returns:
- A tuple of the user and a mapping of other attributes.
+ The parsed CAS response.
"""
- user = None
- attributes = {}
- try:
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise Exception("root of CAS response is not serviceResponse")
- success = root[0].tag.endswith("authenticationSuccess")
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- for attribute in child:
- # ElementTree library expands the namespace in
- # attribute tags to the full URL of the namespace.
- # We don't care about namespace here and it will always
- # be encased in curly braces, so we remove them.
- tag = attribute.tag
- if "}" in tag:
- tag = tag.split("}")[1]
- attributes[tag] = attribute.text
- if user is None:
- raise Exception("CAS response does not contain user")
- except Exception:
- logger.exception("Error parsing CAS response")
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not success:
- raise LoginError(
- 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
+
+ # Ensure the response is valid.
+ root = ET.fromstring(cas_response_body)
+ if not root.tag.endswith("serviceResponse"):
+ raise CasError(
+ "missing_service_response",
+ "root of CAS response is not serviceResponse",
)
- return user, attributes
- def get_redirect_url(self, service_args: Dict[str, str]) -> str:
- """
- Generates a URL for the CAS server where the client should be redirected.
+ success = root[0].tag.endswith("authenticationSuccess")
+ if not success:
+ raise CasError("unsucessful_response", "Unsuccessful CAS response")
+
+ # Iterate through the nodes and pull out the user and any extra attributes.
+ user = None
+ attributes = {} # type: Dict[str, List[Optional[str]]]
+ for child in root[0]:
+ if child.tag.endswith("user"):
+ user = child.text
+ if child.tag.endswith("attributes"):
+ for attribute in child:
+ # ElementTree library expands the namespace in
+ # attribute tags to the full URL of the namespace.
+ # We don't care about namespace here and it will always
+ # be encased in curly braces, so we remove them.
+ tag = attribute.tag
+ if "}" in tag:
+ tag = tag.split("}")[1]
+ attributes.setdefault(tag, []).append(attribute.text)
+
+ # Ensure a user was found.
+ if user is None:
+ raise CasError("no_user", "CAS response does not contain user")
+
+ return CasResponse(user, attributes)
+
+ async def handle_redirect_request(
+ self,
+ request: SynapseRequest,
+ client_redirect_url: Optional[bytes],
+ ui_auth_session_id: Optional[str] = None,
+ ) -> str:
+ """Generates a URL for the CAS server where the client should be redirected.
Args:
- service_args: Additional arguments to include in the final redirect URL.
+ request: the incoming HTTP request
+ client_redirect_url: the URL that we should redirect the
+ client to after login (or None for UI Auth).
+ ui_auth_session_id: The session ID of the ongoing UI Auth (or
+ None if this is a login).
Returns:
- The URL to redirect the client to.
+ URL to redirect to
"""
+
+ if ui_auth_session_id:
+ service_args = {"session": ui_auth_session_id}
+ else:
+ assert client_redirect_url
+ service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
+
args = urllib.parse.urlencode(
{"service": self._build_service_param(service_args)}
)
@@ -201,59 +259,136 @@ class CasHandler:
args["redirectUrl"] = client_redirect_url
if session:
args["session"] = session
- username, user_display_name = await self._validate_ticket(ticket, args)
-
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
- # Get the matrix ID from the CAS username.
- user_id = await self._map_cas_user_to_matrix_user(
- username, user_display_name, user_agent, ip_address
+ try:
+ cas_response = await self._validate_ticket(ticket, args)
+ except CasError as e:
+ logger.exception("Could not validate ticket")
+ self._sso_handler.render_error(request, e.error, e.error_description, 401)
+ return
+
+ await self._handle_cas_response(
+ request, cas_response, client_redirect_url, session
)
+ async def _handle_cas_response(
+ self,
+ request: SynapseRequest,
+ cas_response: CasResponse,
+ client_redirect_url: Optional[str],
+ session: Optional[str],
+ ) -> None:
+ """Handle a CAS response to a ticket request.
+
+ Assumes that the response has been validated. Maps the user onto an MXID,
+ registering them if necessary, and returns a response to the browser.
+
+ Args:
+ request: the incoming request from the browser. We'll respond to it with an
+ HTML page or a redirect
+
+ cas_response: The parsed CAS response.
+
+ client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
+ This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
+
+ session: The session parameter from the `/cas/ticket` HTTP request, if given.
+ This should be the UI Auth session id.
+ """
+
+ # first check if we're doing a UIA
if session:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, session, request,
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self.idp_id,
+ cas_response.username,
+ session,
+ request,
)
- else:
- # If this not a UI auth request than there must be a redirect URL.
- assert client_redirect_url
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url
- )
+ # otherwise, we're handling a login request.
- async def _map_cas_user_to_matrix_user(
- self,
- remote_user_id: str,
- display_name: Optional[str],
- user_agent: str,
- ip_address: str,
- ) -> str:
- """
- Given a CAS username, retrieve the user ID for it and possibly register the user.
+ # Ensure that the attributes of the logged in user meet the required
+ # attributes.
+ if not self._sso_handler.check_required_attributes(
+ request, cas_response.attributes, self._cas_required_attributes
+ ):
+ return
- Args:
- remote_user_id: The username from the CAS response.
- display_name: The display name from the CAS response.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
+ # Call the mapper to register/login the user
- Returns:
- The user ID associated with this response.
+ # If this not a UI auth request than there must be a redirect URL.
+ assert client_redirect_url is not None
+
+ try:
+ await self._complete_cas_login(cas_response, request, client_redirect_url)
+ except MappingException as e:
+ logger.exception("Could not map user")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+
+ async def _complete_cas_login(
+ self,
+ cas_response: CasResponse,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ ) -> None:
"""
+ Given a CAS response, complete the login flow
- localpart = map_username_to_mxid_localpart(remote_user_id)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = await self._auth_handler.check_user_exists(user_id)
+ Retrieves the remote user ID, registers the user if necessary, and serves
+ a redirect back to the client with a login-token.
- # If the user does not exist, register it.
- if not registered_user_id:
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=display_name,
- user_agent_ips=[(user_agent, ip_address)],
+ Args:
+ cas_response: The parsed CAS response.
+ request: The request to respond to
+ client_redirect_url: The redirect URL passed in by the client.
+
+ Raises:
+ MappingException if there was a problem mapping the response to a user.
+ RedirectException: some mapping providers may raise this if they need
+ to redirect to an interstitial page.
+ """
+ # Note that CAS does not support a mapping provider, so the logic is hard-coded.
+ localpart = map_username_to_mxid_localpart(cas_response.username)
+
+ async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Map from CAS attributes to user attributes.
+ """
+ # Due to the grandfathering logic matching any previously registered
+ # mxids it isn't expected for there to be any failures.
+ if failures:
+ raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
+
+ # Arbitrarily use the first attribute found.
+ display_name = cas_response.attributes.get(
+ self._cas_displayname_attribute, [None]
+ )[0]
+
+ return UserAttributes(localpart=localpart, display_name=display_name)
+
+ async def grandfather_existing_users() -> Optional[str]:
+ # Since CAS did not always use the user_external_ids table, always
+ # to attempt to map to existing users.
+ user_id = UserID(localpart, self._hostname).to_string()
+
+ logger.debug(
+ "Looking for existing account based on mapped %s",
+ user_id,
)
- return registered_user_id
+ users = await self._store.get_users_by_id_case_insensitive(user_id)
+ if users:
+ registered_user_id = list(users.keys())[0]
+ logger.info("Grandfathering mapping to %s", registered_user_id)
+ return registered_user_id
+
+ return None
+
+ await self._sso_handler.complete_sso_login_request(
+ self.idp_id,
+ cas_response.username,
+ request,
+ client_redirect_url,
+ cas_response_to_user_attributes,
+ grandfather_existing_users,
+ )
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 366787ca33..9167918257 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -18,12 +18,12 @@ from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID, create_requester
+from synapse.types import Requester, UserID, create_requester
from ._base import BaseHandler
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -53,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler):
self._account_validity_enabled = hs.config.account_validity_enabled
async def deactivate_account(
- self, user_id: str, erase_data: bool, id_server: Optional[str] = None
+ self,
+ user_id: str,
+ erase_data: bool,
+ requester: Requester,
+ id_server: Optional[str] = None,
+ by_admin: bool = False,
) -> bool:
"""Deactivate a user's account
Args:
user_id: ID of user to be deactivated
erase_data: whether to GDPR-erase the user's data
+ requester: The user attempting to make this change.
id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
+ by_admin: Whether this change was made by an administrator.
Returns:
True if identity server supports removing threepids, otherwise False.
@@ -113,6 +120,13 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.user_set_password_hash(user_id, None)
+ # Most of the pushers will have been deleted when we logged out the
+ # associated devices above, but we still need to delete pushers not
+ # associated with devices, e.g. email pushers.
+ await self.store.delete_all_pushers_for_user(user_id)
+
+ # Set the user as no longer active. This will prevent their profile
+ # from being replicated.
user = UserID.from_string(user_id)
await self._profile_handler.set_active([user], False, False)
@@ -125,6 +139,12 @@ class DeactivateAccountHandler(BaseHandler):
# Mark the user as erased, if they asked for that
if erase_data:
+ user = UserID.from_string(user_id)
+ # Remove avatar URL from this user
+ await self._profile_handler.set_avatar_url(user, requester, "", by_admin)
+ # Remove displayname from this user
+ await self._profile_handler.set_displayname(user, requester, "", by_admin)
+
logger.info("Marking %s as erased", user_id)
await self.store.mark_user_erased(user_id)
@@ -186,8 +206,7 @@ class DeactivateAccountHandler(BaseHandler):
run_as_background_process("user_parter_loop", self._user_parter_loop)
async def _user_parter_loop(self) -> None:
- """Loop that parts deactivated users from rooms
- """
+ """Loop that parts deactivated users from rooms"""
self._user_parter_running = True
logger.info("Starting user parter")
try:
@@ -204,8 +223,7 @@ class DeactivateAccountHandler(BaseHandler):
self._user_parter_running = False
async def _part_user(self, user_id: str) -> None:
- """Causes the given user_id to leave all the rooms they're joined to
- """
+ """Causes the given user_id to leave all the rooms they're joined to"""
user = UserID.from_string(user_id)
rooms_for_user = await self.store.get_rooms_for_user(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index debb1b4f29..54293d0b9c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api import errors
from synapse.api.constants import EventTypes
@@ -45,7 +45,7 @@ from synapse.util.retryutils import NotRetryingDestination
from ._base import BaseHandler
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
@trace
- async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
+ async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
Retrieve the given user's devices
@@ -85,8 +85,8 @@ class DeviceWorkerHandler(BaseHandler):
return devices
@trace
- async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
- """ Retrieve the given device
+ async def get_device(self, user_id: str, device_id: str) -> JsonDict:
+ """Retrieve the given device
Args:
user_id: The user to get the device from
@@ -166,7 +166,7 @@ class DeviceWorkerHandler(BaseHandler):
# Fetch the current state at the time.
try:
- event_ids = await self.store.get_forward_extremeties_for_room(
+ event_ids = await self.store.get_forward_extremities_for_room_at_stream_ordering(
room_id, stream_ordering=stream_ordering
)
except errors.StoreError:
@@ -341,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler):
@trace
async def delete_device(self, user_id: str, device_id: str) -> None:
- """ Delete the given device
+ """Delete the given device
Args:
user_id: The user to delete the device from.
@@ -386,7 +386,7 @@ class DeviceHandler(DeviceWorkerHandler):
await self.delete_devices(user_id, device_ids)
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
- """ Delete several devices
+ """Delete several devices
Args:
user_id: The user to delete devices from.
@@ -417,7 +417,7 @@ class DeviceHandler(DeviceWorkerHandler):
await self.notify_device_update(user_id, device_ids)
async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
- """ Update the given device
+ """Update the given device
Args:
user_id: The user to update devices of.
@@ -534,7 +534,9 @@ class DeviceHandler(DeviceWorkerHandler):
device id of the dehydrated device
"""
device_id = await self.check_device_registered(
- user_id, None, initial_device_display_name,
+ user_id,
+ None,
+ initial_device_display_name,
)
old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data
@@ -598,7 +600,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips(
- device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
+ device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@@ -803,7 +805,8 @@ class DeviceListUpdater:
try:
# Try to resync the current user's devices list.
result = await self.user_device_resync(
- user_id=user_id, mark_failed_as_stale=False,
+ user_id=user_id,
+ mark_failed_as_stale=False,
)
# user_device_resync only returns a result if it managed to
@@ -813,14 +816,17 @@ class DeviceListUpdater:
# self.store.update_remote_device_list_cache).
if result:
logger.debug(
- "Successfully resynced the device list for %s", user_id,
+ "Successfully resynced the device list for %s",
+ user_id,
)
except Exception as e:
# If there was an issue resyncing this user, e.g. if the remote
# server sent a malformed result, just log the error instead of
# aborting all the subsequent resyncs.
logger.debug(
- "Could not resync the device list for %s: %s", user_id, e,
+ "Could not resync the device list for %s: %s",
+ user_id,
+ e,
)
finally:
# Allow future calls to retry resyncinc out of sync device lists.
@@ -855,7 +861,9 @@ class DeviceListUpdater:
return None
except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
- "Failed to handle device list update for %s: %s", user_id, e,
+ "Failed to handle device list update for %s: %s",
+ user_id,
+ e,
)
if mark_failed_as_stale:
@@ -899,6 +907,7 @@ class DeviceListUpdater:
master_key = result.get("master_key")
self_signing_key = result.get("self_signing_key")
+ ignore_devices = False
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
@@ -917,6 +926,12 @@ class DeviceListUpdater:
len(devices),
)
devices = []
+ ignore_devices = True
+ else:
+ cached_devices = await self.store.get_cached_devices_for_user(user_id)
+ if cached_devices == {d["device_id"]: d for d in devices}:
+ devices = []
+ ignore_devices = True
for device in devices:
logger.debug(
@@ -926,16 +941,22 @@ class DeviceListUpdater:
stream_id,
)
- await self.store.update_remote_device_list_cache(user_id, devices, stream_id)
+ if not ignore_devices:
+ await self.store.update_remote_device_list_cache(
+ user_id, devices, stream_id
+ )
device_ids = [device["device_id"] for device in devices]
# Handle cross-signing keys.
cross_signing_device_ids = await self.process_cross_signing_key_update(
- user_id, master_key, self_signing_key,
+ user_id,
+ master_key,
+ self_signing_key,
)
device_ids = device_ids + cross_signing_device_ids
- await self.device_handler.notify_device_update(user_id, device_ids)
+ if device_ids:
+ await self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
@@ -946,8 +967,8 @@ class DeviceListUpdater:
async def process_cross_signing_key_update(
self,
user_id: str,
- master_key: Optional[Dict[str, Any]],
- self_signing_key: Optional[Dict[str, Any]],
+ master_key: Optional[JsonDict],
+ self_signing_key: Optional[JsonDict],
) -> List[str]:
"""Process the given new master and self-signing key for the given remote user.
@@ -963,14 +984,17 @@ class DeviceListUpdater:
"""
device_ids = []
- if master_key:
+ current_keys_map = await self.store.get_e2e_cross_signing_keys_bulk([user_id])
+ current_keys = current_keys_map.get(user_id) or {}
+
+ if master_key and master_key != current_keys.get("master"):
await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
device_ids.append(verify_key.version)
- if self_signing_key:
+ if self_signing_key and self_signing_key != current_keys.get("self_signing"):
await self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 9cac5a8463..eb547743be 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -16,7 +16,9 @@
import logging
from typing import TYPE_CHECKING, Any, Dict
+from synapse.api.constants import EduTypes
from synapse.api.errors import SynapseError
+from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
get_active_span_text_map,
@@ -24,12 +26,13 @@ from synapse.logging.opentracing import (
set_tag,
start_active_span,
)
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -44,13 +47,44 @@ class DeviceMessageHandler:
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine = hs.is_mine
- self.federation = hs.get_federation_sender()
- hs.get_federation_registry().register_edu_handler(
- "m.direct_to_device", self.on_direct_to_device_edu
- )
+ # We only need to poke the federation sender explicitly if its on the
+ # same instance. Other federation sender instances will get notified by
+ # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+ # in the to-device replication stream.
+ self.federation_sender = None
+ if hs.should_send_federation():
+ self.federation_sender = hs.get_federation_sender()
+
+ # If we can handle the to device EDUs we do so, otherwise we route them
+ # to the appropriate worker.
+ if hs.get_instance_name() in hs.config.worker.writers.to_device:
+ hs.get_federation_registry().register_edu_handler(
+ "m.direct_to_device", self.on_direct_to_device_edu
+ )
+ else:
+ hs.get_federation_registry().register_instances_for_edu(
+ "m.direct_to_device",
+ hs.config.worker.writers.to_device,
+ )
+
+ # The handler to call when we think a user's device list might be out of
+ # sync. We do all device list resyncing on the master instance, so if
+ # we're on a worker we hit the device resync replication API.
+ if hs.config.worker.worker_app is None:
+ self._user_device_resync = (
+ hs.get_device_handler().device_list_updater.user_device_resync
+ )
+ else:
+ self._user_device_resync = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
+ )
- self._device_list_updater = hs.get_device_handler().device_list_updater
+ self._ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=hs.config.rc_key_requests.per_second,
+ burst_count=hs.config.rc_key_requests.burst_count,
+ )
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
local_messages = {}
@@ -138,21 +172,31 @@ class DeviceMessageHandler:
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background
- run_in_background(
- self._device_list_updater.user_device_resync, sender_user_id
- )
+ run_in_background(self._user_device_resync, user_id=sender_user_id)
async def send_device_message(
self,
- sender_user_id: str,
+ requester: Requester,
message_type: str,
messages: Dict[str, Dict[str, JsonDict]],
) -> None:
+ sender_user_id = requester.user.to_string()
+
set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id)
local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
for user_id, by_device in messages.items():
+ # Ratelimit local cross-user key requests by the sending device.
+ if (
+ message_type == EduTypes.RoomKeyRequest
+ and user_id != sender_user_id
+ and self._ratelimiter.can_do_action(
+ (sender_user_id, requester.device_id)
+ )
+ ):
+ continue
+
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
messages_by_device = {
@@ -195,7 +239,8 @@ class DeviceMessageHandler:
)
log_kv({"remote_messages": remote_messages})
- for destination in remote_messages.keys():
- # Enqueue a new federation transaction to send the new
- # device messages to each remote destination.
- self.federation.send_device_messages(destination)
+ if self.federation_sender:
+ for destination in remote_messages.keys():
+ # Enqueue a new federation transaction to send the new
+ # device messages to each remote destination.
+ self.federation_sender.send_device_messages(destination)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index ad5683d251..abcf86352d 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
403, "You must be in the room to create an alias for it"
)
- if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+ if not await self.spam_checker.user_may_create_room_alias(
+ user_id, room_alias
+ ):
raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed(
@@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_publish_room(user_id, room_id):
+ if not await self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
403, "This user is not permitted to publish rooms to the room list"
)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 929752150d..2ad9b6d930 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,7 +16,7 @@
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
+ JsonDict,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
@@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class E2eKeysHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
@@ -57,8 +61,8 @@ class E2eKeysHandler:
self._is_master = hs.config.worker_app is None
if not self._is_master:
- self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client(
- hs
+ self._user_device_resync_client = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
else:
# Only register this edu handler on master as it requires writing
@@ -78,8 +82,10 @@ class E2eKeysHandler:
)
@trace
- async def query_devices(self, query_body, timeout, from_user_id):
- """ Handle a device key query from a client
+ async def query_devices(
+ self, query_body: JsonDict, timeout: int, from_user_id: str
+ ) -> JsonDict:
+ """Handle a device key query from a client
{
"device_keys": {
@@ -98,12 +104,14 @@ class E2eKeysHandler:
}
Args:
- from_user_id (str): the user making the query. This is used when
+ from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
"""
- device_keys_query = query_body.get("device_keys", {})
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Iterable[str]]
# separate users by domain.
# make a map from domain to user_id to device_ids
@@ -121,7 +129,8 @@ class E2eKeysHandler:
set_tag("remote_key_query", remote_queries)
# First get local devices.
- failures = {}
+ # A map of destination -> failure response.
+ failures = {} # type: Dict[str, JsonDict]
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
@@ -135,9 +144,10 @@ class E2eKeysHandler:
)
# Now attempt to get any remote devices from our local cache.
- remote_queries_not_in_cache = {}
+ # A map of destination -> user ID -> device IDs.
+ remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
- query_list = []
+ query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
@@ -284,15 +294,15 @@ class E2eKeysHandler:
return ret
async def get_cross_signing_keys_from_cache(
- self, query, from_user_id
+ self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
"""Get cross-signing keys for users from the database
Args:
- query (Iterable[string]) an iterable of user IDs. A dict whose keys
+ query: an iterable of user IDs. A dict whose keys
are user IDs satisfies this, so the query format used for
query_devices can be used here.
- from_user_id (str): the user making the query. This is used when
+ from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
@@ -315,14 +325,12 @@ class E2eKeysHandler:
if "self_signing" in user_info:
self_signing_keys[user_id] = user_info["self_signing"]
- if (
- from_user_id in keys
- and keys[from_user_id] is not None
- and "user_signing" in keys[from_user_id]
- ):
- # users can see other users' master and self-signing keys, but can
- # only see their own user-signing keys
- user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+ # users can see other users' master and self-signing keys, but can
+ # only see their own user-signing keys
+ if from_user_id:
+ from_user_key = keys.get(from_user_id)
+ if from_user_key and "user_signing" in from_user_key:
+ user_signing_keys[from_user_id] = from_user_key["user_signing"]
return {
"master_keys": master_keys,
@@ -344,9 +352,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
- local_query = []
+ local_query = [] # type: List[Tuple[str, Optional[str]]]
- result_dict = {}
+ result_dict = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
@@ -380,10 +388,13 @@ class E2eKeysHandler:
log_kv(results)
return result_dict
- async def on_federation_query_client_keys(self, query_body):
- """ Handle a device key query from a federated server
- """
- device_keys_query = query_body.get("device_keys", {})
+ async def on_federation_query_client_keys(
+ self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
+ ) -> JsonDict:
+ """Handle a device key query from a federated server"""
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Optional[List[str]]]
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
@@ -397,31 +408,34 @@ class E2eKeysHandler:
return ret
@trace
- async def claim_one_time_keys(self, query, timeout):
- local_query = []
- remote_queries = {}
+ async def claim_one_time_keys(
+ self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+ ) -> JsonDict:
+ local_query = [] # type: List[Tuple[str, str, str]]
+ remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
- for user_id, device_keys in query.get("one_time_keys", {}).items():
+ for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
- for device_id, algorithm in device_keys.items():
+ for device_id, algorithm in one_time_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
- remote_queries.setdefault(domain, {})[user_id] = device_keys
+ remote_queries.setdefault(domain, {})[user_id] = one_time_keys
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
results = await self.store.claim_e2e_one_time_keys(local_query)
- json_result = {}
- failures = {}
+ # A map of user ID -> device ID -> key ID -> key.
+ json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+ failures = {} # type: Dict[str, JsonDict]
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
- for key_id, json_bytes in keys.items():
+ for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
- key_id: json_decoder.decode(json_bytes)
+ key_id: json_decoder.decode(json_str)
}
@trace
@@ -468,7 +482,9 @@ class E2eKeysHandler:
return {"one_time_keys": json_result, "failures": failures}
@tag_args
- async def upload_keys_for_user(self, user_id, device_id, keys):
+ async def upload_keys_for_user(
+ self, user_id: str, device_id: str, keys: JsonDict
+ ) -> JsonDict:
time_now = self.clock.time_msec()
@@ -543,8 +559,8 @@ class E2eKeysHandler:
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user(
- self, user_id, device_id, time_now, one_time_keys
- ):
+ self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+ ) -> None:
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(),
@@ -585,12 +601,14 @@ class E2eKeysHandler:
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
- async def upload_signing_keys_for_user(self, user_id, keys):
+ async def upload_signing_keys_for_user(
+ self, user_id: str, keys: JsonDict
+ ) -> JsonDict:
"""Upload signing keys for cross-signing
Args:
- user_id (string): the user uploading the keys
- keys (dict[string, dict]): the signing keys
+ user_id: the user uploading the keys
+ keys: the signing keys
"""
# if a master key is uploaded, then check it. Otherwise, load the
@@ -667,16 +685,17 @@ class E2eKeysHandler:
return {}
- async def upload_signatures_for_device_keys(self, user_id, signatures):
+ async def upload_signatures_for_device_keys(
+ self, user_id: str, signatures: JsonDict
+ ) -> JsonDict:
"""Upload device signatures for cross-signing
Args:
- user_id (string): the user uploading the signatures
- signatures (dict[string, dict[string, dict]]): map of users to
- devices to signed keys. This is the submission from the user; an
- exception will be raised if it is malformed.
+ user_id: the user uploading the signatures
+ signatures: map of users to devices to signed keys. This is the submission
+ from the user; an exception will be raised if it is malformed.
Returns:
- dict: response to be sent back to the client. The response will have
+ The response to be sent back to the client. The response will have
a "failures" key, which will be a dict mapping users to devices
to errors for the signatures that failed.
Raises:
@@ -719,7 +738,9 @@ class E2eKeysHandler:
return {"failures": failures}
- async def _process_self_signatures(self, user_id, signatures):
+ async def _process_self_signatures(
+ self, user_id: str, signatures: JsonDict
+ ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of the user's own keys.
Signatures of the user's own keys from this API come in two forms:
@@ -731,15 +752,14 @@ class E2eKeysHandler:
signatures (dict[string, dict]): map of devices to signed keys
Returns:
- (list[SignatureListItem], dict[string, dict[string, dict]]):
- a list of signatures to store, and a map of users to devices to failure
- reasons
+ A tuple of a list of signatures to store, and a map of users to
+ devices to failure reasons
Raises:
SynapseError: if the input is malformed
"""
- signature_list = []
- failures = {}
+ signature_list = [] # type: List[SignatureListItem]
+ failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures:
return signature_list, failures
@@ -834,19 +854,24 @@ class E2eKeysHandler:
return signature_list, failures
def _check_master_key_signature(
- self, user_id, master_key_id, signed_master_key, stored_master_key, devices
- ):
+ self,
+ user_id: str,
+ master_key_id: str,
+ signed_master_key: JsonDict,
+ stored_master_key: JsonDict,
+ devices: Dict[str, Dict[str, JsonDict]],
+ ) -> List["SignatureListItem"]:
"""Check signatures of a user's master key made by their devices.
Args:
- user_id (string): the user whose master key is being checked
- master_key_id (string): the ID of the user's master key
- signed_master_key (dict): the user's signed master key that was uploaded
- stored_master_key (dict): our previously-stored copy of the user's master key
- devices (iterable(dict)): the user's devices
+ user_id: the user whose master key is being checked
+ master_key_id: the ID of the user's master key
+ signed_master_key: the user's signed master key that was uploaded
+ stored_master_key: our previously-stored copy of the user's master key
+ devices: the user's devices
Returns:
- list[SignatureListItem]: a list of signatures to store
+ A list of signatures to store
Raises:
SynapseError: if a signature is invalid
@@ -877,25 +902,26 @@ class E2eKeysHandler:
return master_key_signature_list
- async def _process_other_signatures(self, user_id, signatures):
+ async def _process_other_signatures(
+ self, user_id: str, signatures: Dict[str, dict]
+ ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of other users' keys. These will be the
target user's master keys, signed by the uploading user's user-signing
key.
Args:
- user_id (string): the user uploading the keys
- signatures (dict[string, dict]): map of users to devices to signed keys
+ user_id: the user uploading the keys
+ signatures: map of users to devices to signed keys
Returns:
- (list[SignatureListItem], dict[string, dict[string, dict]]):
- a list of signatures to store, and a map of users to devices to failure
+ A list of signatures to store, and a map of users to devices to failure
reasons
Raises:
SynapseError: if the input is malformed
"""
- signature_list = []
- failures = {}
+ signature_list = [] # type: List[SignatureListItem]
+ failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures:
return signature_list, failures
@@ -983,7 +1009,7 @@ class E2eKeysHandler:
async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None
- ):
+ ) -> Tuple[JsonDict, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key.
First, attempt to fetch the cross-signing public key from storage.
@@ -997,8 +1023,7 @@ class E2eKeysHandler:
This affects what signatures are fetched.
Returns:
- dict, str, VerifyKey: the raw key data, the key ID, and the
- signedjson verify key
+ The raw key data, the key ID, and the signedjson verify key
Raises:
NotFoundError: if the key is not found
@@ -1039,7 +1064,9 @@ class E2eKeysHandler:
return key, key_id, verify_key
async def _retrieve_cross_signing_keys_for_remote_user(
- self, user: UserID, desired_key_type: str,
+ self,
+ user: UserID,
+ desired_key_type: str,
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database
@@ -1135,16 +1162,18 @@ class E2eKeysHandler:
return desired_key, desired_key_id, desired_verify_key
-def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
+def _check_cross_signing_key(
+ key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
+) -> None:
"""Check a cross-signing key uploaded by a user. Performs some basic sanity
checking, and ensures that it is signed, if a signature is required.
Args:
- key (dict): the key data to verify
- user_id (str): the user whose key is being checked
- key_type (str): the type of key that the key should be
- signing_key (VerifyKey): (optional) the signing key that the key should
- be signed with. If omitted, signatures will not be checked.
+ key: the key data to verify
+ user_id: the user whose key is being checked
+ key_type: the type of key that the key should be
+ signing_key: the signing key that the key should be signed with. If
+ omitted, signatures will not be checked.
"""
if (
key.get("user_id") != user_id
@@ -1162,16 +1191,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
)
-def _check_device_signature(user_id, verify_key, signed_device, stored_device):
+def _check_device_signature(
+ user_id: str,
+ verify_key: VerifyKey,
+ signed_device: JsonDict,
+ stored_device: JsonDict,
+) -> None:
"""Check that a signature on a device or cross-signing key is correct and
matches the copy of the device/key that we have stored. Throws an
exception if an error is detected.
Args:
- user_id (str): the user ID whose signature is being checked
- verify_key (VerifyKey): the key to verify the device with
- signed_device (dict): the uploaded signed device data
- stored_device (dict): our previously stored copy of the device
+ user_id: the user ID whose signature is being checked
+ verify_key: the key to verify the device with
+ signed_device: the uploaded signed device data
+ stored_device: our previously stored copy of the device
Raises:
SynapseError: if the signature was invalid or the sent device is not the
@@ -1201,7 +1235,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
-def _exception_to_failure(e):
+def _exception_to_failure(e: Exception) -> JsonDict:
if isinstance(e, SynapseError):
return {"status": e.code, "errcode": e.errcode, "message": str(e)}
@@ -1218,7 +1252,7 @@ def _exception_to_failure(e):
return {"status": 503, "message": str(e)}
-def _one_time_keys_match(old_key_json, new_key):
+def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly
@@ -1236,19 +1270,18 @@ def _one_time_keys_match(old_key_json, new_key):
@attr.s(slots=True)
class SignatureListItem:
- """An item in the signature list as used by upload_signatures_for_device_keys.
- """
+ """An item in the signature list as used by upload_signatures_for_device_keys."""
- signing_key_id = attr.ib()
- target_user_id = attr.ib()
- target_device_id = attr.ib()
- signature = attr.ib()
+ signing_key_id = attr.ib(type=str)
+ target_user_id = attr.ib(type=str)
+ target_device_id = attr.ib(type=str)
+ signature = attr.ib(type=JsonDict)
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""
- def __init__(self, hs, e2e_keys_handler):
+ def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
@@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
- self._pending_updates = {}
+ self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
@@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
iterable=True,
)
- async def incoming_signing_key_update(self, origin, edu_content):
+ async def incoming_signing_key_update(
+ self, origin: str, edu_content: JsonDict
+ ) -> None:
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.
Args:
- origin (string): the server that sent the EDU
- edu_content (dict): the contents of the EDU
+ origin: the server that sent the EDU
+ edu_content: the contents of the EDU
"""
user_id = edu_content.pop("user_id")
@@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
await self._handle_signing_key_updates(user_id)
- async def _handle_signing_key_updates(self, user_id):
+ async def _handle_signing_key_updates(self, user_id: str) -> None:
"""Actually handle pending updates.
Args:
- user_id (string): the user whose updates we are processing
+ user_id: the user whose updates we are processing
"""
device_handler = self.e2e_keys_handler.device_handler
@@ -1315,13 +1350,17 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates
return
- device_ids = []
+ device_ids = [] # type: List[str]
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
- new_device_ids = await device_list_updater.process_cross_signing_key_update(
- user_id, master_key, self_signing_key,
+ new_device_ids = (
+ await device_list_updater.process_cross_signing_key_update(
+ user_id,
+ master_key,
+ self_signing_key,
+ )
)
device_ids = device_ids + new_device_ids
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f01b090772..a910d246d6 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, List, Optional
from synapse.api.errors import (
Codes,
@@ -24,8 +25,12 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, trace
+from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -37,7 +42,7 @@ class E2eRoomKeysHandler:
The actual payload of the encrypted keys is completely opaque to the handler.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
# Used to lock whenever a client is uploading key data. This prevents collisions
@@ -48,21 +53,27 @@ class E2eRoomKeysHandler:
self._upload_linearizer = Linearizer("upload_room_keys_lock")
@trace
- async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_room_keys(
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> List[JsonDict]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
Args:
- user_id(str): the user whose keys we're getting
- version(str): the version ID of the backup we're getting keys from
- room_id(string): room ID to get keys for, for None to get keys for all rooms
- session_id(string): session ID to get keys for, for None to get keys for all
+ user_id: the user whose keys we're getting
+ version: the version ID of the backup we're getting keys from
+ room_id: room ID to get keys for, for None to get keys for all rooms
+ session_id: session ID to get keys for, for None to get keys for all
sessions
Raises:
NotFoundError: if the backup version does not exist
Returns:
- A deferred list of dicts giving the session_data and message metadata for
+ A list of dicts giving the session_data and message metadata for
these room keys.
"""
@@ -86,17 +97,23 @@ class E2eRoomKeysHandler:
return results
@trace
- async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def delete_room_keys(
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> JsonDict:
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
Args:
- user_id(str): the user whose backup we're deleting
- version(str): the version ID of the backup we're deleting
- room_id(string): room ID to delete keys for, for None to delete keys for all
+ user_id: the user whose backup we're deleting
+ version: the version ID of the backup we're deleting
+ room_id: room ID to delete keys for, for None to delete keys for all
rooms
- session_id(string): session ID to delete keys for, for None to delete keys
+ session_id: session ID to delete keys for, for None to delete keys
for all sessions
Raises:
NotFoundError: if the backup version does not exist
@@ -128,15 +145,17 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count}
@trace
- async def upload_room_keys(self, user_id, version, room_keys):
+ async def upload_room_keys(
+ self, user_id: str, version: str, room_keys: JsonDict
+ ) -> JsonDict:
"""Bulk upload a list of room keys into a given backup version, asserting
that the given version is the current backup version. room_keys are merged
into the current backup as described in RoomKeysServlet.on_PUT().
Args:
- user_id(str): the user whose backup we're setting
- version(str): the version ID of the backup we're updating
- room_keys(dict): a nested dict describing the room_keys we're setting:
+ user_id: the user whose backup we're setting
+ version: the version ID of the backup we're updating
+ room_keys: a nested dict describing the room_keys we're setting:
{
"rooms": {
@@ -254,14 +273,16 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count}
@staticmethod
- def _should_replace_room_key(current_room_key, room_key):
+ def _should_replace_room_key(
+ current_room_key: Optional[JsonDict], room_key: JsonDict
+ ) -> bool:
"""
Determine whether to replace a given current_room_key (if any)
with a newly uploaded room_key backup
Args:
- current_room_key (dict): Optional, the current room_key dict if any
- room_key (dict): The new room_key dict which may or may not be fit to
+ current_room_key: Optional, the current room_key dict if any
+ room_key : The new room_key dict which may or may not be fit to
replace the current_room_key
Returns:
@@ -286,14 +307,14 @@ class E2eRoomKeysHandler:
return True
@trace
- async def create_version(self, user_id, version_info):
+ async def create_version(self, user_id: str, version_info: JsonDict) -> str:
"""Create a new backup version. This automatically becomes the new
backup version for the user's keys; previous backups will no longer be
writeable to.
Args:
- user_id(str): the user whose backup version we're creating
- version_info(dict): metadata about the new version being created
+ user_id: the user whose backup version we're creating
+ version_info: metadata about the new version being created
{
"algorithm": "m.megolm_backup.v1",
@@ -301,7 +322,7 @@ class E2eRoomKeysHandler:
}
Returns:
- A deferred of a string that gives the new version number.
+ The new version number.
"""
# TODO: Validate the JSON to make sure it has the right keys.
@@ -313,17 +334,19 @@ class E2eRoomKeysHandler:
)
return new_version
- async def get_version_info(self, user_id, version=None):
+ async def get_version_info(
+ self, user_id: str, version: Optional[str] = None
+ ) -> JsonDict:
"""Get the info about a given version of the user's backup
Args:
- user_id(str): the user whose current backup version we're querying
- version(str): Optional; if None gives the most recent version
+ user_id: the user whose current backup version we're querying
+ version: Optional; if None gives the most recent version
otherwise a historical one.
Raises:
NotFoundError: if the requested backup version doesn't exist
Returns:
- A deferred of a info dict that gives the info about the new version.
+ A info dict that gives the info about the new version.
{
"version": "1234",
@@ -346,7 +369,7 @@ class E2eRoomKeysHandler:
return res
@trace
- async def delete_version(self, user_id, version=None):
+ async def delete_version(self, user_id: str, version: Optional[str] = None) -> None:
"""Deletes a given version of the user's e2e_room_keys backup
Args:
@@ -366,17 +389,19 @@ class E2eRoomKeysHandler:
raise
@trace
- async def update_version(self, user_id, version, version_info):
+ async def update_version(
+ self, user_id: str, version: str, version_info: JsonDict
+ ) -> JsonDict:
"""Update the info about a given version of the user's backup
Args:
- user_id(str): the user whose current backup version we're updating
- version(str): the backup version we're updating
- version_info(dict): the new information about the backup
+ user_id: the user whose current backup version we're updating
+ version: the backup version we're updating
+ version_info: the new information about the backup
Raises:
NotFoundError: if the requested backup version doesn't exist
Returns:
- A deferred of an empty dict.
+ An empty dict.
"""
if "version" not in version_info:
version_info["version"] = version
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 539b4fc32e..f46cab7325 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -17,7 +17,7 @@ import logging
import random
from typing import TYPE_CHECKING, Iterable, List, Optional
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
@@ -57,8 +57,7 @@ class EventStreamHandler(BaseHandler):
room_id: Optional[str] = None,
is_guest: bool = False,
) -> JsonDict:
- """Fetches the events stream for a given user.
- """
+ """Fetches the events stream for a given user."""
if room_id:
blocked = await self.store.is_room_blocked(room_id)
@@ -114,7 +113,7 @@ class EventStreamHandler(BaseHandler):
states = await presence_handler.get_states(users)
to_add.extend(
{
- "type": EventTypes.Presence,
+ "type": EduTypes.Presence,
"content": format_user_presence_state(state, time_now),
}
for state in states
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 7050ae5744..ec2ce679c2 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -112,13 +112,13 @@ class _NewEventInfo:
class FederationHandler(BaseHandler):
"""Handles events that originated from federation.
- Responsible for:
- a) handling received Pdus before handing them on as Events to the rest
- of the homeserver (including auth and state conflict resolutions)
- b) converting events that were produced by local clients that may need
- to be sent to remote homeservers.
- c) doing the necessary dances to invite remote users and join remote
- rooms.
+ Responsible for:
+ a) handling received Pdus before handing them on as Events to the rest
+ of the homeserver (including auth and state conflict resolutions)
+ b) converting events that were produced by local clients that may need
+ to be sent to remote homeservers.
+ c) doing the necessary dances to invite remote users and join remote
+ rooms.
"""
def __init__(self, hs: "HomeServer"):
@@ -151,11 +151,11 @@ class FederationHandler(BaseHandler):
)
if hs.config.worker_app:
- self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
- hs
+ self._user_device_resync = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
- self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
- hs
+ self._maybe_store_room_on_outlier_membership = (
+ ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
)
else:
self._device_list_updater = hs.get_device_handler().device_list_updater
@@ -173,7 +173,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
- """ Process a PDU received via a federation /send/ transaction, or
+ """Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
Args:
@@ -202,7 +202,7 @@ class FederationHandler(BaseHandler):
or pdu.internal_metadata.is_outlier()
)
if already_seen:
- logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
+ logger.debug("Already seen pdu")
return
# do some initial sanity-checking of the event. In particular, make
@@ -211,18 +211,14 @@ class FederationHandler(BaseHandler):
try:
self._sanity_check_event(pdu)
except SynapseError as err:
- logger.warning(
- "[%s %s] Received event failed sanity checks", room_id, event_id
- )
+ logger.warning("Received event failed sanity checks")
raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
if room_id in self.room_queues:
logger.info(
- "[%s %s] Queuing PDU from %s for now: join in progress",
- room_id,
- event_id,
+ "Queuing PDU from %s for now: join in progress",
origin,
)
self.room_queues[room_id].append((pdu, origin))
@@ -237,9 +233,7 @@ class FederationHandler(BaseHandler):
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
- "[%s %s] Ignoring PDU from %s as we're not in the room",
- room_id,
- event_id,
+ "Ignoring PDU from %s as we're not in the room",
origin,
)
return None
@@ -251,7 +245,7 @@ class FederationHandler(BaseHandler):
# We only backfill backwards to the min depth.
min_depth = await self.get_min_depth_for_context(pdu.room_id)
- logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
+ logger.debug("min_depth: %d", min_depth)
prevs = set(pdu.prev_event_ids())
seen = await self.store.have_events_in_timeline(prevs)
@@ -268,17 +262,13 @@ class FederationHandler(BaseHandler):
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
- "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
- room_id,
- event_id,
+ "Acquiring room lock to fetch %d missing prev_events: %s",
len(missing_prevs),
shortstr(missing_prevs),
)
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
- "[%s %s] Acquired room lock to fetch %d missing prev_events",
- room_id,
- event_id,
+ "Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),
)
@@ -298,9 +288,7 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
logger.info(
- "[%s %s] Found all missing prev_events",
- room_id,
- event_id,
+ "Found all missing prev_events",
)
elif missing_prevs:
logger.info(
@@ -338,9 +326,7 @@ class FederationHandler(BaseHandler):
if sent_to_us_directly:
logger.warning(
- "[%s %s] Rejecting: failed to fetch %d prev events: %s",
- room_id,
- event_id,
+ "Rejecting: failed to fetch %d prev events: %s",
len(prevs - seen),
shortstr(prevs - seen),
)
@@ -370,19 +356,16 @@ class FederationHandler(BaseHandler):
# Ask the remote server for the states we don't
# know about
for p in prevs - seen:
- logger.info(
- "[%s %s] Requesting state at missing prev_event %s",
- room_id,
- event_id,
- p,
- )
+ logger.info("Requesting state after missing prev_event %s", p)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- (remote_state, _,) = await self._get_state_for_room(
- origin, room_id, p, include_event_in_state=True
+ remote_state = (
+ await self._get_state_after_missing_prev_event(
+ origin, room_id, p
+ )
)
remote_state_map = {
@@ -394,12 +377,14 @@ class FederationHandler(BaseHandler):
event_map[x.event_id] = x
room_version = await self.store.get_room_version_id(room_id)
- state_map = await self._state_resolution_handler.resolve_events_with_store(
- room_id,
- room_version,
- state_maps,
- event_map,
- state_res_store=StateResolutionStore(self.store),
+ state_map = (
+ await self._state_resolution_handler.resolve_events_with_store(
+ room_id,
+ room_version,
+ state_maps,
+ event_map,
+ state_res_store=StateResolutionStore(self.store),
+ )
)
# We need to give _process_received_pdu the actual state events
@@ -408,17 +393,15 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self.store.get_events(
- list(state_map.values()), get_prev_content=False,
+ list(state_map.values()),
+ get_prev_content=False,
)
event_map.update(evs)
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
- "[%s %s] Error attempting to resolve state at missing "
- "prev_events",
- room_id,
- event_id,
+ "Error attempting to resolve state at missing " "prev_events",
exc_info=True,
)
raise FederationError(
@@ -455,9 +438,7 @@ class FederationHandler(BaseHandler):
latest |= seen
logger.info(
- "[%s %s]: Requesting missing events between %s and %s",
- room_id,
- event_id,
+ "Requesting missing events between %s and %s",
shortstr(latest),
event_id,
)
@@ -524,15 +505,11 @@ class FederationHandler(BaseHandler):
# We failed to get the missing events, but since we need to handle
# the case of `get_missing_events` not returning the necessary
# events anyway, it is safe to simply log the error and continue.
- logger.warning(
- "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
- )
+ logger.warning("Failed to get prev_events: %s", e)
return
logger.info(
- "[%s %s]: Got %d prev_events: %s",
- room_id,
- event_id,
+ "Got %d prev_events: %s",
len(missing_events),
shortstr(missing_events),
)
@@ -543,9 +520,7 @@ class FederationHandler(BaseHandler):
for ev in missing_events:
logger.info(
- "[%s %s] Handling received prev_event %s",
- room_id,
- event_id,
+ "Handling received prev_event %s",
ev.event_id,
)
with nested_logging_context(ev.event_id):
@@ -554,9 +529,7 @@ class FederationHandler(BaseHandler):
except FederationError as e:
if e.code == 403:
logger.warning(
- "[%s %s] Received prev_event %s failed history check.",
- room_id,
- event_id,
+ "Received prev_event %s failed history check.",
ev.event_id,
)
else:
@@ -567,7 +540,6 @@ class FederationHandler(BaseHandler):
destination: str,
room_id: str,
event_id: str,
- include_event_in_state: bool = False,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver.
@@ -575,11 +547,9 @@ class FederationHandler(BaseHandler):
destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
- include_event_in_state: if true, the event itself will be included in the
- returned state event list.
Returns:
- A list of events in the state, possibly including the event itself, and
+ A list of events in the state, not including the event itself, and
a list of events in the auth chain for the given event.
"""
(
@@ -591,9 +561,6 @@ class FederationHandler(BaseHandler):
desired_events = set(state_event_ids + auth_event_ids)
- if include_event_in_state:
- desired_events.add(event_id)
-
event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
@@ -610,13 +577,6 @@ class FederationHandler(BaseHandler):
event_map[e_id] for e_id in state_event_ids if e_id in event_map
]
- if include_event_in_state:
- remote_event = event_map.get(event_id)
- if not remote_event:
- raise Exception("Unable to get missing prev_event %s" % (event_id,))
- if remote_event.is_state() and remote_event.rejected_reason is None:
- remote_state.append(remote_event)
-
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
@@ -690,10 +650,138 @@ class FederationHandler(BaseHandler):
return fetched_events
+ async def _get_state_after_missing_prev_event(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ ) -> List[EventBase]:
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
+
+ Returns:
+ A list of events in the state, including the event itself
+ """
+ # TODO: This function is basically the same as _get_state_for_room. Can
+ # we make backfill() use it, rather than having two code paths? I think the
+ # only difference is that backfill() persists the prev events separately.
+
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = await self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
+
+ logger.debug(
+ "state_ids returned %i state events, %i auth events",
+ len(state_event_ids),
+ len(auth_event_ids),
+ )
+
+ # start by just trying to fetch the events from the store
+ desired_events = set(state_event_ids)
+ desired_events.add(event_id)
+ logger.debug("Fetching %i events from cache/store", len(desired_events))
+ fetched_events = await self.store.get_events(
+ desired_events, allow_rejected=True
+ )
+
+ missing_desired_events = desired_events - fetched_events.keys()
+ logger.debug(
+ "We are missing %i events (got %i)",
+ len(missing_desired_events),
+ len(fetched_events),
+ )
+
+ # We probably won't need most of the auth events, so let's just check which
+ # we have for now, rather than thrashing the event cache with them all
+ # unnecessarily.
+
+ # TODO: we probably won't actually need all of the auth events, since we
+ # already have a bunch of the state events. It would be nice if the
+ # federation api gave us a way of finding out which we actually need.
+
+ missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+ missing_auth_events.difference_update(
+ await self.store.have_seen_events(missing_auth_events)
+ )
+ logger.debug("We are also missing %i auth events", len(missing_auth_events))
+
+ missing_events = missing_desired_events | missing_auth_events
+ logger.debug("Fetching %i events from remote", len(missing_events))
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, events=missing_events
+ )
+
+ # we need to make sure we re-load from the database to get the rejected
+ # state correct.
+ fetched_events.update(
+ (await self.store.get_events(missing_desired_events, allow_rejected=True))
+ )
+
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+
+ bad_events = [
+ (event_id, event.room_id)
+ for event_id, event in fetched_events.items()
+ if event.room_id != room_id
+ ]
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
+ room_id,
+ )
+
+ del fetched_events[bad_event_id]
+
+ # if we couldn't get the prev event in question, that's a problem.
+ remote_event = fetched_events.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+
+ # missing state at that event is a warning, not a blocker
+ # XXX: this doesn't sound right? it means that we'll end up with incomplete
+ # state.
+ failed_to_fetch = desired_events - fetched_events.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state events for %s %s",
+ event_id,
+ failed_to_fetch,
+ )
+
+ remote_state = [
+ fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
+ ]
+
+ if remote_event.is_state() and remote_event.rejected_reason is None:
+ remote_state.append(remote_event)
+
+ return remote_state
+
async def _process_received_pdu(
- self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
+ self,
+ origin: str,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]],
):
- """ Called when we have a new pdu. We need to do auth checks and put it
+ """Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
Args:
@@ -705,10 +793,7 @@ class FederationHandler(BaseHandler):
(ie, we are missing one or more prev_events), the resolved state at the
event
"""
- room_id = event.room_id
- event_id = event.event_id
-
- logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
+ logger.debug("Processing event: %s", event)
try:
await self._handle_new_event(origin, event, state=state)
@@ -805,7 +890,7 @@ class FederationHandler(BaseHandler):
@log_function
async def backfill(self, dest, room_id, limit, extremities):
- """ Trigger a backfill request to `dest` for the given `room_id`
+ """Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
has no new events to offer, this will return an empty list.
@@ -869,7 +954,6 @@ class FederationHandler(BaseHandler):
destination=dest,
room_id=room_id,
event_id=e_id,
- include_event_in_state=False,
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
@@ -1208,11 +1292,16 @@ class FederationHandler(BaseHandler):
with nested_logging_context(event_id):
try:
event = await self.federation_client.get_pdu(
- [destination], event_id, room_version, outlier=True,
+ [destination],
+ event_id,
+ room_version,
+ outlier=True,
)
if event is None:
logger.warning(
- "Server %s didn't return event %s", destination, event_id,
+ "Server %s didn't return event %s",
+ destination,
+ event_id,
)
return
@@ -1239,7 +1328,8 @@ class FederationHandler(BaseHandler):
if aid not in event_map
]
persisted_events = await self.store.get_events(
- auth_events, allow_rejected=True,
+ auth_events,
+ allow_rejected=True,
)
event_infos = []
@@ -1255,7 +1345,9 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events(
- destination, room_id, event_infos,
+ destination,
+ room_id,
+ event_infos,
)
def _sanity_check_event(self, ev):
@@ -1291,7 +1383,7 @@ class FederationHandler(BaseHandler):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
async def send_invite(self, target_host, event):
- """ Sends the invite to the remote server for signing.
+ """Sends the invite to the remote server for signing.
Invites must be signed by the invitee's server before distribution.
"""
@@ -1307,14 +1399,14 @@ class FederationHandler(BaseHandler):
async def on_event_auth(self, event_id: str) -> List[EventBase]:
event = await self.store.get_event(event_id)
auth = await self.store.get_auth_chain(
- list(event.auth_event_ids()), include_given=True
+ event.room_id, list(event.auth_event_ids()), include_given=True
)
return list(auth)
async def do_invite_join(
self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
) -> Tuple[str, int]:
- """ Attempts to join the `joinee` to the room `room_id` via the
+ """Attempts to join the `joinee` to the room `room_id` via the
servers contained in `target_hosts`.
This first triggers a /make_join/ request that returns a partial
@@ -1358,8 +1450,6 @@ class FederationHandler(BaseHandler):
await self._clean_room_for_join(room_id)
- handled_events = set()
-
try:
# Try the host we successfully got a response to /make_join/
# request first.
@@ -1379,10 +1469,6 @@ class FederationHandler(BaseHandler):
auth_chain = ret["auth_chain"]
auth_chain.sort(key=lambda e: e.depth)
- handled_events.update([s.event_id for s in state])
- handled_events.update([a.event_id for a in auth_chain])
- handled_events.add(event.event_id)
-
logger.debug("do_invite_join auth_chain: %s", auth_chain)
logger.debug("do_invite_join state: %s", state)
@@ -1398,7 +1484,8 @@ class FederationHandler(BaseHandler):
# so we can rely on it now.
#
await self.store.upsert_room_on_join(
- room_id=room_id, room_version=room_version_obj,
+ room_id=room_id,
+ room_version=room_version_obj,
)
max_stream_id = await self._persist_auth_tree(
@@ -1445,7 +1532,11 @@ class FederationHandler(BaseHandler):
@log_function
async def do_knock(
- self, target_hosts: List[str], room_id: str, knockee: str, content: JsonDict,
+ self,
+ target_hosts: List[str],
+ room_id: str,
+ knockee: str,
+ content: JsonDict,
) -> Tuple[str, int]:
"""Sends the knock to the remote server.
@@ -1535,7 +1626,7 @@ class FederationHandler(BaseHandler):
async def on_make_join_request(
self, origin: str, room_id: str, user_id: str
) -> EventBase:
- """ We've received a /make_join/ request, so we create a partial
+ """We've received a /make_join/ request, so we create a partial
join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
@@ -1560,7 +1651,8 @@ class FederationHandler(BaseHandler):
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
- "Got /make_join request for room %s we are no longer in", room_id,
+ "Got /make_join request for room %s we are no longer in",
+ room_id,
)
raise NotFoundError("Not an active room on this server")
@@ -1594,7 +1686,7 @@ class FederationHandler(BaseHandler):
return event
async def on_send_join_request(self, origin, pdu):
- """ We have received a join event for a room. Fully process it and
+ """We have received a join event for a room. Fully process it and
respond with the current state and auth chains.
"""
event = pdu
@@ -1641,7 +1733,7 @@ class FederationHandler(BaseHandler):
prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
- auth_chain = await self.store.get_auth_chain(state_ids)
+ auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
state = await self.store.get_events(list(prev_state_ids.values()))
@@ -1650,7 +1742,7 @@ class FederationHandler(BaseHandler):
async def on_invite_request(
self, origin: str, event: EventBase, room_version: RoomVersion
):
- """ We've got an invite event. Process and persist it. Sign it.
+ """We've got an invite event. Process and persist it. Sign it.
Respond with the now signed event.
"""
@@ -1666,7 +1758,7 @@ class FederationHandler(BaseHandler):
is_published = await self.store.is_room_published(event.room_id)
- if not self.spam_checker.user_may_invite(
+ if not await self.spam_checker.user_may_invite(
event.sender,
event.state_key,
None,
@@ -1695,6 +1787,12 @@ class FederationHandler(BaseHandler):
if event.state_key == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
+ # We retrieve the room member handler here as to not cause a cyclic dependency
+ member_handler = self.hs.get_room_member_handler()
+ # We don't rate limit based on room ID, as that should be done by
+ # sending server.
+ member_handler.ratelimit_invite(None, event.state_key)
+
# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
# join dance).
@@ -1778,7 +1876,7 @@ class FederationHandler(BaseHandler):
async def on_make_leave_request(
self, origin: str, room_id: str, user_id: str
) -> EventBase:
- """ We've received a /make_leave/ request, so we create a partial
+ """We've received a /make_leave/ request, so we create a partial
leave event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
@@ -1968,8 +2066,7 @@ class FederationHandler(BaseHandler):
return context
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
- """Returns the state at the event. i.e. not including said event.
- """
+ """Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
@@ -1995,8 +2092,7 @@ class FederationHandler(BaseHandler):
return []
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
- """Returns the state at the event. i.e. not including said event.
- """
+ """Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@@ -2202,7 +2298,11 @@ class FederationHandler(BaseHandler):
for e_id in missing_auth_events:
m_ev = await self.federation_client.get_pdu(
- [origin], e_id, room_version=room_version, outlier=True, timeout=10000,
+ [origin],
+ e_id,
+ room_version=room_version,
+ outlier=True,
+ timeout=10000,
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
@@ -2285,6 +2385,11 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
+ # If we are going to send this event over federation we precaclculate
+ # the joined hosts.
+ if event.internal_metadata.get_send_on_behalf_of():
+ await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
return context
async def _check_for_soft_fail(
@@ -2347,7 +2452,9 @@ class FederationHandler(BaseHandler):
)
logger.debug(
- "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
+ "Doing soft-fail check for %s: state %s",
+ event.event_id,
+ current_state_ids,
)
# Now check if event pass auth against said current state
@@ -2386,7 +2493,7 @@ class FederationHandler(BaseHandler):
# Now get the current auth_chain for the event.
local_auth_chain = await self.store.get_auth_chain(
- list(event.auth_event_ids()), include_given=True
+ room_id, list(event.auth_event_ids()), include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
@@ -2700,7 +2807,7 @@ class FederationHandler(BaseHandler):
async def construct_auth_difference(
self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
) -> Dict:
- """ Given a local and remote auth chain, find the differences. This
+ """Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
Params:
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index abd8d2af44..a41ca5df9c 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -15,9 +15,13 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import GroupID, get_domain_from_id
+from synapse.types import GroupID, JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -29,7 +33,7 @@ def _create_rerouter(func_name):
async def f(self, group_id, *args, **kwargs):
if not GroupID.is_valid(group_id):
- raise SynapseError(400, "%s was not legal group ID" % (group_id,))
+ raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
if self.is_mine_id(group_id):
return await getattr(self.groups_server_handler, func_name)(
@@ -56,7 +60,7 @@ def _create_rerouter(func_name):
class GroupsLocalWorkerHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler()
@@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")
- async def get_group_summary(self, group_id, requester_user_id):
+ async def get_group_summary(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get the group summary for a group.
If the group is remote we check that the users have valid attestations.
@@ -137,14 +143,14 @@ class GroupsLocalWorkerHandler:
return res
- async def get_users_in_group(self, group_id, requester_user_id):
- """Get users in a group
- """
+ async def get_users_in_group(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
+ """Get users in a group"""
if self.is_mine_id(group_id):
- res = await self.groups_server_handler.get_users_in_group(
+ return await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
- return res
group_server_name = get_domain_from_id(group_id)
@@ -178,11 +184,11 @@ class GroupsLocalWorkerHandler:
return res
- async def get_joined_groups(self, user_id):
+ async def get_joined_groups(self, user_id: str) -> JsonDict:
group_ids = await self.store.get_joined_groups(user_id)
return {"groups": group_ids}
- async def get_publicised_groups_for_user(self, user_id):
+ async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
if self.hs.is_mine_id(user_id):
result = await self.store.get_publicised_groups_for_user(user_id)
@@ -206,8 +212,10 @@ class GroupsLocalWorkerHandler:
# TODO: Verify attestations
return {"groups": result}
- async def bulk_get_publicised_groups(self, user_ids, proxy=True):
- destinations = {}
+ async def bulk_get_publicised_groups(
+ self, user_ids: Iterable[str], proxy: bool = True
+ ) -> JsonDict:
+ destinations = {} # type: Dict[str, Set[str]]
local_users = set()
for user_id in user_ids:
@@ -220,7 +228,7 @@ class GroupsLocalWorkerHandler:
raise SynapseError(400, "Some user_ids are not local")
results = {}
- failed_results = []
+ failed_results = [] # type: List[str]
for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
@@ -242,7 +250,7 @@ class GroupsLocalWorkerHandler:
class GroupsLocalHandler(GroupsLocalWorkerHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
# Ensure attestations get renewed
@@ -271,9 +279,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy")
- async def create_group(self, group_id, user_id, content):
- """Create a group
- """
+ async def create_group(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
+ """Create a group"""
logger.info("Asking to create group with ID: %r", group_id)
@@ -284,27 +293,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = None
remote_attestation = None
else:
- local_attestation = self.attestations.create_attestation(group_id, user_id)
- content["attestation"] = local_attestation
-
- content["user_profile"] = await self.profile_handler.get_profile(user_id)
-
- try:
- res = await self.transport_client.create_group(
- get_domain_from_id(group_id), group_id, user_id, content
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- remote_attestation = res["attestation"]
- await self.attestations.verify_attestation(
- remote_attestation,
- group_id=group_id,
- user_id=user_id,
- server_name=get_domain_from_id(group_id),
- )
+ raise SynapseError(400, "Unable to create remote groups")
is_publicised = content.get("publicise", False)
token = await self.store.register_user_group_membership(
@@ -320,9 +309,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- async def join_group(self, group_id, user_id, content):
- """Request to join a group
- """
+ async def join_group(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
+ """Request to join a group"""
if self.is_mine_id(group_id):
await self.groups_server_handler.join_group(group_id, user_id, content)
local_attestation = None
@@ -365,9 +355,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
- async def accept_invite(self, group_id, user_id, content):
- """Accept an invite to a group
- """
+ async def accept_invite(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
+ """Accept an invite to a group"""
if self.is_mine_id(group_id):
await self.groups_server_handler.accept_invite(group_id, user_id, content)
local_attestation = None
@@ -410,9 +401,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
- async def invite(self, group_id, user_id, requester_user_id, config):
- """Invite a user to a group
- """
+ async def invite(
+ self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
+ ) -> JsonDict:
+ """Invite a user to a group"""
content = {"requester_user_id": requester_user_id, "config": config}
if self.is_mine_id(group_id):
res = await self.groups_server_handler.invite_to_group(
@@ -434,9 +426,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- async def on_invite(self, group_id, user_id, content):
- """One of our users were invited to a group
- """
+ async def on_invite(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
+ """One of our users were invited to a group"""
# TODO: Support auto join and rejection
if not self.is_mine_id(user_id):
@@ -465,10 +458,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile}
async def remove_user_from_group(
- self, group_id, user_id, requester_user_id, content
- ):
- """Remove a user from a group
- """
+ self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
+ """Remove a user from a group"""
if user_id == requester_user_id:
token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave"
@@ -499,9 +491,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- async def user_removed_from_group(self, group_id, user_id, content):
- """One of our users was removed/kicked from a group
- """
+ async def user_removed_from_group(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> None:
+ """One of our users was removed/kicked from a group"""
# TODO: Check if user in group
token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave"
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 4291c4d81f..382d1dbe6f 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -29,9 +29,11 @@ from synapse.api.errors import (
ProxiedRequestError,
SynapseError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
+from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
@@ -47,7 +49,7 @@ class IdentityHandler(BaseHandler):
super().__init__(hs)
# An HTTP client for contacting trusted URLs.
- self.http_client = hs.get_simple_http_client()
+ self.http_client = SimpleHttpClient(hs)
# An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
@@ -55,13 +57,40 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
- self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
- self.trust_any_id_server_just_for_testing_do_not_use = (
- hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
- )
self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls
self._enable_lookup = hs.config.enable_3pid_lookup
+ self._web_client_location = hs.config.invite_client_location
+
+ # Ratelimiters for `/requestToken` endpoints.
+ self._3pid_validation_ratelimiter_ip = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+ burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+ )
+ self._3pid_validation_ratelimiter_address = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+ burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+ )
+
+ def ratelimit_request_token_requests(
+ self,
+ request: SynapseRequest,
+ medium: str,
+ address: str,
+ ):
+ """Used to ratelimit requests to `/requestToken` by IP and address.
+
+ Args:
+ request: The associated request
+ medium: The type of threepid, e.g. "msisdn" or "email"
+ address: The actual threepid ID, e.g. the phone number or email address
+ """
+
+ self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
+ self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+
async def threepid_from_creds(
self, id_server_url: str, creds: Dict[str, str]
) -> Optional[JsonDict]:
@@ -529,6 +558,8 @@ class IdentityHandler(BaseHandler):
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
+ # It is already checked that public_baseurl is configured since this code
+ # should only be used if account_threepid_delegate_msisdn is true.
assert self.hs.config.public_baseurl
# we need to tell the client to send the token back to us, since it doesn't
@@ -940,6 +971,9 @@ class IdentityHandler(BaseHandler):
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
+ # If a custom web client location is available, include it in the request.
+ if self._web_client_location:
+ invite_config["org.matrix.web_client_location"] = self._web_client_location
# Rewrite the identity server URL if necessary
id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
@@ -983,7 +1017,9 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning(
- "Error trying to call /store-invite on %s: %s", id_server_url, e,
+ "Error trying to call /store-invite on %s: %s",
+ id_server_url,
+ e,
)
if data is None:
@@ -1020,7 +1056,10 @@ class IdentityHandler(BaseHandler):
return token, public_keys, fallback_public_key, display_name
async def bind_email_using_internal_sydent_api(
- self, id_server_url: str, email: str, user_id: str,
+ self,
+ id_server_url: str,
+ email: str,
+ user_id: str,
):
"""Bind an email to a fully qualified user ID using the internal API of an
instance of Sydent.
@@ -1052,7 +1091,10 @@ class IdentityHandler(BaseHandler):
# Remember where we bound the threepid
await self.store.add_user_bound_threepid(
- user_id=user_id, medium="email", address=email, id_server=id_server,
+ user_id=user_id,
+ medium="email",
+ address=email,
+ id_server=id_server,
)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index cb11754bf8..13f8152283 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
from twisted.internet import defer
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
@@ -48,7 +48,7 @@ class InitialSyncHandler(BaseHandler):
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = ResponseCache(
- hs, "initial_sync_cache"
+ hs.get_clock(), "initial_sync_cache"
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
@@ -124,7 +124,8 @@ class InitialSyncHandler(BaseHandler):
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
receipt = await self.store.get_linearized_receipts_for_rooms(
- joined_rooms, to_key=int(now_token.receipt_key),
+ joined_rooms,
+ to_key=int(now_token.receipt_key),
)
tags_by_room = await self.store.get_tags_for_user(user_id)
@@ -169,7 +170,10 @@ class InitialSyncHandler(BaseHandler):
self.state_handler.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
- room_end_token = RoomStreamToken(None, event.stream_ordering,)
+ room_end_token = RoomStreamToken(
+ None,
+ event.stream_ordering,
+ )
deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id]
)
@@ -284,7 +288,9 @@ class InitialSyncHandler(BaseHandler):
membership,
member_event_id,
) = await self.auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True,
+ room_id,
+ user_id,
+ allow_departed_users=True,
)
is_peeking = member_event_id is None
@@ -323,9 +329,7 @@ class InitialSyncHandler(BaseHandler):
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
- room_state = await self.state_store.get_state_for_events([member_event_id])
-
- room_state = room_state[member_event_id]
+ room_state = await self.state_store.get_state_for_event(member_event_id)
limit = pagin_config.limit if pagin_config else None
if limit is None:
@@ -408,7 +412,7 @@ class InitialSyncHandler(BaseHandler):
return [
{
- "type": EventTypes.Presence,
+ "type": EduTypes.Presence,
"content": format_user_presence_state(s, time_now),
}
for s in states
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 2fcba6abfc..e06e8ff60c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -67,8 +67,7 @@ logger = logging.getLogger(__name__)
class MessageHandler:
- """Contains some read only APIs to get state about a room
- """
+ """Contains some read only APIs to get state about a room"""
def __init__(self, hs):
self.auth = hs.get_auth()
@@ -90,9 +89,13 @@ class MessageHandler:
)
async def get_room_data(
- self, user_id: str, room_id: str, event_type: str, state_key: str,
+ self,
+ user_id: str,
+ room_id: str,
+ event_type: str,
+ state_key: str,
) -> dict:
- """ Get data from a room.
+ """Get data from a room.
Args:
user_id
@@ -176,7 +179,10 @@ class MessageHandler:
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = await filter_events_for_client(
- self.storage, user_id, last_events, filter_send_to_client=False
+ self.storage,
+ user_id,
+ last_events,
+ filter_send_to_client=False,
)
event = last_events[0]
@@ -383,6 +389,12 @@ class EventCreationHandler:
self.room_invite_state_types = self.hs.config.room_invite_state_types
+ self.membership_types_to_include_profile_data_in = (
+ {Membership.JOIN, Membership.INVITE, Membership.KNOCK}
+ if self.hs.config.include_profile_data_on_invite
+ else {Membership.JOIN, Membership.KNOCK}
+ )
+
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
# This is only used to get at ratelimit function, and maybe_kick_guest_users
@@ -434,6 +446,8 @@ class EventCreationHandler:
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._external_cache = hs.get_external_cache()
+
async def create_event(
self,
requester: Requester,
@@ -494,7 +508,7 @@ class EventCreationHandler:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
- if membership in {Membership.JOIN, Membership.INVITE, Membership.KNOCK}:
+ if membership in self.membership_types_to_include_profile_data_in:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
@@ -571,7 +585,7 @@ class EventCreationHandler:
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
) -> bool:
- """"Determine if an event to be sent is exempt from having to consent
+ """ "Determine if an event to be sent is exempt from having to consent
to the privacy policy
Args:
@@ -746,7 +760,7 @@ class EventCreationHandler:
event.sender,
)
- spam_error = self.spam_checker.check_event_for_spam(event)
+ spam_error = await self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here"
@@ -793,9 +807,10 @@ class EventCreationHandler:
"""
if prev_event_ids is not None:
- assert len(prev_event_ids) <= 10, (
- "Attempting to create an event with %i prev_events"
- % (len(prev_event_ids),)
+ assert (
+ len(prev_event_ids) <= 10
+ ), "Attempting to create an event with %i prev_events" % (
+ len(prev_event_ids),
)
else:
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
@@ -821,7 +836,8 @@ class EventCreationHandler:
)
if not third_party_result:
logger.info(
- "Event %s forbidden by third-party rules", event,
+ "Event %s forbidden by third-party rules",
+ event,
)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
@@ -941,6 +957,8 @@ class EventCreationHandler:
await self.action_generator.handle_push_actions_for_event(event, context)
+ await self.cache_joined_hosts_for_event(event)
+
try:
# If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -980,6 +998,44 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
+ async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
+ """Precalculate the joined hosts at the event, when using Redis, so that
+ external federation senders don't have to recalculate it themselves.
+ """
+
+ if not self._external_cache.is_enabled():
+ return
+
+ # We actually store two mappings, event ID -> prev state group,
+ # state group -> joined hosts, which is much more space efficient
+ # than event ID -> joined hosts.
+ #
+ # Note: We have to cache event ID -> prev state group, as we don't
+ # store that in the DB.
+ #
+ # Note: We always set the state group -> joined hosts cache, even if
+ # we already set it, so that the expiry time is reset.
+
+ state_entry = await self.state.resolve_state_groups_for_events(
+ event.room_id, event_ids=event.prev_event_ids()
+ )
+
+ if state_entry.state_group:
+ joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+
+ await self._external_cache.set(
+ "event_to_prev_state_group",
+ event.event_id,
+ state_entry.state_group,
+ expiry_ms=60 * 60 * 1000,
+ )
+ await self._external_cache.set(
+ "get_joined_hosts",
+ str(state_entry.state_group),
+ list(joined_hosts),
+ expiry_ms=60 * 60 * 1000,
+ )
+
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
) -> None:
@@ -1131,7 +1187,8 @@ class EventCreationHandler:
event.unsigned[
"knock_room_state"
] = await self.store.get_stripped_room_state_from_event_context(
- context, DEFAULT_ROOM_STATE_TYPES,
+ context,
+ DEFAULT_ROOM_STATE_TYPES,
)
if event.type == EventTypes.Redaction:
@@ -1270,7 +1327,7 @@ class EventCreationHandler:
event, context = await self.create_event(
requester,
{
- "type": "org.matrix.dummy_event",
+ "type": EventTypes.Dummy,
"content": {},
"room_id": room_id,
"sender": user_id,
@@ -1283,7 +1340,11 @@ class EventCreationHandler:
# Since this is a dummy-event it is OK if it is sent by a
# shadow-banned user.
await self.handle_new_client_event(
- requester, event, context, ratelimit=False, ignore_shadow_ban=True,
+ requester,
+ event,
+ context,
+ ratelimit=False,
+ ignore_shadow_ban=True,
)
return True
except AuthError:
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index c605f7082a..6624212d6f 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,13 +15,13 @@
# limitations under the License.
import inspect
import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode
import attr
import pymacaroons
from authlib.common.security import generate_token
-from authlib.jose import JsonWebToken
+from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
@@ -28,26 +29,52 @@ from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
MacaroonDeserializationException,
+ MacaroonInitException,
MacaroonInvalidSignatureException,
)
from typing_extensions import TypedDict
from twisted.web.client import readBody
+from twisted.web.http_headers import Headers
from synapse.config import ConfigError
-from synapse.handlers._base import BaseHandler
+from synapse.config.oidc_config import (
+ OidcProviderClientSecretJwtKey,
+ OidcProviderConfig,
+)
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
-from synapse.util import json_decoder
+from synapse.util import Clock, json_decoder
+from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-SESSION_COOKIE_NAME = b"oidc_session"
+# we want the cookie to be returned to us even when the request is the POSTed
+# result of a form on another domain, as is used with `response_mode=form_post`.
+#
+# Modern browsers will not do so unless we set SameSite=None; however *older*
+# browsers (including all versions of Safari on iOS 12?) don't support
+# SameSite=None, and interpret it as SameSite=Strict:
+# https://bugs.webkit.org/show_bug.cgi?id=198181
+#
+# As a rather painful workaround, we set *two* cookies, one with SameSite=None
+# and one with no SameSite, in the hope that at least one of them will get
+# back to us.
+#
+# Secure is necessary for SameSite=None (and, empirically, also breaks things
+# on iOS 12.)
+#
+# Here we have the names of the cookies, and the options we use to set them.
+_SESSION_COOKIES = [
+ (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"),
+ (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"),
+]
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
#: OpenID.Core sec 3.1.3.3.
@@ -71,9 +98,160 @@ JWK = Dict[str, str]
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
+class OidcHandler:
+ """Handles requests related to the OpenID Connect login flow."""
+
+ def __init__(self, hs: "HomeServer"):
+ self._sso_handler = hs.get_sso_handler()
+
+ provider_confs = hs.config.oidc.oidc_providers
+ # we should not have been instantiated if there is no configured provider.
+ assert provider_confs
+
+ self._token_generator = OidcSessionTokenGenerator(hs)
+ self._providers = {
+ p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
+ } # type: Dict[str, OidcProvider]
+
+ async def load_metadata(self) -> None:
+ """Validate the config and load the metadata from the remote endpoint.
+
+ Called at startup to ensure we have everything we need.
+ """
+ for idp_id, p in self._providers.items():
+ try:
+ await p.load_metadata()
+ await p.load_jwks()
+ except Exception as e:
+ raise Exception(
+ "Error while initialising OIDC provider %r" % (idp_id,)
+ ) from e
+
+ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+ """Handle an incoming request to /_synapse/client/oidc/callback
+
+ Since we might want to display OIDC-related errors in a user-friendly
+ way, we don't raise SynapseError from here. Instead, we call
+ ``self._sso_handler.render_error`` which displays an HTML page for the error.
+
+ Most of the OpenID Connect logic happens here:
+
+ - first, we check if there was any error returned by the provider and
+ display it
+ - then we fetch the session cookie, decode and verify it
+ - the ``state`` query parameter should match with the one stored in the
+ session cookie
+
+ Once we know the session is legit, we then delegate to the OIDC Provider
+ implementation, which will exchange the code with the provider and complete the
+ login/authentication.
+
+ Args:
+ request: the incoming request from the browser.
+ """
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
+ # The provider might redirect with an error.
+ # In that case, just display it as-is.
+ if b"error" in request.args:
+ # error response from the auth server. see:
+ # https://tools.ietf.org/html/rfc6749#section-4.1.2.1
+ # https://openid.net/specs/openid-connect-core-1_0.html#AuthError
+ error = request.args[b"error"][0].decode()
+ description = request.args.get(b"error_description", [b""])[0].decode()
+
+ # Most of the errors returned by the provider could be due by
+ # either the provider misbehaving or Synapse being misconfigured.
+ # The only exception of that is "access_denied", where the user
+ # probably cancelled the login flow. In other cases, log those errors.
+ logger.log(
+ logging.INFO if error == "access_denied" else logging.ERROR,
+ "Received OIDC callback with error: %s %s",
+ error,
+ description,
+ )
+
+ self._sso_handler.render_error(request, error, description)
+ return
+
+ # otherwise, it is presumably a successful response. see:
+ # https://tools.ietf.org/html/rfc6749#section-4.1.2
+
+ # Fetch the session cookie. See the comments on SESSION_COOKIES for why there
+ # are two.
+
+ for cookie_name, _ in _SESSION_COOKIES:
+ session = request.getCookie(cookie_name) # type: Optional[bytes]
+ if session is not None:
+ break
+ else:
+ logger.info("Received OIDC callback, with no session cookie")
+ self._sso_handler.render_error(
+ request, "missing_session", "No session cookie found"
+ )
+ return
+
+ # Remove the cookies. There is a good chance that if the callback failed
+ # once, it will fail next time and the code will already be exchanged.
+ # Removing the cookies early avoids spamming the provider with token requests.
+ #
+ # we have to build the header by hand rather than calling request.addCookie
+ # because the latter does not support SameSite=None
+ # (https://twistedmatrix.com/trac/ticket/10088)
+
+ for cookie_name, options in _SESSION_COOKIES:
+ request.cookies.append(
+ b"%s=; Expires=Thu, Jan 01 1970 00:00:00 UTC; %s"
+ % (cookie_name, options)
+ )
+
+ # Check for the state query parameter
+ if b"state" not in request.args:
+ logger.info("Received OIDC callback, with no state parameter")
+ self._sso_handler.render_error(
+ request, "invalid_request", "State parameter is missing"
+ )
+ return
+
+ state = request.args[b"state"][0].decode()
+
+ # Deserialize the session token and verify it.
+ try:
+ session_data = self._token_generator.verify_oidc_session_token(
+ session, state
+ )
+ except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e:
+ logger.exception("Invalid session for OIDC callback")
+ self._sso_handler.render_error(request, "invalid_session", str(e))
+ return
+ except MacaroonInvalidSignatureException as e:
+ logger.exception("Could not verify session for OIDC callback")
+ self._sso_handler.render_error(request, "mismatching_session", str(e))
+ return
+
+ logger.info("Received OIDC callback for IdP %s", session_data.idp_id)
+
+ oidc_provider = self._providers.get(session_data.idp_id)
+ if not oidc_provider:
+ logger.error("OIDC session uses unknown IdP %r", oidc_provider)
+ self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
+ return
+
+ if b"code" not in request.args:
+ logger.info("Code parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "Code parameter is missing"
+ )
+ return
+
+ code = request.args[b"code"][0].decode()
+
+ await oidc_provider.handle_oidc_callback(request, session_data, code)
+
+
class OidcError(Exception):
- """Used to catch errors when calling the token_endpoint
- """
+ """Used to catch errors when calling the token_endpoint"""
def __init__(self, error, error_description=None):
self.error = error
@@ -85,47 +263,86 @@ class OidcError(Exception):
return self.error
-class OidcHandler(BaseHandler):
- """Handles requests related to the OpenID Connect login flow.
+class OidcProvider:
+ """Wraps the config for a single OIDC IdentityProvider
+
+ Provides methods for handling redirect requests and callbacks via that particular
+ IdP.
"""
- def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ def __init__(
+ self,
+ hs: "HomeServer",
+ token_generator: "OidcSessionTokenGenerator",
+ provider: OidcProviderConfig,
+ ):
+ self._store = hs.get_datastore()
+
+ self._token_generator = token_generator
+
+ self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str
- self._scopes = hs.config.oidc_scopes # type: List[str]
- self._user_profile_method = hs.config.oidc_user_profile_method # type: str
+
+ self._oidc_attribute_requirements = provider.attribute_requirements
+ self._scopes = provider.scopes
+ self._user_profile_method = provider.user_profile_method
+
+ client_secret = None # type: Union[None, str, JwtClientSecret]
+ if provider.client_secret:
+ client_secret = provider.client_secret
+ elif provider.client_secret_jwt_key:
+ client_secret = JwtClientSecret(
+ provider.client_secret_jwt_key,
+ provider.client_id,
+ provider.issuer,
+ hs.get_clock(),
+ )
+
self._client_auth = ClientAuth(
- hs.config.oidc_client_id,
- hs.config.oidc_client_secret,
- hs.config.oidc_client_auth_method,
+ provider.client_id,
+ client_secret,
+ provider.client_auth_method,
) # type: ClientAuth
- self._client_auth_method = hs.config.oidc_client_auth_method # type: str
- self._provider_metadata = OpenIDProviderMetadata(
- issuer=hs.config.oidc_issuer,
- authorization_endpoint=hs.config.oidc_authorization_endpoint,
- token_endpoint=hs.config.oidc_token_endpoint,
- userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
- jwks_uri=hs.config.oidc_jwks_uri,
- ) # type: OpenIDProviderMetadata
- self._provider_needs_discovery = hs.config.oidc_discover # type: bool
- self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
- hs.config.oidc_user_mapping_provider_config
- ) # type: OidcMappingProvider
- self._skip_verification = hs.config.oidc_skip_verification # type: bool
- self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
+ self._client_auth_method = provider.client_auth_method
+
+ # cache of metadata for the identity provider (endpoint uris, mostly). This is
+ # loaded on-demand from the discovery endpoint (if discovery is enabled), with
+ # possible overrides from the config. Access via `load_metadata`.
+ self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
+
+ # cache of JWKs used by the identity provider to sign tokens. Loaded on demand
+ # from the IdP's jwks_uri, if required.
+ self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
+
+ self._user_mapping_provider = provider.user_mapping_provider_class(
+ provider.user_mapping_provider_config
+ )
+ self._skip_verification = provider.skip_verification
+ self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client()
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
self._server_name = hs.config.server_name # type: str
- self._macaroon_secret_key = hs.config.macaroon_secret_key
# identifier for the external_ids table
- self._auth_provider_id = "oidc"
+ self.idp_id = provider.idp_id
+
+ # user-facing name of this auth provider
+ self.idp_name = provider.idp_name
+
+ # MXC URI for icon for this auth provider
+ self.idp_icon = provider.idp_icon
+
+ # optional brand identifier for this auth provider
+ self.idp_brand = provider.idp_brand
+
+ # Optional brand identifier for the unstable API (see MSC2858).
+ self.unstable_idp_brand = provider.unstable_idp_brand
self._sso_handler = hs.get_sso_handler()
- def _validate_metadata(self):
+ self._sso_handler.register_identity_provider(self)
+
+ def _validate_metadata(self, m: OpenIDProviderMetadata) -> None:
"""Verifies the provider metadata.
This checks the validity of the currently loaded provider. Not
@@ -144,7 +361,6 @@ class OidcHandler(BaseHandler):
if self._skip_verification is True:
return
- m = self._provider_metadata
m.validate_issuer()
m.validate_authorization_endpoint()
m.validate_token_endpoint()
@@ -179,11 +395,7 @@ class OidcHandler(BaseHandler):
)
else:
# If we're not using userinfo, we need a valid jwks to validate the ID token
- if m.get("jwks") is None:
- if m.get("jwks_uri") is not None:
- m.validate_jwks_uri()
- else:
- raise ValueError('"jwks_uri" must be set')
+ m.validate_jwks_uri()
@property
def _uses_userinfo(self) -> bool:
@@ -200,11 +412,15 @@ class OidcHandler(BaseHandler):
or self._user_profile_method == "userinfo_endpoint"
)
- async def load_metadata(self) -> OpenIDProviderMetadata:
- """Load and validate the provider metadata.
+ async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
+ """Return the provider metadata.
+
+ If this is the first call, the metadata is built from the config and from the
+ metadata discovery endpoint (if enabled), and then validated. If the metadata
+ is successfully validated, it is then cached for future use.
- The values metadatas are discovered if ``oidc_config.discovery`` is
- ``True`` and then cached.
+ Args:
+ force: If true, any cached metadata is discarded to force a reload.
Raises:
ValueError: if something in the provider is not valid
@@ -212,18 +428,41 @@ class OidcHandler(BaseHandler):
Returns:
The provider's metadata.
"""
- # If we are using the OpenID Discovery documents, it needs to be loaded once
- # FIXME: should there be a lock here?
- if self._provider_needs_discovery:
- url = get_well_known_url(self._provider_metadata["issuer"], external=True)
+ if force:
+ # reset the cached call to ensure we get a new result
+ self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
+
+ return await self._provider_metadata.get()
+
+ async def _load_metadata(self) -> OpenIDProviderMetadata:
+ # start out with just the issuer (unlike the other settings, discovered issuer
+ # takes precedence over configured issuer, because configured issuer is
+ # required for discovery to take place.)
+ #
+ metadata = OpenIDProviderMetadata(issuer=self._config.issuer)
+
+ # load any data from the discovery endpoint, if enabled
+ if self._config.discover:
+ url = get_well_known_url(self._config.issuer, external=True)
metadata_response = await self._http_client.get_json(url)
- # TODO: maybe update the other way around to let user override some values?
- self._provider_metadata.update(metadata_response)
- self._provider_needs_discovery = False
+ metadata.update(metadata_response)
- self._validate_metadata()
+ # override any discovered data with any settings in our config
+ if self._config.authorization_endpoint:
+ metadata["authorization_endpoint"] = self._config.authorization_endpoint
- return self._provider_metadata
+ if self._config.token_endpoint:
+ metadata["token_endpoint"] = self._config.token_endpoint
+
+ if self._config.userinfo_endpoint:
+ metadata["userinfo_endpoint"] = self._config.userinfo_endpoint
+
+ if self._config.jwks_uri:
+ metadata["jwks_uri"] = self._config.jwks_uri
+
+ self._validate_metadata(metadata)
+
+ return metadata
async def load_jwks(self, force: bool = False) -> JWKS:
"""Load the JSON Web Key Set used to sign ID tokens.
@@ -253,27 +492,27 @@ class OidcHandler(BaseHandler):
]
}
"""
+ if force:
+ # reset the cached call to ensure we get a new result
+ self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
+ return await self._jwks.get()
+
+ async def _load_jwks(self) -> JWKS:
if self._uses_userinfo:
# We're not using jwt signing, return an empty jwk set
return {"keys": []}
- # First check if the JWKS are loaded in the provider metadata.
- # It can happen either if the provider gives its JWKS in the discovery
- # document directly or if it was already loaded once.
metadata = await self.load_metadata()
- jwk_set = metadata.get("jwks")
- if jwk_set is not None and not force:
- return jwk_set
- # Loading the JWKS using the `jwks_uri` metadata
+ # Load the JWKS using the `jwks_uri` metadata.
uri = metadata.get("jwks_uri")
if not uri:
+ # this should be unreachable: load_metadata validates that
+ # there is a jwks_uri in the metadata if _uses_userinfo is unset
raise RuntimeError('Missing "jwks_uri" in metadata')
jwk_set = await self._http_client.get_json(uri)
- # Caching the JWKS in the provider's metadata
- self._provider_metadata["jwks"] = jwk_set
return jwk_set
async def _exchange_code(self, code: str) -> Token:
@@ -308,7 +547,7 @@ class OidcHandler(BaseHandler):
"""
metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint")
- headers = {
+ raw_headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent,
"Accept": "application/json",
@@ -322,16 +561,19 @@ class OidcHandler(BaseHandler):
body = urlencode(args, True)
# Fill the body/headers with credentials
- uri, headers, body = self._client_auth.prepare(
- method="POST", uri=token_endpoint, headers=headers, body=body
+ uri, raw_headers, body = self._client_auth.prepare(
+ method="POST", uri=token_endpoint, headers=raw_headers, body=body
)
- headers = {k: [v] for (k, v) in headers.items()}
+ headers = Headers({k: [v] for (k, v) in raw_headers.items()})
# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to
# check the HTTP status code and we do the body encoding ourself.
response = await self._http_client.request(
- method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
+ method="POST",
+ uri=uri,
+ data=body.encode("utf-8"),
+ headers=headers,
)
# This is used in multiple error messages below
@@ -409,6 +651,7 @@ class OidcHandler(BaseHandler):
Returns:
UserInfo: an object representing the user.
"""
+ logger.debug("Using the OAuth2 access_token to request userinfo")
metadata = await self.load_metadata()
resp = await self._http_client.get_json(
@@ -416,6 +659,8 @@ class OidcHandler(BaseHandler):
headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
)
+ logger.debug("Retrieved user info from userinfo endpoint: %r", resp)
+
return UserInfo(resp)
async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
@@ -444,17 +689,19 @@ class OidcHandler(BaseHandler):
claims_cls = ImplicitIDToken
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
-
jwt = JsonWebToken(alg_values)
claim_options = {"iss": {"values": [metadata["issuer"]]}}
+ id_token = token["id_token"]
+ logger.debug("Attempting to decode JWT id_token %r", id_token)
+
# Try to decode the keys in cache first, then retry by forcing the keys
# to be reloaded
jwk_set = await self.load_jwks()
try:
claims = jwt.decode(
- token["id_token"],
+ id_token,
key=jwk_set,
claims_cls=claims_cls,
claims_options=claim_options,
@@ -464,20 +711,22 @@ class OidcHandler(BaseHandler):
logger.info("Reloading JWKS after decode error")
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
claims = jwt.decode(
- token["id_token"],
+ id_token,
key=jwk_set,
claims_cls=claims_cls,
claims_options=claim_options,
claims_params=claims_params,
)
+ logger.debug("Decoded id_token JWT %r; validating", claims)
+
claims.validate(leeway=120) # allows 2 min of clock skew
return UserInfo(claims)
async def handle_redirect_request(
self,
request: SynapseRequest,
- client_redirect_url: bytes,
+ client_redirect_url: Optional[bytes],
ui_auth_session_id: Optional[str] = None,
) -> str:
"""Handle an incoming request to /login/sso/redirect
@@ -487,7 +736,7 @@ class OidcHandler(BaseHandler):
- ``client_id``: the client ID set in ``oidc_config.client_id``
- ``response_type``: ``code``
- - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
+ - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback``
- ``scope``: the list of scopes set in ``oidc_config.scopes``
- ``state``: a random string
- ``nonce``: a random string
@@ -501,7 +750,7 @@ class OidcHandler(BaseHandler):
request: the incoming request from the browser.
We'll respond to it with a redirect and a cookie.
client_redirect_url: the URL that we should redirect the client to
- when everything is done
+ when everything is done (or None for UI Auth)
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
@@ -513,21 +762,31 @@ class OidcHandler(BaseHandler):
state = generate_token()
nonce = generate_token()
- cookie = self._generate_oidc_session_token(
+ if not client_redirect_url:
+ client_redirect_url = b""
+
+ cookie = self._token_generator.generate_oidc_session_token(
state=state,
- nonce=nonce,
- client_redirect_url=client_redirect_url.decode(),
- ui_auth_session_id=ui_auth_session_id,
- )
- request.addCookie(
- SESSION_COOKIE_NAME,
- cookie,
- path="/_synapse/oidc",
- max_age="3600",
- httpOnly=True,
- sameSite="lax",
+ session_data=OidcSessionData(
+ idp_id=self.idp_id,
+ nonce=nonce,
+ client_redirect_url=client_redirect_url.decode(),
+ ui_auth_session_id=ui_auth_session_id or "",
+ ),
)
+ # Set the cookies. See the comments on _SESSION_COOKIES for why there are two.
+ #
+ # we have to build the header by hand rather than calling request.addCookie
+ # because the latter does not support SameSite=None
+ # (https://twistedmatrix.com/trac/ticket/10088)
+
+ for cookie_name, options in _SESSION_COOKIES:
+ request.cookies.append(
+ b"%s=%s; Max-Age=3600; %s"
+ % (cookie_name, cookie.encode("utf-8"), options)
+ )
+
metadata = await self.load_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
return prepare_grant_uri(
@@ -540,22 +799,16 @@ class OidcHandler(BaseHandler):
nonce=nonce,
)
- async def handle_oidc_callback(self, request: SynapseRequest) -> None:
- """Handle an incoming request to /_synapse/oidc/callback
+ async def handle_oidc_callback(
+ self, request: SynapseRequest, session_data: "OidcSessionData", code: str
+ ) -> None:
+ """Handle an incoming request to /_synapse/client/oidc/callback
- Since we might want to display OIDC-related errors in a user-friendly
- way, we don't raise SynapseError from here. Instead, we call
- ``self._sso_handler.render_error`` which displays an HTML page for the error.
+ By this time we have already validated the session on the synapse side, and
+ now need to do the provider-specific operations. This includes:
- Most of the OpenID Connect logic happens here:
-
- - first, we check if there was any error returned by the provider and
- display it
- - then we fetch the session cookie, decode and verify it
- - the ``state`` query parameter should match with the one stored in the
- session cookie
- - once we known this session is legit, exchange the code with the
- provider using the ``token_endpoint`` (see ``_exchange_code``)
+ - exchange the code with the provider using the ``token_endpoint`` (see
+ ``_exchange_code``)
- once we have the token, use it to either extract the UserInfo from
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
to fetch UserInfo from the ``userinfo_endpoint``
@@ -565,100 +818,23 @@ class OidcHandler(BaseHandler):
Args:
request: the incoming request from the browser.
+ session_data: the session data, extracted from our cookie
+ code: The authorization code we got from the callback.
"""
-
- # The provider might redirect with an error.
- # In that case, just display it as-is.
- if b"error" in request.args:
- # error response from the auth server. see:
- # https://tools.ietf.org/html/rfc6749#section-4.1.2.1
- # https://openid.net/specs/openid-connect-core-1_0.html#AuthError
- error = request.args[b"error"][0].decode()
- description = request.args.get(b"error_description", [b""])[0].decode()
-
- # Most of the errors returned by the provider could be due by
- # either the provider misbehaving or Synapse being misconfigured.
- # The only exception of that is "access_denied", where the user
- # probably cancelled the login flow. In other cases, log those errors.
- if error != "access_denied":
- logger.error("Error from the OIDC provider: %s %s", error, description)
-
- self._sso_handler.render_error(request, error, description)
- return
-
- # otherwise, it is presumably a successful response. see:
- # https://tools.ietf.org/html/rfc6749#section-4.1.2
-
- # Fetch the session cookie
- session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
- if session is None:
- logger.info("No session cookie found")
- self._sso_handler.render_error(
- request, "missing_session", "No session cookie found"
- )
- return
-
- # Remove the cookie. There is a good chance that if the callback failed
- # once, it will fail next time and the code will already be exchanged.
- # Removing it early avoids spamming the provider with token requests.
- request.addCookie(
- SESSION_COOKIE_NAME,
- b"",
- path="/_synapse/oidc",
- expires="Thu, Jan 01 1970 00:00:00 UTC",
- httpOnly=True,
- sameSite="lax",
- )
-
- # Check for the state query parameter
- if b"state" not in request.args:
- logger.info("State parameter is missing")
- self._sso_handler.render_error(
- request, "invalid_request", "State parameter is missing"
- )
- return
-
- state = request.args[b"state"][0].decode()
-
- # Deserialize the session token and verify it.
- try:
- (
- nonce,
- client_redirect_url,
- ui_auth_session_id,
- ) = self._verify_oidc_session_token(session, state)
- except MacaroonDeserializationException as e:
- logger.exception("Invalid session")
- self._sso_handler.render_error(request, "invalid_session", str(e))
- return
- except MacaroonInvalidSignatureException as e:
- logger.exception("Could not verify session")
- self._sso_handler.render_error(request, "mismatching_session", str(e))
- return
-
# Exchange the code with the provider
- if b"code" not in request.args:
- logger.info("Code parameter is missing")
- self._sso_handler.render_error(
- request, "invalid_request", "Code parameter is missing"
- )
- return
-
- logger.debug("Exchanging code")
- code = request.args[b"code"][0].decode()
try:
+ logger.debug("Exchanging OAuth2 code for a token")
token = await self._exchange_code(code)
except OidcError as e:
- logger.exception("Could not exchange code")
+ logger.exception("Could not exchange OAuth2 code")
self._sso_handler.render_error(request, e.error, e.error_description)
return
- logger.debug("Successfully obtained OAuth2 access token")
+ logger.debug("Successfully obtained OAuth2 token data: %r", token)
# Now that we have a token, get the userinfo, either by decoding the
# `id_token` or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
- logger.debug("Fetching userinfo")
try:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
@@ -666,172 +842,57 @@ class OidcHandler(BaseHandler):
self._sso_handler.render_error(request, "fetch_error", str(e))
return
else:
- logger.debug("Extracting userinfo from id_token")
try:
- userinfo = await self._parse_id_token(token, nonce=nonce)
+ userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
+ # first check if we're doing a UIA
+ if session_data.ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_userinfo(userinfo)
+ except Exception as e:
+ logger.exception("Could not extract remote user id")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+ return
+
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
+ )
+
+ # otherwise, it's a login
+ logger.debug("Userinfo for OIDC login: %s", userinfo)
+
+ # Ensure that the attributes of the logged in user meet the required
+ # attributes by checking the userinfo against attribute_requirements
+ # In order to deal with the fact that OIDC userinfo can contain many
+ # types of data, we wrap non-list values in lists.
+ if not self._sso_handler.check_required_attributes(
+ request,
+ {k: v if isinstance(v, list) else [v] for k, v in userinfo.items()},
+ self._oidc_attribute_requirements,
+ ):
+ return
# Call the mapper to register/login the user
try:
- user_id = await self._map_userinfo_to_user(
- userinfo, token, user_agent, ip_address
+ await self._complete_oidc_login(
+ userinfo, token, request, session_data.client_redirect_url
)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
- return
- # Mapping providers might not have get_extra_attributes: only call this
- # method if it exists.
- extra_attributes = None
- get_extra_attributes = getattr(
- self._user_mapping_provider, "get_extra_attributes", None
- )
- if get_extra_attributes:
- extra_attributes = await get_extra_attributes(userinfo, token)
-
- # and finally complete the login
- if ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, ui_auth_session_id, request
- )
- else:
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url, extra_attributes
- )
-
- def _generate_oidc_session_token(
+ async def _complete_oidc_login(
self,
- state: str,
- nonce: str,
+ userinfo: UserInfo,
+ token: Token,
+ request: SynapseRequest,
client_redirect_url: str,
- ui_auth_session_id: Optional[str],
- duration_in_ms: int = (60 * 60 * 1000),
- ) -> str:
- """Generates a signed token storing data about an OIDC session.
-
- When Synapse initiates an authorization flow, it creates a random state
- and a random nonce. Those parameters are given to the provider and
- should be verified when the client comes back from the provider.
- It is also used to store the client_redirect_url, which is used to
- complete the SSO login flow.
-
- Args:
- state: The ``state`` parameter passed to the OIDC provider.
- nonce: The ``nonce`` parameter passed to the OIDC provider.
- client_redirect_url: The URL the client gave when it initiated the
- flow.
- ui_auth_session_id: The session ID of the ongoing UI Auth (or
- None if this is a login).
- duration_in_ms: An optional duration for the token in milliseconds.
- Defaults to an hour.
-
- Returns:
- A signed macaroon token with the session information.
- """
- macaroon = pymacaroons.Macaroon(
- location=self._server_name, identifier="key", key=self._macaroon_secret_key,
- )
- macaroon.add_first_party_caveat("gen = 1")
- macaroon.add_first_party_caveat("type = session")
- macaroon.add_first_party_caveat("state = %s" % (state,))
- macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
- macaroon.add_first_party_caveat(
- "client_redirect_url = %s" % (client_redirect_url,)
- )
- if ui_auth_session_id:
- macaroon.add_first_party_caveat(
- "ui_auth_session_id = %s" % (ui_auth_session_id,)
- )
- now = self.clock.time_msec()
- expiry = now + duration_in_ms
- macaroon.add_first_party_caveat("time < %d" % (expiry,))
-
- return macaroon.serialize()
-
- def _verify_oidc_session_token(
- self, session: bytes, state: str
- ) -> Tuple[str, str, Optional[str]]:
- """Verifies and extract an OIDC session token.
-
- This verifies that a given session token was issued by this homeserver
- and extract the nonce and client_redirect_url caveats.
-
- Args:
- session: The session token to verify
- state: The state the OIDC provider gave back
-
- Returns:
- The nonce, client_redirect_url, and ui_auth_session_id for this session
- """
- macaroon = pymacaroons.Macaroon.deserialize(session)
-
- v = pymacaroons.Verifier()
- v.satisfy_exact("gen = 1")
- v.satisfy_exact("type = session")
- v.satisfy_exact("state = %s" % (state,))
- v.satisfy_general(lambda c: c.startswith("nonce = "))
- v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
- # Sometimes there's a UI auth session ID, it seems to be OK to attempt
- # to always satisfy this.
- v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
- v.satisfy_general(self._verify_expiry)
-
- v.verify(macaroon, self._macaroon_secret_key)
-
- # Extract the `nonce`, `client_redirect_url`, and maybe the
- # `ui_auth_session_id` from the token.
- nonce = self._get_value_from_macaroon(macaroon, "nonce")
- client_redirect_url = self._get_value_from_macaroon(
- macaroon, "client_redirect_url"
- )
- try:
- ui_auth_session_id = self._get_value_from_macaroon(
- macaroon, "ui_auth_session_id"
- ) # type: Optional[str]
- except ValueError:
- ui_auth_session_id = None
-
- return nonce, client_redirect_url, ui_auth_session_id
-
- def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
- """Extracts a caveat value from a macaroon token.
-
- Args:
- macaroon: the token
- key: the key of the caveat to extract
-
- Returns:
- The extracted value
-
- Raises:
- Exception: if the caveat was not in the macaroon
- """
- prefix = key + " = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(prefix):
- return caveat.caveat_id[len(prefix) :]
- raise ValueError("No %s caveat in macaroon" % (key,))
-
- def _verify_expiry(self, caveat: str) -> bool:
- prefix = "time < "
- if not caveat.startswith(prefix):
- return False
- expiry = int(caveat[len(prefix) :])
- now = self.clock.time_msec()
- return now < expiry
-
- async def _map_userinfo_to_user(
- self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
- ) -> str:
- """Maps a UserInfo object to a mxid.
+ ) -> None:
+ """Given a UserInfo response, complete the login flow
UserInfo should have a claim that uniquely identifies users. This claim
is usually `sub`, but can be configured with `oidc_config.subject_claim`.
@@ -843,27 +904,23 @@ class OidcHandler(BaseHandler):
If a user already exists with the mxid we've mapped and allow_existing_users
is disabled, raise an exception.
+ Otherwise, render a redirect back to the client_redirect_url with a loginToken.
+
Args:
userinfo: an object representing the user
token: a dict with the tokens obtained from the provider
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
+ request: The request to respond to
+ client_redirect_url: The redirect URL passed in by the client.
Raises:
MappingException: if there was an error while mapping some properties
-
- Returns:
- The mxid of the user
"""
try:
- remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
raise MappingException(
"Failed to extract subject from OIDC response: %s" % (e,)
)
- # Some OIDC providers use integer IDs, but Synapse expects external IDs
- # to be strings.
- remote_user_id = str(remote_user_id)
# Older mapping providers don't accept the `failures` argument, so we
# try and detect support.
@@ -904,8 +961,8 @@ class OidcHandler(BaseHandler):
# and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0)
- user_id = UserID(attributes.localpart, self.server_name).to_string()
- users = await self.store.get_users_by_id_case_insensitive(user_id)
+ user_id = UserID(attributes.localpart, self._server_name).to_string()
+ users = await self._store.get_users_by_id_case_insensitive(user_id)
if users:
# If an existing matrix ID is returned, then use it.
if len(users) == 1:
@@ -924,18 +981,232 @@ class OidcHandler(BaseHandler):
return None
- return await self._sso_handler.get_mxid_from_sso(
- self._auth_provider_id,
+ # Mapping providers might not have get_extra_attributes: only call this
+ # method if it exists.
+ extra_attributes = None
+ get_extra_attributes = getattr(
+ self._user_mapping_provider, "get_extra_attributes", None
+ )
+ if get_extra_attributes:
+ extra_attributes = await get_extra_attributes(userinfo, token)
+
+ await self._sso_handler.complete_sso_login_request(
+ self.idp_id,
remote_user_id,
- user_agent,
- ip_address,
+ request,
+ client_redirect_url,
oidc_response_to_user_attributes,
grandfather_existing_users,
+ extra_attributes,
+ )
+
+ def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
+ """Extract the unique remote id from an OIDC UserInfo block
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ Returns:
+ remote user id
+ """
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ # Some OIDC providers use integer IDs, but Synapse expects external IDs
+ # to be strings.
+ return str(remote_user_id)
+
+
+# number of seconds a newly-generated client secret should be valid for
+CLIENT_SECRET_VALIDITY_SECONDS = 3600
+
+# minimum remaining validity on a client secret before we should generate a new one
+CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600
+
+
+class JwtClientSecret:
+ """A class which generates a new client secret on demand, based on a JWK
+
+ This implementation is designed to comply with the requirements for Apple Sign in:
+ https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048
+
+ It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
+ but it's worth noting that we still put the generated secret in the "client_secret"
+ field (or rather, whereever client_auth_method puts it) rather than in a
+ client_assertion field in the body as that RFC seems to require.
+ """
+
+ def __init__(
+ self,
+ key: OidcProviderClientSecretJwtKey,
+ oauth_client_id: str,
+ oauth_issuer: str,
+ clock: Clock,
+ ):
+ self._key = key
+ self._oauth_client_id = oauth_client_id
+ self._oauth_issuer = oauth_issuer
+ self._clock = clock
+ self._cached_secret = b""
+ self._cached_secret_replacement_time = 0
+
+ def __str__(self):
+ # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
+ # encode_client_secret_basic, which calls "{}".format(secret), which ends up
+ # here.
+ return self._get_secret().decode("ascii")
+
+ def __bytes__(self):
+ # if client_auth_method is client_secret_post, then ClientAuth.prepare calls
+ # encode_client_secret_post, which ends up here.
+ return self._get_secret()
+
+ def _get_secret(self) -> bytes:
+ now = self._clock.time()
+
+ # if we have enough validity on our existing secret, use it
+ if now < self._cached_secret_replacement_time:
+ return self._cached_secret
+
+ issued_at = int(now)
+ expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS
+
+ # we copy the configured header because jwt.encode modifies it.
+ header = dict(self._key.jwt_header)
+
+ # see https://tools.ietf.org/html/rfc7523#section-3
+ payload = {
+ "sub": self._oauth_client_id,
+ "aud": self._oauth_issuer,
+ "iat": issued_at,
+ "exp": expires_at,
+ **self._key.jwt_payload,
+ }
+ logger.info(
+ "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
+ )
+ self._cached_secret = jwt.encode(header, payload, self._key.key)
+ self._cached_secret_replacement_time = (
+ expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
+ )
+ return self._cached_secret
+
+
+class OidcSessionTokenGenerator:
+ """Methods for generating and checking OIDC Session cookies."""
+
+ def __init__(self, hs: "HomeServer"):
+ self._clock = hs.get_clock()
+ self._server_name = hs.hostname
+ self._macaroon_secret_key = hs.config.key.macaroon_secret_key
+
+ def generate_oidc_session_token(
+ self,
+ state: str,
+ session_data: "OidcSessionData",
+ duration_in_ms: int = (60 * 60 * 1000),
+ ) -> str:
+ """Generates a signed token storing data about an OIDC session.
+
+ When Synapse initiates an authorization flow, it creates a random state
+ and a random nonce. Those parameters are given to the provider and
+ should be verified when the client comes back from the provider.
+ It is also used to store the client_redirect_url, which is used to
+ complete the SSO login flow.
+
+ Args:
+ state: The ``state`` parameter passed to the OIDC provider.
+ session_data: data to include in the session token.
+ duration_in_ms: An optional duration for the token in milliseconds.
+ Defaults to an hour.
+
+ Returns:
+ A signed macaroon token with the session information.
+ """
+ macaroon = pymacaroons.Macaroon(
+ location=self._server_name,
+ identifier="key",
+ key=self._macaroon_secret_key,
+ )
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = session")
+ macaroon.add_first_party_caveat("state = %s" % (state,))
+ macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
+ macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
+ macaroon.add_first_party_caveat(
+ "client_redirect_url = %s" % (session_data.client_redirect_url,)
+ )
+ macaroon.add_first_party_caveat(
+ "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+ )
+ now = self._clock.time_msec()
+ expiry = now + duration_in_ms
+ macaroon.add_first_party_caveat("time < %d" % (expiry,))
+
+ return macaroon.serialize()
+
+ def verify_oidc_session_token(
+ self, session: bytes, state: str
+ ) -> "OidcSessionData":
+ """Verifies and extract an OIDC session token.
+
+ This verifies that a given session token was issued by this homeserver
+ and extract the nonce and client_redirect_url caveats.
+
+ Args:
+ session: The session token to verify
+ state: The state the OIDC provider gave back
+
+ Returns:
+ The data extracted from the session cookie
+
+ Raises:
+ KeyError if an expected caveat is missing from the macaroon.
+ """
+ macaroon = pymacaroons.Macaroon.deserialize(session)
+
+ v = pymacaroons.Verifier()
+ v.satisfy_exact("gen = 1")
+ v.satisfy_exact("type = session")
+ v.satisfy_exact("state = %s" % (state,))
+ v.satisfy_general(lambda c: c.startswith("nonce = "))
+ v.satisfy_general(lambda c: c.startswith("idp_id = "))
+ v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+ v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
+ satisfy_expiry(v, self._clock.time_msec)
+
+ v.verify(macaroon, self._macaroon_secret_key)
+
+ # Extract the session data from the token.
+ nonce = get_value_from_macaroon(macaroon, "nonce")
+ idp_id = get_value_from_macaroon(macaroon, "idp_id")
+ client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
+ ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
+ return OidcSessionData(
+ nonce=nonce,
+ idp_id=idp_id,
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=ui_auth_session_id,
)
+@attr.s(frozen=True, slots=True)
+class OidcSessionData:
+ """The attributes which are stored in a OIDC session cookie"""
+
+ # the Identity Provider being used
+ idp_id = attr.ib(type=str)
+
+ # The `nonce` parameter passed to the OIDC provider.
+ nonce = attr.ib(type=str)
+
+ # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
+ client_redirect_url = attr.ib(type=str)
+
+ # The session ID of the ongoing UI Auth ("" if this is a login)
+ ui_auth_session_id = attr.ib(type=str)
+
+
UserAttributeDict = TypedDict(
- "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
+ "UserAttributeDict",
+ {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
)
C = TypeVar("C")
@@ -1014,12 +1285,13 @@ def jinja_finalize(thing):
env = Environment(finalize=jinja_finalize)
-@attr.s
+@attr.s(slots=True, frozen=True)
class JinjaOidcMappingConfig:
- subject_claim = attr.ib() # type: str
- localpart_template = attr.ib() # type: Template
- display_name_template = attr.ib() # type: Optional[Template]
- extra_attributes = attr.ib() # type: Dict[str, Template]
+ subject_claim = attr.ib(type=str)
+ localpart_template = attr.ib(type=Optional[Template])
+ display_name_template = attr.ib(type=Optional[Template])
+ email_template = attr.ib(type=Optional[Template])
+ extra_attributes = attr.ib(type=Dict[str, Template])
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1035,50 +1307,37 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
def parse_config(config: dict) -> JinjaOidcMappingConfig:
subject_claim = config.get("subject_claim", "sub")
- if "localpart_template" not in config:
- raise ConfigError(
- "missing key: oidc_config.user_mapping_provider.config.localpart_template"
- )
-
- try:
- localpart_template = env.from_string(config["localpart_template"])
- except Exception as e:
- raise ConfigError(
- "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
- % (e,)
- )
-
- display_name_template = None # type: Optional[Template]
- if "display_name_template" in config:
+ def parse_template_config(option_name: str) -> Optional[Template]:
+ if option_name not in config:
+ return None
try:
- display_name_template = env.from_string(config["display_name_template"])
+ return env.from_string(config[option_name])
except Exception as e:
- raise ConfigError(
- "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
- % (e,)
- )
+ raise ConfigError("invalid jinja template", path=[option_name]) from e
+
+ localpart_template = parse_template_config("localpart_template")
+ display_name_template = parse_template_config("display_name_template")
+ email_template = parse_template_config("email_template")
extra_attributes = {} # type Dict[str, Template]
if "extra_attributes" in config:
extra_attributes_config = config.get("extra_attributes") or {}
if not isinstance(extra_attributes_config, dict):
- raise ConfigError(
- "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
- )
+ raise ConfigError("must be a dict", path=["extra_attributes"])
for key, value in extra_attributes_config.items():
try:
extra_attributes[key] = env.from_string(value)
except Exception as e:
raise ConfigError(
- "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
- % (key, e)
- )
+ "invalid jinja template", path=["extra_attributes", key]
+ ) from e
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
+ email_template=email_template,
extra_attributes=extra_attributes,
)
@@ -1088,25 +1347,35 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
async def map_user_attributes(
self, userinfo: UserInfo, token: Token, failures: int
) -> UserAttributeDict:
- localpart = self._config.localpart_template.render(user=userinfo).strip()
+ localpart = None
- # Ensure only valid characters are included in the MXID.
- localpart = map_username_to_mxid_localpart(localpart)
+ if self._config.localpart_template:
+ localpart = self._config.localpart_template.render(user=userinfo).strip()
- # Append suffix integer if last call to this function failed to produce
- # a usable mxid.
- localpart += str(failures) if failures else ""
+ # Ensure only valid characters are included in the MXID.
+ localpart = map_username_to_mxid_localpart(localpart)
- display_name = None # type: Optional[str]
- if self._config.display_name_template is not None:
- display_name = self._config.display_name_template.render(
- user=userinfo
- ).strip()
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
- if display_name == "":
- display_name = None
+ def render_template_field(template: Optional[Template]) -> Optional[str]:
+ if template is None:
+ return None
+ return template.render(user=userinfo).strip()
- return UserAttributeDict(localpart=localpart, display_name=display_name)
+ display_name = render_template_field(self._config.display_name_template)
+ if display_name == "":
+ display_name = None
+
+ emails = [] # type: List[str]
+ email = render_template_field(self._config.email_template)
+ if email:
+ emails.append(email)
+
+ return UserAttributeDict(
+ localpart=localpart, display_name=display_name, emails=emails
+ )
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 5372753707..66dc886c81 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -197,7 +197,8 @@ class PaginationHandler:
stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
r = await self.store.get_room_event_before_stream_ordering(
- room_id, stream_ordering,
+ room_id,
+ stream_ordering,
)
if not r:
logger.warning(
@@ -223,7 +224,12 @@ class PaginationHandler:
# the background so that it's not blocking any other operation apart from
# other purges in the same room.
run_as_background_process(
- "_purge_history", self._purge_history, purge_id, room_id, token, True,
+ "_purge_history",
+ self._purge_history,
+ purge_id,
+ room_id,
+ token,
+ True,
)
def start_purge_history(
@@ -279,7 +285,7 @@ class PaginationHandler:
except Exception:
f = Failure()
logger.error(
- "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
+ "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore
)
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:
@@ -389,7 +395,9 @@ class PaginationHandler:
)
await self.hs.get_federation_handler().maybe_backfill(
- room_id, curr_topo, limit=pagin_config.limit,
+ room_id,
+ curr_topo,
+ limit=pagin_config.limit,
)
to_room_key = None
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
index 6c635cc31b..92cefa11aa 100644
--- a/synapse/handlers/password_policy.py
+++ b/synapse/handlers/password_policy.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING
from synapse.api.errors import Codes, PasswordRefusedError
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 22d1e9d35c..da92feacc9 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -104,6 +104,8 @@ class BasePresenceHandler(abc.ABC):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
+
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {state.user_id: state for state in active_presence}
@@ -274,22 +276,25 @@ class PresenceHandler(BasePresenceHandler):
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
- # Start a LoopingCall in 30s that fires every 5s.
- # The initial delay is to allow disconnected clients a chance to
- # reconnect before we treat them as offline.
- def run_timeout_handler():
- return run_as_background_process(
- "handle_presence_timeouts", self._handle_timeouts
- )
-
- self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000)
+ if self._presence_enabled:
+ # Start a LoopingCall in 30s that fires every 5s.
+ # The initial delay is to allow disconnected clients a chance to
+ # reconnect before we treat them as offline.
+ def run_timeout_handler():
+ return run_as_background_process(
+ "handle_presence_timeouts", self._handle_timeouts
+ )
- def run_persister():
- return run_as_background_process(
- "persist_presence_changes", self._persist_unpersisted_changes
+ self.clock.call_later(
+ 30, self.clock.looping_call, run_timeout_handler, 5000
)
- self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
+ def run_persister():
+ return run_as_background_process(
+ "persist_presence_changes", self._persist_unpersisted_changes
+ )
+
+ self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
LaterGauge(
"synapse_handlers_presence_wheel_timer_size",
@@ -299,7 +304,7 @@ class PresenceHandler(BasePresenceHandler):
)
# Used to handle sending of presence to newly joined users/servers
- if hs.config.use_presence:
+ if self._presence_enabled:
self.notifier.add_replication_callback(self.notify_new_event)
# Presence is best effort and quickly heals itself, so lets just always
@@ -349,10 +354,13 @@ class PresenceHandler(BasePresenceHandler):
[self.user_to_current_state[user_id] for user_id in unpersisted]
)
- async def _update_states(self, new_states):
+ async def _update_states(self, new_states: Iterable[UserPresenceState]) -> None:
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
should be sent to clients/servers.
+
+ Args:
+ new_states: The new user presence state updates to process.
"""
now = self.clock.time_msec()
@@ -368,7 +376,7 @@ class PresenceHandler(BasePresenceHandler):
new_states_dict = {}
for new_state in new_states:
new_states_dict[new_state.user_id] = new_state
- new_state = new_states_dict.values()
+ new_states = new_states_dict.values()
for new_state in new_states:
user_id = new_state.user_id
@@ -635,8 +643,7 @@ class PresenceHandler(BasePresenceHandler):
self.external_process_last_updated_ms.pop(process_id, None)
async def current_state_for_user(self, user_id):
- """Get the current presence state for a user.
- """
+ """Get the current presence state for a user."""
res = await self.current_state_for_users([user_id])
return res[user_id]
@@ -658,17 +665,6 @@ class PresenceHandler(BasePresenceHandler):
self._push_to_remotes(states)
- async def notify_for_states(self, state, stream_id):
- parties = await get_interested_parties(self.store, [state])
- room_ids_to_states, users_to_states = parties
-
- self.notifier.on_new_event(
- "presence_key",
- stream_id,
- rooms=room_ids_to_states.keys(),
- users=[UserID.from_string(u) for u in users_to_states],
- )
-
def _push_to_remotes(self, states):
"""Sends state updates to remote servers.
@@ -678,8 +674,7 @@ class PresenceHandler(BasePresenceHandler):
self.federation.send_presence(states)
async def incoming_presence(self, origin, content):
- """Called when we receive a `m.presence` EDU from a remote server.
- """
+ """Called when we receive a `m.presence` EDU from a remote server."""
if not self._presence_enabled:
return
@@ -729,8 +724,7 @@ class PresenceHandler(BasePresenceHandler):
await self._update_states(updates)
async def set_state(self, target_user, state, ignore_status_msg=False):
- """Set the presence state of the user.
- """
+ """Set the presence state of the user."""
status_msg = state.get("status_msg", None)
presence = state["presence"]
@@ -738,8 +732,12 @@ class PresenceHandler(BasePresenceHandler):
PresenceState.ONLINE,
PresenceState.UNAVAILABLE,
PresenceState.OFFLINE,
+ PresenceState.BUSY,
)
- if presence not in valid_presence:
+
+ if presence not in valid_presence or (
+ presence == PresenceState.BUSY and not self._busy_presence_enabled
+ ):
raise SynapseError(400, "Invalid presence state")
user_id = target_user.to_string()
@@ -752,14 +750,15 @@ class PresenceHandler(BasePresenceHandler):
msg = status_msg if presence != PresenceState.OFFLINE else None
new_fields["status_msg"] = msg
- if presence == PresenceState.ONLINE:
+ if presence == PresenceState.ONLINE or (
+ presence == PresenceState.BUSY and self._busy_presence_enabled
+ ):
new_fields["last_active_ts"] = self.clock.time_msec()
await self._update_states([prev_state.copy_and_replace(**new_fields)])
async def is_visible(self, observed_user, observer_user):
- """Returns whether a user can see another user's presence.
- """
+ """Returns whether a user can see another user's presence."""
observer_room_ids = await self.store.get_rooms_for_user(
observer_user.to_string()
)
@@ -861,6 +860,9 @@ class PresenceHandler(BasePresenceHandler):
"""Process current state deltas to find new joins that need to be
handled.
"""
+ # A map of destination to a set of user state that they should receive
+ presence_destinations = {} # type: Dict[str, Set[UserPresenceState]]
+
for delta in deltas:
typ = delta["type"]
state_key = delta["state_key"]
@@ -870,6 +872,7 @@ class PresenceHandler(BasePresenceHandler):
logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+ # Drop any event that isn't a membership join
if typ != EventTypes.Member:
continue
@@ -892,13 +895,38 @@ class PresenceHandler(BasePresenceHandler):
# Ignore changes to join events.
continue
- await self._on_user_joined_room(room_id, state_key)
+ # Retrieve any user presence state updates that need to be sent as a result,
+ # and the destinations that need to receive it
+ destinations, user_presence_states = await self._on_user_joined_room(
+ room_id, state_key
+ )
+
+ # Insert the destinations and respective updates into our destinations dict
+ for destination in destinations:
+ presence_destinations.setdefault(destination, set()).update(
+ user_presence_states
+ )
- async def _on_user_joined_room(self, room_id: str, user_id: str) -> None:
+ # Send out user presence updates for each destination
+ for destination, user_state_set in presence_destinations.items():
+ self.federation.send_presence_to_destinations(
+ destinations=[destination], states=user_state_set
+ )
+
+ async def _on_user_joined_room(
+ self, room_id: str, user_id: str
+ ) -> Tuple[List[str], List[UserPresenceState]]:
"""Called when we detect a user joining the room via the current state
- delta stream.
- """
+ delta stream. Returns the destinations that need to be updated and the
+ presence updates to send to them.
+
+ Args:
+ room_id: The ID of the room that the user has joined.
+ user_id: The ID of the user that has joined the room.
+ Returns:
+ A tuple of destinations and presence updates to send to them.
+ """
if self.is_mine_id(user_id):
# If this is a local user then we need to send their presence
# out to hosts in the room (who don't already have it)
@@ -906,15 +934,15 @@ class PresenceHandler(BasePresenceHandler):
# TODO: We should be able to filter the hosts down to those that
# haven't previously seen the user
- state = await self.current_state_for_user(user_id)
- hosts = await self.state.get_current_hosts_in_room(room_id)
+ remote_hosts = await self.state.get_current_hosts_in_room(room_id)
# Filter out ourselves.
- hosts = {host for host in hosts if host != self.server_name}
+ filtered_remote_hosts = [
+ host for host in remote_hosts if host != self.server_name
+ ]
- self.federation.send_presence_to_destinations(
- states=[state], destinations=hosts
- )
+ state = await self.current_state_for_user(user_id)
+ return filtered_remote_hosts, [state]
else:
# A remote user has joined the room, so we need to:
# 1. Check if this is a new server in the room
@@ -927,6 +955,8 @@ class PresenceHandler(BasePresenceHandler):
# TODO: Check that this is actually a new server joining the
# room.
+ remote_host = get_domain_from_id(user_id)
+
users = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, users))
@@ -946,15 +976,11 @@ class PresenceHandler(BasePresenceHandler):
or state.status_msg is not None
]
- if states:
- self.federation.send_presence_to_destinations(
- states=states, destinations=[get_domain_from_id(user_id)]
- )
+ return [remote_host], states
def should_notify(old_state, new_state):
- """Decides if a presence state change should be sent to interested parties.
- """
+ """Decides if a presence state change should be sent to interested parties."""
if old_state == new_state:
return False
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 40b1c25a36..ee8f868791 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -45,7 +45,7 @@ from synapse.types import (
from ._base import BaseHandler
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -87,8 +87,8 @@ class ProfileHandler(BaseHandler):
)
if len(self.hs.config.replicate_user_profiles_to) > 0:
- reactor.callWhenRunning(self._do_assign_profile_replication_batches)
- reactor.callWhenRunning(self._start_replicate_profiles)
+ reactor.callWhenRunning(self._do_assign_profile_replication_batches) # type: ignore
+ reactor.callWhenRunning(self._start_replicate_profiles) # type: ignore
# Add a looping call to replicate_profiles: this handles retries
# if the replication is unsuccessful when the user updated their
# profile.
@@ -254,7 +254,7 @@ class ProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
- return result["displayname"]
+ return result.get("displayname")
async def set_displayname(
self,
@@ -305,7 +305,8 @@ class ProfileHandler(BaseHandler):
# This must be done by the target user himself.
if by_admin:
requester = create_requester(
- target_user, authenticated_entity=requester.authenticated_entity,
+ target_user,
+ authenticated_entity=requester.authenticated_entity,
)
if len(self.hs.config.replicate_user_profiles_to) > 0:
@@ -332,7 +333,10 @@ class ProfileHandler(BaseHandler):
run_in_background(self._replicate_profiles)
async def set_active(
- self, users: List[UserID], active: bool, hide: bool,
+ self,
+ users: List[UserID],
+ active: bool,
+ hide: bool,
):
"""
Sets the 'active' flag on a set of user profiles. If set to false, the
@@ -392,7 +396,7 @@ class ProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
- return result["avatar_url"]
+ return result.get("avatar_url")
async def set_avatar_url(
self,
@@ -432,9 +436,17 @@ class ProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
+ avatar_url_to_set = new_avatar_url # type: Optional[str]
+ if new_avatar_url == "":
+ avatar_url_to_set = None
+
# Enforce a max avatar size if one is defined
- if self.max_avatar_size or self.allowed_avatar_mimetypes:
- media_id = self._validate_and_parse_media_id_from_avatar_url(new_avatar_url)
+ if avatar_url_to_set and (
+ self.max_avatar_size or self.allowed_avatar_mimetypes
+ ):
+ media_id = self._validate_and_parse_media_id_from_avatar_url(
+ avatar_url_to_set
+ )
# Check that this media exists locally
media_info = await self.store.get_local_media(media_id)
@@ -477,7 +489,7 @@ class ProfileHandler(BaseHandler):
new_batchnum = None
await self.store.set_profile_avatar_url(
- target_user.localpart, new_avatar_url, new_batchnum
+ target_user.localpart, avatar_url_to_set, new_batchnum
)
if self.hs.config.user_directory_search_all_users:
@@ -506,6 +518,15 @@ class ProfileHandler(BaseHandler):
return avatar_pieces[-1]
async def on_profile_query(self, args: JsonDict) -> JsonDict:
+ """Handles federation profile query requests."""
+
+ if not self.hs.config.allow_profile_lookup_over_federation:
+ raise SynapseError(
+ 403,
+ "Profile lookup over federation is disabled on this homeserver",
+ Codes.FORBIDDEN,
+ )
+
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index a7550806e6..a54fe1968e 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -21,7 +21,7 @@ from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
super().__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
+ self.account_data_handler = hs.get_account_data_handler()
self.read_marker_linearizer = Linearizer(name="read_marker")
- self.notifier = hs.get_notifier()
async def received_client_read_marker(
self, room_id: str, user_id: str, event_id: str
@@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
if should_update:
content = {"event_id": event_id}
- max_id = await self.store.add_account_data_to_room(
+ await self.account_data_handler.add_account_data_to_room(
user_id, room_id, "m.fully_read", content
)
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 153cbae7b9..dbfe9bfaca 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -13,33 +13,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
-from synapse.util.async_helpers import maybe_awaitable
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ReceiptsHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs
- self.federation = hs.get_federation_sender()
- hs.get_federation_registry().register_edu_handler(
- "m.receipt", self._received_remote_receipt
- )
+
+ # We only need to poke the federation sender explicitly if its on the
+ # same instance. Other federation sender instances will get notified by
+ # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+ # in the receipts stream.
+ self.federation_sender = None
+ if hs.should_send_federation():
+ self.federation_sender = hs.get_federation_sender()
+
+ # If we can handle the receipt EDUs we do so, otherwise we route them
+ # to the appropriate worker.
+ if hs.get_instance_name() in hs.config.worker.writers.receipts:
+ hs.get_federation_registry().register_edu_handler(
+ "m.receipt", self._received_remote_receipt
+ )
+ else:
+ hs.get_federation_registry().register_instances_for_edu(
+ "m.receipt",
+ hs.config.worker.writers.receipts,
+ )
+
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
- async def _received_remote_receipt(self, origin, content):
- """Called when we receive an EDU of type m.receipt from a remote HS.
- """
+ async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
+ """Called when we receive an EDU of type m.receipt from a remote HS."""
receipts = []
for room_id, room_values in content.items():
for receipt_type, users in room_values.items():
@@ -64,11 +82,10 @@ class ReceiptsHandler(BaseHandler):
await self._handle_new_receipts(receipts)
- async def _handle_new_receipts(self, receipts):
- """Takes a list of receipts, stores them and informs the notifier.
- """
- min_batch_id = None
- max_batch_id = None
+ async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
+ """Takes a list of receipts, stores them and informs the notifier."""
+ min_batch_id = None # type: Optional[int]
+ max_batch_id = None # type: Optional[int]
for receipt in receipts:
res = await self.store.insert_receipt(
@@ -90,7 +107,8 @@ class ReceiptsHandler(BaseHandler):
if max_batch_id is None or max_persisted_id > max_batch_id:
max_batch_id = max_persisted_id
- if min_batch_id is None:
+ # Either both of these should be None or neither.
+ if min_batch_id is None or max_batch_id is None:
# no new receipts
return False
@@ -98,15 +116,15 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
- await maybe_awaitable(
- self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
- )
+ await self.hs.get_pusherpool().on_new_receipts(
+ min_batch_id, max_batch_id, affected_room_ids
)
return True
- async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
+ async def received_client_receipt(
+ self, room_id: str, receipt_type: str, user_id: str, event_id: str
+ ) -> None:
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
@@ -122,14 +140,17 @@ class ReceiptsHandler(BaseHandler):
if not is_new:
return
- await self.federation.send_read_receipt(receipt)
+ if self.federation_sender:
+ await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- async def get_new_events(self, from_key, room_ids, **kwargs):
+ async def get_new_events(
+ self, from_key: int, room_ids: List[str], **kwargs
+ ) -> Tuple[List[JsonDict], int]:
from_key = int(from_key)
to_key = self.get_current_key()
@@ -174,5 +195,5 @@ class ReceiptEventSource:
return (events, to_key)
- def get_current_key(self, direction="f"):
+ def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c227c4fe91..1b2a515ee9 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -14,8 +14,11 @@
# limitations under the License.
"""Contains functions for registering clients."""
+
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+
+from prometheus_client import Counter
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@@ -35,11 +38,24 @@ from synapse.types import RoomAlias, UserID, create_requester
from ._base import BaseHandler
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
+registration_counter = Counter(
+ "synapse_user_registrations_total",
+ "Number of new users registered (since restart)",
+ ["guest", "shadow_banned", "auth_provider"],
+)
+
+login_counter = Counter(
+ "synapse_user_logins_total",
+ "Number of user logins (since restart)",
+ ["guest", "auth_provider"],
+)
+
+
class RegistrationHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -64,11 +80,12 @@ class RegistrationHandler(BaseHandler):
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
hs
)
- self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client(
- hs
+ self._post_registration_client = (
+ ReplicationPostRegisterActionsServlet.make_client(hs)
)
else:
self.device_handler = hs.get_device_handler()
+ self._register_device_client = self.register_device_inner
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.session_lifetime
@@ -167,9 +184,10 @@ class RegistrationHandler(BaseHandler):
user_type: Optional[str] = None,
default_display_name: Optional[str] = None,
address: Optional[str] = None,
- bind_emails: List[str] = [],
+ bind_emails: Iterable[str] = [],
by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+ auth_provider_id: Optional[str] = None,
) -> str:
"""Registers a new client on the server.
@@ -195,20 +213,25 @@ class RegistrationHandler(BaseHandler):
admin api, otherwise False.
user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
+ auth_provider_id: The SSO IdP the user used, if any.
Returns:
- The registere user_id.
+ The registered user_id.
Raises:
SynapseError if there was a problem registering.
"""
self.check_registration_ratelimit(address)
- result = self.spam_checker.check_registration_for_spam(
- threepid, localpart, user_agent_ips or [],
+ result = await self.spam_checker.check_registration_for_spam(
+ threepid,
+ localpart,
+ user_agent_ips or [],
+ auth_provider_id=auth_provider_id,
)
if result == RegistrationBehaviour.DENY:
logger.info(
- "Blocked registration of %r", localpart,
+ "Blocked registration of %r",
+ localpart,
)
# We return a 429 to make it not obvious that they've been
# denied.
@@ -217,7 +240,8 @@ class RegistrationHandler(BaseHandler):
shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
if shadow_banned:
logger.info(
- "Shadow banning registration of %r", localpart,
+ "Shadow banning registration of %r",
+ localpart,
)
# do not check_auth_blocking if the call is coming through the Admin API
@@ -266,8 +290,6 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
fail_count = 0
- # If a default display name is not given, generate one.
- generate_display_name = default_display_name is None
# This breaks on successful registration *or* errors after 10 failures.
while True:
# Fail after being unable to find a suitable ID a few times
@@ -278,7 +300,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
- if generate_display_name:
+ if default_display_name is None:
default_display_name = localpart
try:
await self.register_with_store(
@@ -301,6 +323,12 @@ class RegistrationHandler(BaseHandler):
# if user id is taken, just generate another
fail_count += 1
+ registration_counter.labels(
+ guest=make_guest,
+ shadow_banned=shadow_banned,
+ auth_provider=(auth_provider_id or ""),
+ ).inc()
+
if not self.hs.config.user_consent_at_registration:
if not self.hs.config.auto_join_rooms_for_guests and make_guest:
logger.info(
@@ -402,7 +430,9 @@ class RegistrationHandler(BaseHandler):
config["room_alias_name"] = room_alias.localpart
info, _ = await room_creation_handler.create_room(
- fake_requester, config=config, ratelimit=False,
+ fake_requester,
+ config=config,
+ ratelimit=False,
)
# If the room does not require an invite, but another user
@@ -439,10 +469,10 @@ class RegistrationHandler(BaseHandler):
if RoomAlias.is_valid(r):
(
- room_id,
+ room,
remote_room_hosts,
) = await room_member_handler.lookup_room_alias(room_alias)
- room_id = room_id.to_string()
+ room_id = room.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (r,)
@@ -712,6 +742,8 @@ class RegistrationHandler(BaseHandler):
device_id: Optional[str],
initial_display_name: Optional[str],
is_guest: bool = False,
+ is_appservice_ghost: bool = False,
+ auth_provider_id: Optional[str] = None,
) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
@@ -722,20 +754,40 @@ class RegistrationHandler(BaseHandler):
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
-
+ auth_provider_id: The SSO IdP the user used, if any (just used for the
+ prometheus metrics).
Returns:
Tuple of device ID and access token
"""
+ res = await self._register_device_client(
+ user_id=user_id,
+ device_id=device_id,
+ initial_display_name=initial_display_name,
+ is_guest=is_guest,
+ is_appservice_ghost=is_appservice_ghost,
+ )
- if self.hs.config.worker_app:
- r = await self._register_device_client(
- user_id=user_id,
- device_id=device_id,
- initial_display_name=initial_display_name,
- is_guest=is_guest,
- )
- return r["device_id"], r["access_token"]
+ login_counter.labels(
+ guest=is_guest,
+ auth_provider=(auth_provider_id or ""),
+ ).inc()
+
+ return res["device_id"], res["access_token"]
+
+ async def register_device_inner(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ is_guest: bool = False,
+ is_appservice_ghost: bool = False,
+ ) -> Dict[str, str]:
+ """Helper for register_device
+ Does the bits that need doing on the main process. Not for use outside this
+ class and RegisterDeviceReplicationServlet.
+ """
+ assert not self.hs.config.worker_app
valid_until_ms = None
if self.session_lifetime is not None:
if is_guest:
@@ -754,10 +806,13 @@ class RegistrationHandler(BaseHandler):
)
else:
access_token = await self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
+ user_id,
+ device_id=registered_device_id,
+ valid_until_ms=valid_until_ms,
+ is_appservice_ghost=is_appservice_ghost,
)
- return (registered_device_id, access_token)
+ return {"device_id": registered_device_id, "access_token": access_token}
async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]
@@ -770,6 +825,8 @@ class RegistrationHandler(BaseHandler):
access_token: The access token of the newly logged in device, or
None if `inhibit_login` enabled.
"""
+ # TODO: 3pid registration can actually happen on the workers. Consider
+ # refactoring it.
if self.hs.config.worker_app:
await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token
@@ -853,7 +910,10 @@ class RegistrationHandler(BaseHandler):
return
await self._auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
)
# And we add an email pusher for them by default, but only
@@ -905,5 +965,8 @@ class RegistrationHandler(BaseHandler):
raise
await self._auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 469f6cf10f..f3da38a71e 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from synapse.api.constants import (
EventTypes,
+ HistoryVisibility,
JoinRules,
Membership,
RoomCreationPreset,
@@ -37,7 +38,7 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
-from synapse.http.endpoint import parse_and_validate_server_name
+from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
@@ -54,6 +55,7 @@ from synapse.types import (
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import parse_and_validate_server_name
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -81,21 +83,21 @@ class RoomCreationHandler(BaseHandler):
self._presets_dict = {
RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
- "history_visibility": "shared",
+ "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": False,
"guest_can_join": True,
"power_level_content_override": {"invite": 0},
},
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
- "history_visibility": "shared",
+ "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": True,
"guest_can_join": True,
"power_level_content_override": {"invite": 0},
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
- "history_visibility": "shared",
+ "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": False,
"guest_can_join": False,
"power_level_content_override": {},
@@ -119,12 +121,16 @@ class RoomCreationHandler(BaseHandler):
# succession, only process the first attempt and return its result to
# subsequent requests
self._upgrade_response_cache = ResponseCache(
- hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
+ hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
) # type: ResponseCache[Tuple[str, str]]
self._server_notices_mxid = hs.config.server_notices_mxid
self.third_party_event_rules = hs.get_third_party_event_rules()
+ self._invite_burst_count = (
+ hs.config.ratelimiting.rc_invites_per_room.burst_count
+ )
+
async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
) -> str:
@@ -192,7 +198,9 @@ class RoomCreationHandler(BaseHandler):
if r is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,))
new_room_id = await self._generate_room_id(
- creator_id=user_id, is_public=r["is_public"], room_version=new_version,
+ creator_id=user_id,
+ is_public=r["is_public"],
+ room_version=new_version,
)
logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
@@ -230,7 +238,9 @@ class RoomCreationHandler(BaseHandler):
# now send the tombstone
await self.event_creation_handler.handle_new_client_event(
- requester=requester, event=tombstone_event, context=tombstone_context,
+ requester=requester,
+ event=tombstone_event,
+ context=tombstone_context,
)
old_room_state = await tombstone_context.get_current_state_ids()
@@ -251,7 +261,10 @@ class RoomCreationHandler(BaseHandler):
# finally, shut down the PLs in the old room, and update them in the new
# room.
await self._update_upgraded_room_pls(
- requester, old_room_id, new_room_id, old_room_state,
+ requester,
+ old_room_id,
+ new_room_id,
+ old_room_state,
)
return new_room_id
@@ -368,7 +381,7 @@ class RoomCreationHandler(BaseHandler):
else:
is_requester_admin = await self.auth.is_server_admin(requester.user)
- if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id, invite_list=[], third_party_invite_list=[], cloning=True
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -376,7 +389,7 @@ class RoomCreationHandler(BaseHandler):
creation_content = {
"room_version": new_room_version.identifier,
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
- }
+ } # type: JsonDict
# Check if old room was non-federatable
@@ -431,17 +444,20 @@ class RoomCreationHandler(BaseHandler):
# Copy over user power levels now as this will not be possible with >100PL users once
# the room has been created
-
# Calculate the minimum power level needed to clone the room
event_power_levels = power_levels.get("events", {})
- state_default = power_levels.get("state_default", 0)
- ban = power_levels.get("ban")
+ state_default = power_levels.get("state_default", 50)
+ ban = power_levels.get("ban", 50)
needed_power_level = max(state_default, ban, max(event_power_levels.values()))
+ # Get the user's current power level, this matches the logic in get_user_power_level,
+ # but without the entire state map.
+ user_power_levels = power_levels.setdefault("users", {})
+ users_default = power_levels.get("users_default", 0)
+ current_power_level = user_power_levels.get(user_id, users_default)
# Raise the requester's power level in the new room if necessary
- current_power_level = power_levels["users"][user_id]
if current_power_level < needed_power_level:
- power_levels["users"][user_id] = needed_power_level
+ user_power_levels[user_id] = needed_power_level
await self._send_events_for_new_room(
requester,
@@ -452,6 +468,7 @@ class RoomCreationHandler(BaseHandler):
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
+ ratelimit=False,
)
# Transfer membership events
@@ -572,7 +589,7 @@ class RoomCreationHandler(BaseHandler):
ratelimit: bool = True,
creator_join_profile: Optional[JsonDict] = None,
) -> Tuple[dict, int]:
- """ Creates a new room.
+ """Creates a new room.
Args:
requester:
@@ -623,7 +640,7 @@ class RoomCreationHandler(BaseHandler):
invite_list = config.get("invite", [])
invite_3pid_list = config.get("invite_3pid", [])
- if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id,
invite_list=invite_list,
third_party_invite_list=invite_3pid_list,
@@ -678,6 +695,9 @@ class RoomCreationHandler(BaseHandler):
invite_3pid_list = []
invite_list = []
+ if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count:
+ raise SynapseError(400, "Cannot invite so many users at once")
+
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")
@@ -696,7 +716,9 @@ class RoomCreationHandler(BaseHandler):
is_public = visibility == "public"
room_id = await self._generate_room_id(
- creator_id=user_id, is_public=is_public, room_version=room_version,
+ creator_id=user_id,
+ is_public=is_public,
+ room_version=room_version,
)
# Check whether this visibility value is blocked by a third party module
@@ -753,6 +775,7 @@ class RoomCreationHandler(BaseHandler):
room_alias=room_alias,
power_level_content_override=power_level_content_override,
creator_join_profile=creator_join_profile,
+ ratelimit=ratelimit,
)
if "name" in config:
@@ -838,7 +861,7 @@ class RoomCreationHandler(BaseHandler):
if room_alias:
result["room_alias"] = room_alias.to_string()
- # Always wait for room creation to progate before returning
+ # Always wait for room creation to propagate before returning
await self._replication.wait_for_stream_position(
self.hs.config.worker.events_shard_config.get_instance(room_id),
"events",
@@ -858,6 +881,7 @@ class RoomCreationHandler(BaseHandler):
room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
+ ratelimit: bool = True,
) -> int:
"""Sends the initial events into a new room.
@@ -889,7 +913,10 @@ class RoomCreationHandler(BaseHandler):
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
- creator, event, ratelimit=False, ignore_shadow_ban=True,
+ creator,
+ event,
+ ratelimit=False,
+ ignore_shadow_ban=True,
)
return last_stream_id
@@ -904,7 +931,7 @@ class RoomCreationHandler(BaseHandler):
creator.user,
room_id,
"join",
- ratelimit=False,
+ ratelimit=ratelimit,
content=creator_join_profile,
new_room=True,
)
@@ -990,7 +1017,10 @@ class RoomCreationHandler(BaseHandler):
return last_sent_stream_id
async def _generate_room_id(
- self, creator_id: str, is_public: bool, room_version: RoomVersion,
+ self,
+ creator_id: str,
+ is_public: bool,
+ room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -1014,41 +1044,51 @@ class RoomCreationHandler(BaseHandler):
class RoomContextHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
+ self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
async def get_event_context(
self,
- user: UserID,
+ requester: Requester,
room_id: str,
event_id: str,
limit: int,
event_filter: Optional[Filter],
+ use_admin_priviledge: bool = False,
) -> Optional[JsonDict]:
"""Retrieves events, pagination tokens and state around a given event
in a room.
Args:
- user
+ requester
room_id
event_id
limit: The maximum number of events to return in total
(excluding state).
event_filter: the filter to apply to the events returned
(excluding the target event_id)
-
+ use_admin_priviledge: if `True`, return all events, regardless
+ of whether `user` has access to them. To be used **ONLY**
+ from the admin API.
Returns:
dict, or None if the event isn't found
"""
+ user = requester.user
+ if use_admin_priviledge:
+ await assert_user_is_admin(self.auth, requester.user)
+
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
- def filter_evts(events):
- return filter_events_for_client(
+ async def filter_evts(events):
+ if use_admin_priviledge:
+ return events
+ return await filter_events_for_client(
self.storage, user.to_string(), events, is_peeking=is_peeking
)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index ddb8e44e21..01fade718e 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -15,19 +15,22 @@
import logging
from collections import namedtuple
-from typing import Any, Dict, Optional
+from typing import TYPE_CHECKING, Optional, Tuple
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.api.errors import Codes, HttpResponseException
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached
from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
@@ -37,37 +40,38 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
- self.response_cache = ResponseCache(hs, "room_list")
+ self.response_cache = ResponseCache(
+ hs.get_clock(), "room_list"
+ ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
self.remote_response_cache = ResponseCache(
- hs, "remote_room_list", timeout_ms=30 * 1000
- )
+ hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
+ ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
async def get_local_public_room_list(
self,
- limit=None,
- since_token=None,
- search_filter=None,
- network_tuple=EMPTY_THIRD_PARTY_ID,
- from_federation=False,
- ):
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[dict] = None,
+ network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
+ from_federation: bool = False,
+ ) -> JsonDict:
"""Generate a local public room list.
There are multiple different lists: the main one plus one per third
party network. A client can ask for a specific list or to return all.
Args:
- limit (int|None)
- since_token (str|None)
- search_filter (dict|None)
- network_tuple (ThirdPartyInstanceID): Which public list to use.
+ limit
+ since_token
+ search_filter
+ network_tuple: Which public list to use.
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
- from_federation (bool): true iff the request comes from the federation
- API
+ from_federation: true iff the request comes from the federation API
"""
if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0}
@@ -107,10 +111,10 @@ class RoomListHandler(BaseHandler):
self,
limit: Optional[int] = None,
since_token: Optional[str] = None,
- search_filter: Optional[Dict] = None,
+ search_filter: Optional[dict] = None,
network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False,
- ) -> Dict[str, Any]:
+ ) -> JsonDict:
"""Generate a public room list.
Args:
limit: Maximum amount of rooms to return.
@@ -131,13 +135,17 @@ class RoomListHandler(BaseHandler):
if since_token:
batch_token = RoomListNextBatch.from_token(since_token)
- bounds = (batch_token.last_joined_members, batch_token.last_room_id)
+ bounds = (
+ batch_token.last_joined_members,
+ batch_token.last_room_id,
+ ) # type: Optional[Tuple[int, str]]
forwards = batch_token.direction_is_forward
+ has_batch_token = True
else:
- batch_token = None
bounds = None
forwards = True
+ has_batch_token = False
# we request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None
@@ -159,7 +167,8 @@ class RoomListHandler(BaseHandler):
"canonical_alias": room["canonical_alias"],
"num_joined_members": room["joined_members"],
"avatar_url": room["avatar"],
- "world_readable": room["history_visibility"] == "world_readable",
+ "world_readable": room["history_visibility"]
+ == HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
"join_rule": room["join_rules"],
}
@@ -169,7 +178,7 @@ class RoomListHandler(BaseHandler):
results = [build_room_entry(r) for r in results]
- response = {}
+ response = {} # type: JsonDict
num_results = len(results)
if limit is not None:
more_to_come = num_results == probing_limit
@@ -187,7 +196,7 @@ class RoomListHandler(BaseHandler):
initial_entry = results[0]
if forwards:
- if batch_token:
+ if has_batch_token:
# If there was a token given then we assume that there
# must be previous results.
response["prev_batch"] = RoomListNextBatch(
@@ -203,7 +212,7 @@ class RoomListHandler(BaseHandler):
direction_is_forward=True,
).to_token()
else:
- if batch_token:
+ if has_batch_token:
response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"],
last_room_id=final_entry["room_id"],
@@ -293,7 +302,7 @@ class RoomListHandler(BaseHandler):
return None
# Return whether this room is open to federation users or not
- create_event = current_state.get((EventTypes.Create, ""))
+ create_event = current_state[EventTypes.Create, ""]
result["m.federate"] = create_event.content.get("m.federate", True)
name_event = current_state.get((EventTypes.Name, ""))
@@ -318,7 +327,7 @@ class RoomListHandler(BaseHandler):
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
- result["world_readable"] = visibility == "world_readable"
+ result["world_readable"] = visibility == HistoryVisibility.WORLD_READABLE
guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
@@ -336,13 +345,13 @@ class RoomListHandler(BaseHandler):
async def get_remote_public_room_list(
self,
- server_name,
- limit=None,
- since_token=None,
- search_filter=None,
- include_all_networks=False,
- third_party_instance_id=None,
- ):
+ server_name: str,
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[dict] = None,
+ include_all_networks: bool = False,
+ third_party_instance_id: Optional[str] = None,
+ ) -> JsonDict:
if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0}
@@ -399,13 +408,13 @@ class RoomListHandler(BaseHandler):
async def _get_remote_list_cached(
self,
- server_name,
- limit=None,
- since_token=None,
- search_filter=None,
- include_all_networks=False,
- third_party_instance_id=None,
- ):
+ server_name: str,
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[dict] = None,
+ include_all_networks: bool = False,
+ third_party_instance_id: Optional[str] = None,
+ ) -> JsonDict:
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
@@ -456,24 +465,24 @@ class RoomListNextBatch(
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@classmethod
- def from_token(cls, token):
+ def from_token(cls, token: str) -> "RoomListNextBatch":
decoded = msgpack.loads(decode_base64(token), raw=False)
return RoomListNextBatch(
**{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
)
- def to_token(self):
+ def to_token(self) -> str:
return encode_base64(
msgpack.dumps(
{self.KEY_DICT[key]: val for key, val in self._asdict().items()}
)
)
- def copy_and_replace(self, **kwds):
+ def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
return self._replace(**kwds)
-def _matches_room_entry(room_entry, search_filter):
+def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
if search_filter and search_filter.get("generic_search_term", None):
generic_search_term = search_filter["generic_search_term"].upper()
if generic_search_term in room_entry.get("name", "").upper():
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 55deff987c..10af3782f4 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -71,6 +71,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.account_data_handler = hs.get_account_data_handler()
self.member_linearizer = Linearizer(name="member")
@@ -92,6 +93,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
+ self._invites_per_room_limiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
+ burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
+ )
+ self._invites_per_user_limiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
+ burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
+ )
+
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
# it doesn't store state.
@@ -119,7 +131,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def remote_knock(
- self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict,
+ self,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
) -> Tuple[str, int]:
"""Try and knock on a room that this server is not in
@@ -186,6 +202,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
+ @abc.abstractmethod
+ async def forget(self, user: UserID, room_id: str) -> None:
+ raise NotImplementedError()
+
+ def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
+ """Ratelimit invites by room and by target user.
+
+ If room ID is missing then we just rate limit by target user.
+ """
+ if room_id:
+ self._invites_per_room_limiter.ratelimit(room_id)
+
+ self._invites_per_user_limiter.ratelimit(invitee_user_id)
+
async def _local_membership_update(
self,
requester: Requester,
@@ -212,7 +242,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# do it up front for efficiency.)
if txn_id and requester.access_token_id:
existing_event_id = await self.store.get_event_id_from_transaction_id(
- room_id, requester.user.to_string(), requester.access_token_id, txn_id,
+ room_id,
+ requester.user.to_string(),
+ requester.access_token_id,
+ txn_id,
)
if existing_event_id:
event_pos = await self.store.get_position_for_event(existing_event_id)
@@ -246,7 +279,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
- if newly_joined:
+ if newly_joined and ratelimit:
time_now_s = self.clock.time()
(
allowed,
@@ -259,7 +292,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
result_event = await self.event_creation_handler.handle_new_client_event(
- requester, event, context, extra_users=[target], ratelimit=ratelimit,
+ requester,
+ event,
+ context,
+ extra_users=[target],
+ ratelimit=ratelimit,
)
if event.membership == Membership.LEAVE:
@@ -296,7 +333,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
direct_rooms[key].append(new_room_id)
# Save back to user's m.direct account data
- await self.store.add_account_data_for_user(
+ await self.account_data_handler.add_account_data_for_user(
user_id, AccountDataTypes.DIRECT, direct_rooms
)
break
@@ -306,7 +343,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Copy each room tag to the new room
for tag, tag_content in room_tags.items():
- await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
+ await self.account_data_handler.add_tag_to_room(
+ user_id, new_room_id, tag, tag_content
+ )
async def update_membership(
self,
@@ -430,8 +469,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(403, "This room has been blocked on this server")
if effective_membership_state == Membership.INVITE:
+ target_id = target.to_string()
+ if ratelimit:
+ # Don't ratelimit application services.
+ if not requester.app_service or requester.app_service.is_rate_limited():
+ self.ratelimit_invite(room_id, target_id)
+
# block any attempts to invite the server notices mxid
- if target.to_string() == self._server_notices_mxid:
+ if target_id == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
block_invite = False
@@ -456,9 +501,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
is_published = await self.store.is_room_published(room_id)
- if not self.spam_checker.user_may_invite(
+ if not await self.spam_checker.user_may_invite(
requester.user.to_string(),
- target.to_string(),
+ target_id,
third_party_invite=None,
room_id=room_id,
new_room=new_room,
@@ -560,17 +605,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(403, "Not allowed to join this room")
if not is_host_in_room:
- time_now_s = self.clock.time()
- (
- allowed,
- time_allowed,
- ) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
-
- if not allowed:
- raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ if ratelimit:
+ time_now_s = self.clock.time()
+ (
+ allowed,
+ time_allowed,
+ ) = self._join_rate_limiter_remote.can_requester_do_action(
+ requester,
)
+ if not allowed:
+ raise LimitExceededError(
+ retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ )
+
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@@ -624,7 +672,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# send the rejection to the inviter's HS (with fallback to
# local event)
return await self.remote_reject_invite(
- invite.event_id, txn_id, requester, content,
+ invite.event_id,
+ txn_id,
+ requester,
+ content,
)
# the inviter was on our server, but has now left. Carry on
@@ -947,7 +998,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
is_published = await self.store.is_room_published(room_id)
- if not self.spam_checker.user_may_invite(
+ if not await self.spam_checker.user_may_invite(
requester.user.to_string(),
invitee,
third_party_invite={"medium": medium, "address": address},
@@ -1145,8 +1196,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
user: UserID,
content: dict,
) -> Tuple[str, int]:
- """Implements RoomMemberHandler._remote_join
- """
+ """Implements RoomMemberHandler._remote_join"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
# and if it is the only entry we'd like to return a 404 rather than a
# 500.
@@ -1329,7 +1379,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
event.internal_metadata.out_of_band_membership = True
result_event = await self.event_creation_handler.handle_new_client_event(
- requester, event, context, extra_users=[UserID.from_string(target_user)],
+ requester,
+ event,
+ context,
+ extra_users=[UserID.from_string(target_user)],
)
# we know it was persisted, so must have a stream ordering
assert result_event.internal_metadata.stream_ordering
@@ -1337,7 +1390,11 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return result_event.event_id, result_event.internal_metadata.stream_ordering
async def remote_knock(
- self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict,
+ self,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
) -> Tuple[str, int]:
"""Sends a knock to a room. Attempts to do so via one remote out of a given list.
@@ -1363,8 +1420,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
)
async def _user_left_room(self, target: UserID, room_id: str) -> None:
- """Implements RoomMemberHandler._user_left_room
- """
+ """Implements RoomMemberHandler._user_left_room"""
user_left_room(self.distributor, target, room_id)
async def forget(self, user: UserID, room_id: str) -> None:
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 3de63e885e..9f2bde25dd 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
@@ -28,11 +28,14 @@ from synapse.replication.http.membership import (
)
from synapse.types import JsonDict, Requester, UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class RoomMemberWorkerHandler(RoomMemberHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._remote_join_client = ReplRemoteJoin.make_client(hs)
@@ -49,8 +52,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
user: UserID,
content: dict,
) -> Tuple[str, int]:
- """Implements RoomMemberHandler._remote_join
- """
+ """Implements RoomMemberHandler._remote_join"""
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
@@ -113,7 +115,11 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
return ret["event_id"], ret["stream_id"]
async def remote_knock(
- self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict,
+ self,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
) -> Tuple[str, int]:
"""Sends a knock to a room.
@@ -128,8 +134,10 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
return ret["event_id"], ret["stream_id"]
async def _user_left_room(self, target: UserID, room_id: str) -> None:
- """Implements RoomMemberHandler._user_left_room
- """
+ """Implements RoomMemberHandler._user_left_room"""
await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="left"
)
+
+ async def forget(self, target: UserID, room_id: str) -> None:
+ raise RuntimeError("Cannot forget rooms on workers.")
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 76d4169fe2..ec2ba11c75 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -23,7 +23,6 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
-from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
@@ -34,7 +33,6 @@ from synapse.types import (
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
-from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
@@ -59,8 +57,6 @@ class SamlHandler(BaseHandler):
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
@@ -76,30 +72,47 @@ class SamlHandler(BaseHandler):
)
# identifier for the external_ids table
- self._auth_provider_id = "saml"
+ self.idp_id = "saml"
+
+ # user-facing name of this auth provider
+ self.idp_name = "SAML"
+
+ # we do not currently support icons/brands for SAML auth, but this is required by
+ # the SsoIdentityProvider protocol type.
+ self.idp_icon = None
+ self.idp_brand = None
+ self.unstable_idp_brand = None
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
- # a lock on the mappings
- self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
-
self._sso_handler = hs.get_sso_handler()
+ self._sso_handler.register_identity_provider(self)
- def handle_redirect_request(
- self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
- ) -> bytes:
+ async def handle_redirect_request(
+ self,
+ request: SynapseRequest,
+ client_redirect_url: Optional[bytes],
+ ui_auth_session_id: Optional[str] = None,
+ ) -> str:
"""Handle an incoming request to /login/sso/redirect
Args:
+ request: the incoming HTTP request
client_redirect_url: the URL that we should redirect the
- client to when everything is done
+ client to after login (or None for UI Auth).
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
Returns:
URL to redirect to
"""
+ if not client_redirect_url:
+ # Some SAML identity providers (e.g. Google) require a
+ # RelayState parameter on requests, so pass in a dummy redirect URL
+ # (which will never get used).
+ client_redirect_url = b"unused"
+
reqid, info = self._saml_client.prepare_for_authenticate(
entityid=self._saml_idp_entityid, relay_state=client_redirect_url
)
@@ -109,7 +122,8 @@ class SamlHandler(BaseHandler):
now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData(
- creation_time=now, ui_auth_session_id=ui_auth_session_id,
+ creation_time=now,
+ ui_auth_session_id=ui_auth_session_id,
)
for key, value in info["headers"]:
@@ -120,7 +134,7 @@ class SamlHandler(BaseHandler):
raise Exception("prepare_for_authenticate didn't return a Location header")
async def handle_saml_response(self, request: SynapseRequest) -> None:
- """Handle an incoming request to /_matrix/saml2/authn_response
+ """Handle an incoming request to /_synapse/client/saml2/authn_response
Args:
request: the incoming request from the browser. We'll
@@ -167,6 +181,29 @@ class SamlHandler(BaseHandler):
return
logger.debug("SAML2 response: %s", saml2_auth.origxml)
+
+ await self._handle_authn_response(request, saml2_auth, relay_state)
+
+ async def _handle_authn_response(
+ self,
+ request: SynapseRequest,
+ saml2_auth: saml2.response.AuthnResponse,
+ relay_state: str,
+ ) -> None:
+ """Handle an AuthnResponse, having parsed it from the request params
+
+ Assumes that the signature on the response object has been checked. Maps
+ the user onto an MXID, registering them if necessary, and returns a response
+ to the browser.
+
+ Args:
+ request: the incoming request from the browser. We'll respond to it with an
+ HTML page or a redirect
+ saml2_auth: the parsed AuthnResponse object
+ relay_state: the RelayState query param, which encodes the URI to rediret
+ back to
+ """
+
for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather
# useful, so split it up.
@@ -183,72 +220,64 @@ class SamlHandler(BaseHandler):
saml2_auth.in_response_to, None
)
- # Ensure that the attributes of the logged in user meet the required
- # attributes.
- for requirement in self._saml2_attribute_requirements:
- if not _check_attribute_requirement(saml2_auth.ava, requirement):
- self._sso_handler.render_error(
- request, "unauthorised", "You are not authorised to log in here."
- )
+ # first check if we're doing a UIA
+ if current_session and current_session.ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+ except MappingException as e:
+ logger.exception("Failed to extract remote user id from SAML response")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
return
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self.idp_id,
+ remote_user_id,
+ current_session.ui_auth_session_id,
+ request,
+ )
+
+ # otherwise, we're handling a login request.
+
+ # Ensure that the attributes of the logged in user meet the required
+ # attributes.
+ if not self._sso_handler.check_required_attributes(
+ request, saml2_auth.ava, self._saml2_attribute_requirements
+ ):
+ return
# Call the mapper to register/login the user
try:
- user_id = await self._map_saml_response_to_user(
- saml2_auth, relay_state, user_agent, ip_address
- )
+ await self._complete_saml_login(saml2_auth, request, relay_state)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
- return
-
- # Complete the interactive auth session or the login.
- if current_session and current_session.ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, current_session.ui_auth_session_id, request
- )
-
- else:
- await self._auth_handler.complete_sso_login(user_id, request, relay_state)
- async def _map_saml_response_to_user(
+ async def _complete_saml_login(
self,
saml2_auth: saml2.response.AuthnResponse,
+ request: SynapseRequest,
client_redirect_url: str,
- user_agent: str,
- ip_address: str,
- ) -> str:
+ ) -> None:
"""
- Given a SAML response, retrieve the user ID for it and possibly register the user.
+ Given a SAML response, complete the login flow
+
+ Retrieves the remote user ID, registers the user if necessary, and serves
+ a redirect back to the client with a login-token.
Args:
saml2_auth: The parsed SAML2 response.
+ request: The request to respond to
client_redirect_url: The redirect URL passed in by the client.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
-
- Returns:
- The user ID associated with this response.
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
-
- remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ remote_user_id = self._remote_id_from_saml_response(
saml2_auth, client_redirect_url
)
- if not remote_user_id:
- raise MappingException(
- "Failed to extract remote user id from SAML response"
- )
-
async def saml_response_to_remapped_user_attributes(
failures: int,
) -> UserAttributes:
@@ -294,16 +323,44 @@ class SamlHandler(BaseHandler):
return None
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- return await self._sso_handler.get_mxid_from_sso(
- self._auth_provider_id,
- remote_user_id,
- user_agent,
- ip_address,
- saml_response_to_remapped_user_attributes,
- grandfather_existing_users,
+ await self._sso_handler.complete_sso_login_request(
+ self.idp_id,
+ remote_user_id,
+ request,
+ client_redirect_url,
+ saml_response_to_remapped_user_attributes,
+ grandfather_existing_users,
+ )
+
+ def _remote_id_from_saml_response(
+ self,
+ saml2_auth: saml2.response.AuthnResponse,
+ client_redirect_url: Optional[str],
+ ) -> str:
+ """Extract the unique remote id from a SAML2 AuthnResponse
+
+ Args:
+ saml2_auth: The parsed SAML2 response.
+ client_redirect_url: The redirect URL passed in by the client.
+ Returns:
+ remote user id
+
+ Raises:
+ MappingException if there was an error extracting the user id
+ """
+ # It's not obvious why we need to pass in the redirect URI to the mapping
+ # provider, but we do :/
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ saml2_auth, client_redirect_url
+ )
+
+ if not remote_user_id:
+ raise MappingException(
+ "Failed to extract remote user id from SAML response"
)
+ return remote_user_id
+
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
@@ -315,21 +372,6 @@ class SamlHandler(BaseHandler):
del self._outstanding_requests_dict[reqid]
-def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
- values = ava.get(req.attribute, [])
- for v in values:
- if v == req.value:
- return True
-
- logger.info(
- "SAML2 attribute %s did not match required value '%s' (was '%s')",
- req.attribute,
- req.value,
- values,
- )
- return False
-
-
DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)
@@ -410,7 +452,8 @@ class DefaultSamlMappingProvider:
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
except KeyError:
logger.warning(
- "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+ "SAML2 response lacks a '%s' attestation",
+ self._mxid_source_attribute,
)
raise SynapseError(
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 66f1bbcfc4..d742dfbd53 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -15,23 +15,28 @@
import itertools
import logging
-from typing import Iterable
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
+from synapse.events import EventBase
from synapse.storage.state import StateFilter
+from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
@@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
return historical_room_ids
- async def search(self, user, content, batch=None):
+ async def search(
+ self, user: UserID, content: JsonDict, batch: Optional[str] = None
+ ) -> JsonDict:
"""Performs a full text search for a user.
Args:
- user (UserID)
- content (dict): Search parameters
- batch (str): The next_batch parameter. Used for pagination.
+ user
+ content: Search parameters
+ batch: The next_batch parameter. Used for pagination.
Returns:
dict to be returned to the client with results of search
@@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
if search_filter.rooms:
- historical_room_ids = []
+ historical_room_ids = [] # type: List[str]
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
rank_map = {} # event_id -> rank of event
allowed_events = []
- room_groups = {} # Holds result of grouping by room, if applicable
- sender_group = {} # Holds result of grouping by sender, if applicable
+ # Holds result of grouping by room, if applicable
+ room_groups = {} # type: Dict[str, JsonDict]
+ # Holds result of grouping by sender, if applicable
+ sender_group = {} # type: Dict[str, JsonDict]
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
@@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id)
elif order_by == "recent":
- room_events = []
+ room_events = [] # type: List[EventBase]
i = 0
pagination_token = batch_token
@@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
state_results = {}
if include_state:
- rooms = {e.room_id for e in allowed_events}
- for room_id in rooms:
+ for room_id in {e.room_id for e in allowed_events}:
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
- state_results.values()
-
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong
@@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
if state_results:
s = {}
- for room_id, state in state_results.items():
+ for room_id, state_events in state_results.items():
s[room_id] = await self._event_serializer.serialize_events(
- state, time_now
+ state_events, time_now
)
rooms_cat_res["state"] = s
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 7713c3cf91..6243fab091 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -14,24 +14,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
- self._password_policy_handler = hs.get_password_policy_handler()
async def set_password(
self,
@@ -39,8 +41,8 @@ class SetPasswordHandler(BaseHandler):
password_hash: str,
logout_devices: bool,
requester: Optional[Requester] = None,
- ):
- if not self.hs.config.password_localdb_enabled:
+ ) -> None:
+ if not self._auth_handler.can_change_password():
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
try:
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
new file mode 100644
index 0000000000..5d9418969d
--- /dev/null
+++ b/synapse/handlers/space_summary.py
@@ -0,0 +1,395 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+from collections import deque
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple, cast
+
+import attr
+
+from synapse.api.constants import EventContentFields, EventTypes, HistoryVisibility
+from synapse.api.errors import AuthError
+from synapse.events import EventBase
+from synapse.events.utils import format_event_for_client_v2
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+# number of rooms to return. We'll stop once we hit this limit.
+# TODO: allow clients to reduce this with a request param.
+MAX_ROOMS = 50
+
+# max number of events to return per room.
+MAX_ROOMS_PER_SPACE = 50
+
+# max number of federation servers to hit per room
+MAX_SERVERS_PER_SPACE = 3
+
+
+class SpaceSummaryHandler:
+ def __init__(self, hs: "HomeServer"):
+ self._clock = hs.get_clock()
+ self._auth = hs.get_auth()
+ self._room_list_handler = hs.get_room_list_handler()
+ self._state_handler = hs.get_state_handler()
+ self._store = hs.get_datastore()
+ self._event_serializer = hs.get_event_client_serializer()
+ self._server_name = hs.hostname
+ self._federation_client = hs.get_federation_client()
+
+ async def get_space_summary(
+ self,
+ requester: str,
+ room_id: str,
+ suggested_only: bool = False,
+ max_rooms_per_space: Optional[int] = None,
+ ) -> JsonDict:
+ """
+ Implementation of the space summary C-S API
+
+ Args:
+ requester: user id of the user making this request
+
+ room_id: room id to start the summary at
+
+ suggested_only: whether we should only return children with the "suggested"
+ flag set.
+
+ max_rooms_per_space: an optional limit on the number of child rooms we will
+ return. This does not apply to the root room (ie, room_id), and
+ is overridden by MAX_ROOMS_PER_SPACE.
+
+ Returns:
+ summary dict to return
+ """
+ # first of all, check that the user is in the room in question (or it's
+ # world-readable)
+ await self._auth.check_user_in_room_or_world_readable(room_id, requester)
+
+ # the queue of rooms to process
+ room_queue = deque((_RoomQueueEntry(room_id, ()),))
+
+ # rooms we have already processed
+ processed_rooms = set() # type: Set[str]
+
+ # events we have already processed. We don't necessarily have their event ids,
+ # so instead we key on (room id, state key)
+ processed_events = set() # type: Set[Tuple[str, str]]
+
+ rooms_result = [] # type: List[JsonDict]
+ events_result = [] # type: List[JsonDict]
+
+ while room_queue and len(rooms_result) < MAX_ROOMS:
+ queue_entry = room_queue.popleft()
+ room_id = queue_entry.room_id
+ if room_id in processed_rooms:
+ # already done this room
+ continue
+
+ logger.debug("Processing room %s", room_id)
+
+ is_in_room = await self._store.is_host_joined(room_id, self._server_name)
+
+ # The client-specified max_rooms_per_space limit doesn't apply to the
+ # room_id specified in the request, so we ignore it if this is the
+ # first room we are processing.
+ max_children = max_rooms_per_space if processed_rooms else None
+
+ if is_in_room:
+ rooms, events = await self._summarize_local_room(
+ requester, room_id, suggested_only, max_children
+ )
+ else:
+ rooms, events = await self._summarize_remote_room(
+ queue_entry,
+ suggested_only,
+ max_children,
+ exclude_rooms=processed_rooms,
+ )
+
+ logger.debug(
+ "Query of %s returned rooms %s, events %s",
+ queue_entry.room_id,
+ [room.get("room_id") for room in rooms],
+ ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events],
+ )
+
+ rooms_result.extend(rooms)
+
+ # any rooms returned don't need visiting again
+ processed_rooms.update(cast(str, room.get("room_id")) for room in rooms)
+
+ # the room we queried may or may not have been returned, but don't process
+ # it again, anyway.
+ processed_rooms.add(room_id)
+
+ # XXX: is it ok that we blindly iterate through any events returned by
+ # a remote server, whether or not they actually link to any rooms in our
+ # tree?
+ for ev in events:
+ # remote servers might return events we have already processed
+ # (eg, Dendrite returns inward pointers as well as outward ones), so
+ # we need to filter them out, to avoid returning duplicate links to the
+ # client.
+ ev_key = (ev["room_id"], ev["state_key"])
+ if ev_key in processed_events:
+ continue
+ events_result.append(ev)
+
+ # add the child to the queue. we have already validated
+ # that the vias are a list of server names.
+ room_queue.append(
+ _RoomQueueEntry(ev["state_key"], ev["content"]["via"])
+ )
+ processed_events.add(ev_key)
+
+ return {"rooms": rooms_result, "events": events_result}
+
+ async def federation_space_summary(
+ self,
+ room_id: str,
+ suggested_only: bool,
+ max_rooms_per_space: Optional[int],
+ exclude_rooms: Iterable[str],
+ ) -> JsonDict:
+ """
+ Implementation of the space summary Federation API
+
+ Args:
+ room_id: room id to start the summary at
+
+ suggested_only: whether we should only return children with the "suggested"
+ flag set.
+
+ max_rooms_per_space: an optional limit on the number of child rooms we will
+ return. Unlike the C-S API, this applies to the root room (room_id).
+ It is clipped to MAX_ROOMS_PER_SPACE.
+
+ exclude_rooms: a list of rooms to skip over (presumably because the
+ calling server has already seen them).
+
+ Returns:
+ summary dict to return
+ """
+ # the queue of rooms to process
+ room_queue = deque((room_id,))
+
+ # the set of rooms that we should not walk further. Initialise it with the
+ # excluded-rooms list; we will add other rooms as we process them so that
+ # we do not loop.
+ processed_rooms = set(exclude_rooms) # type: Set[str]
+
+ rooms_result = [] # type: List[JsonDict]
+ events_result = [] # type: List[JsonDict]
+
+ while room_queue and len(rooms_result) < MAX_ROOMS:
+ room_id = room_queue.popleft()
+ if room_id in processed_rooms:
+ # already done this room
+ continue
+
+ logger.debug("Processing room %s", room_id)
+
+ rooms, events = await self._summarize_local_room(
+ None, room_id, suggested_only, max_rooms_per_space
+ )
+
+ processed_rooms.add(room_id)
+
+ rooms_result.extend(rooms)
+ events_result.extend(events)
+
+ # add any children to the queue
+ room_queue.extend(edge_event["state_key"] for edge_event in events)
+
+ return {"rooms": rooms_result, "events": events_result}
+
+ async def _summarize_local_room(
+ self,
+ requester: Optional[str],
+ room_id: str,
+ suggested_only: bool,
+ max_children: Optional[int],
+ ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]:
+ if not await self._is_room_accessible(room_id, requester):
+ return (), ()
+
+ room_entry = await self._build_room_entry(room_id)
+
+ # look for child rooms/spaces.
+ child_events = await self._get_child_events(room_id)
+
+ if suggested_only:
+ # we only care about suggested children
+ child_events = filter(_is_suggested_child_event, child_events)
+
+ if max_children is None or max_children > MAX_ROOMS_PER_SPACE:
+ max_children = MAX_ROOMS_PER_SPACE
+
+ now = self._clock.time_msec()
+ events_result = [] # type: List[JsonDict]
+ for edge_event in itertools.islice(child_events, max_children):
+ events_result.append(
+ await self._event_serializer.serialize_event(
+ edge_event,
+ time_now=now,
+ event_format=format_event_for_client_v2,
+ )
+ )
+ return (room_entry,), events_result
+
+ async def _summarize_remote_room(
+ self,
+ room: "_RoomQueueEntry",
+ suggested_only: bool,
+ max_children: Optional[int],
+ exclude_rooms: Iterable[str],
+ ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]:
+ room_id = room.room_id
+ logger.info("Requesting summary for %s via %s", room_id, room.via)
+
+ # we need to make the exclusion list json-serialisable
+ exclude_rooms = list(exclude_rooms)
+
+ via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE)
+ try:
+ res = await self._federation_client.get_space_summary(
+ via,
+ room_id,
+ suggested_only=suggested_only,
+ max_rooms_per_space=max_children,
+ exclude_rooms=exclude_rooms,
+ )
+ except Exception as e:
+ logger.warning(
+ "Unable to get summary of %s via federation: %s",
+ room_id,
+ e,
+ exc_info=logger.isEnabledFor(logging.DEBUG),
+ )
+ return (), ()
+
+ return res.rooms, tuple(
+ ev.data
+ for ev in res.events
+ if ev.event_type == EventTypes.MSC1772_SPACE_CHILD
+ )
+
+ async def _is_room_accessible(self, room_id: str, requester: Optional[str]) -> bool:
+ # if we have an authenticated requesting user, first check if they are in the
+ # room
+ if requester:
+ try:
+ await self._auth.check_user_in_room(room_id, requester)
+ return True
+ except AuthError:
+ pass
+
+ # otherwise, check if the room is peekable
+ hist_vis_ev = await self._state_handler.get_current_state(
+ room_id, EventTypes.RoomHistoryVisibility, ""
+ )
+ if hist_vis_ev:
+ hist_vis = hist_vis_ev.content.get("history_visibility")
+ if hist_vis == HistoryVisibility.WORLD_READABLE:
+ return True
+
+ logger.info(
+ "room %s is unpeekable and user %s is not a member, omitting from summary",
+ room_id,
+ requester,
+ )
+ return False
+
+ async def _build_room_entry(self, room_id: str) -> JsonDict:
+ """Generate en entry suitable for the 'rooms' list in the summary response"""
+ stats = await self._store.get_room_with_stats(room_id)
+
+ # currently this should be impossible because we call
+ # check_user_in_room_or_world_readable on the room before we get here, so
+ # there should always be an entry
+ assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
+
+ current_state_ids = await self._store.get_current_state_ids(room_id)
+ create_event = await self._store.get_event(
+ current_state_ids[(EventTypes.Create, "")]
+ )
+
+ # TODO: update once MSC1772 lands
+ room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE)
+
+ entry = {
+ "room_id": stats["room_id"],
+ "name": stats["name"],
+ "topic": stats["topic"],
+ "canonical_alias": stats["canonical_alias"],
+ "num_joined_members": stats["joined_members"],
+ "avatar_url": stats["avatar"],
+ "world_readable": (
+ stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
+ ),
+ "guest_can_join": stats["guest_access"] == "can_join",
+ "room_type": room_type,
+ }
+
+ # Filter out Nones – rather omit the field altogether
+ room_entry = {k: v for k, v in entry.items() if v is not None}
+
+ return room_entry
+
+ async def _get_child_events(self, room_id: str) -> Iterable[EventBase]:
+ # look for child rooms/spaces.
+ current_state_ids = await self._store.get_current_state_ids(room_id)
+
+ events = await self._store.get_events_as_list(
+ [
+ event_id
+ for key, event_id in current_state_ids.items()
+ # TODO: update once MSC1772 lands
+ if key[0] == EventTypes.MSC1772_SPACE_CHILD
+ ]
+ )
+
+ # filter out any events without a "via" (which implies it has been redacted)
+ return (e for e in events if _has_valid_via(e))
+
+
+@attr.s(frozen=True, slots=True)
+class _RoomQueueEntry:
+ room_id = attr.ib(type=str)
+ via = attr.ib(type=Sequence[str])
+
+
+def _has_valid_via(e: EventBase) -> bool:
+ via = e.content.get("via")
+ if not via or not isinstance(via, Sequence):
+ return False
+ for v in via:
+ if not isinstance(v, str):
+ logger.debug("Ignoring edge event %s with invalid via entry", e.event_id)
+ return False
+ return True
+
+
+def _is_suggested_child_event(edge_event: EventBase) -> bool:
+ suggested = edge_event.content.get("suggested")
+ if isinstance(suggested, bool) and suggested:
+ return True
+ logger.debug("Ignorning not-suggested child %s", edge_event.state_key)
+ return False
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 47ad96f97e..415b1c2d17 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -12,15 +12,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+)
+from urllib.parse import urlencode
import attr
-
-from synapse.api.errors import RedirectException
-from synapse.handlers._base import BaseHandler
-from synapse.http.server import respond_with_html
-from synapse.types import UserID, contains_invalid_mxid_characters
+from typing_extensions import NoReturn, Protocol
+
+from twisted.web.iweb import IRequest
+from twisted.web.server import Request
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
+from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+from synapse.http import get_request_user_agent
+from synapse.http.server import respond_with_html, respond_with_redirect
+from synapse.http.site import SynapseRequest
+from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters
+from synapse.util.async_helpers import Linearizer
+from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -35,24 +58,195 @@ class MappingException(Exception):
"""
+class SsoIdentityProvider(Protocol):
+ """Abstract base class to be implemented by SSO Identity Providers
+
+ An Identity Provider, or IdP, is an external HTTP service which authenticates a user
+ to say whether they should be allowed to log in, or perform a given action.
+
+ Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
+ and CAS.
+
+ The main entry point is `handle_redirect_request`, which should return a URI to
+ redirect the user's browser to the IdP's authentication page.
+
+ Each IdP should be registered with the SsoHandler via
+ `hs.get_sso_handler().register_identity_provider()`, so that requests to
+ `/_matrix/client/r0/login/sso/redirect` can be correctly dispatched.
+ """
+
+ @property
+ @abc.abstractmethod
+ def idp_id(self) -> str:
+ """A unique identifier for this SSO provider
+
+ Eg, "saml", "cas", "github"
+ """
+
+ @property
+ @abc.abstractmethod
+ def idp_name(self) -> str:
+ """User-facing name for this provider"""
+
+ @property
+ def idp_icon(self) -> Optional[str]:
+ """Optional MXC URI for user-facing icon"""
+ return None
+
+ @property
+ def idp_brand(self) -> Optional[str]:
+ """Optional branding identifier"""
+ return None
+
+ @property
+ def unstable_idp_brand(self) -> Optional[str]:
+ """Optional brand identifier for the unstable API (see MSC2858)."""
+ return None
+
+ @abc.abstractmethod
+ async def handle_redirect_request(
+ self,
+ request: SynapseRequest,
+ client_redirect_url: Optional[bytes],
+ ui_auth_session_id: Optional[str] = None,
+ ) -> str:
+ """Handle an incoming request to /login/sso/redirect
+
+ Args:
+ request: the incoming HTTP request
+ client_redirect_url: the URL that we should redirect the
+ client to after login (or None for UI Auth).
+ ui_auth_session_id: The session ID of the ongoing UI Auth (or
+ None if this is a login).
+
+ Returns:
+ URL to redirect to
+ """
+ raise NotImplementedError()
+
+
@attr.s
class UserAttributes:
- localpart = attr.ib(type=str)
+ # the localpart of the mxid that the mapper has assigned to the user.
+ # if `None`, the mapper has not picked a userid, and the user should be prompted to
+ # enter one.
+ localpart = attr.ib(type=Optional[str])
display_name = attr.ib(type=Optional[str], default=None)
- emails = attr.ib(type=List[str], default=attr.Factory(list))
+ emails = attr.ib(type=Collection[str], default=attr.Factory(list))
+
+
+@attr.s(slots=True)
+class UsernameMappingSession:
+ """Data we track about SSO sessions"""
+
+ # A unique identifier for this SSO provider, e.g. "oidc" or "saml".
+ auth_provider_id = attr.ib(type=str)
+
+ # user ID on the IdP server
+ remote_user_id = attr.ib(type=str)
+
+ # attributes returned by the ID mapper
+ display_name = attr.ib(type=Optional[str])
+ emails = attr.ib(type=Collection[str])
+
+ # An optional dictionary of extra attributes to be provided to the client in the
+ # login response.
+ extra_login_attributes = attr.ib(type=Optional[JsonDict])
+ # where to redirect the client back to
+ client_redirect_url = attr.ib(type=str)
-class SsoHandler(BaseHandler):
+ # expiry time for the session, in milliseconds
+ expiry_time_ms = attr.ib(type=int)
+
+ # choices made by the user
+ chosen_localpart = attr.ib(type=Optional[str], default=None)
+ use_display_name = attr.ib(type=bool, default=True)
+ emails_to_use = attr.ib(type=Collection[str], default=())
+ terms_accepted_version = attr.ib(type=Optional[str], default=None)
+
+
+# the HTTP cookie used to track the mapping session id
+USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
+
+
+class SsoHandler:
# The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000
+ # the time a UsernameMappingSession remains valid for
+ _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
+
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self._clock = hs.get_clock()
+ self._store = hs.get_datastore()
+ self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
+ self._auth_handler = hs.get_auth_handler()
self._error_template = hs.config.sso_error_template
+ self._bad_user_template = hs.config.sso_auth_bad_user_template
+
+ # The following template is shown after a successful user interactive
+ # authentication session. It tells the user they can close the window.
+ self._sso_auth_success_template = hs.config.sso_auth_success_template
+
+ # a lock on the mappings
+ self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
+
+ # a map from session id to session data
+ self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
+
+ # map from idp_id to SsoIdentityProvider
+ self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
+
+ self._consent_at_registration = hs.config.consent.user_consent_at_registration
+
+ def register_identity_provider(self, p: SsoIdentityProvider):
+ p_id = p.idp_id
+ assert p_id not in self._identity_providers
+ self._identity_providers[p_id] = p
+
+ def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
+ """Get the configured identity providers"""
+ return self._identity_providers
+
+ async def get_identity_providers_for_user(
+ self, user_id: str
+ ) -> Mapping[str, SsoIdentityProvider]:
+ """Get the SsoIdentityProviders which a user has used
+
+ Given a user id, get the identity providers that that user has used to log in
+ with in the past (and thus could use to re-identify themselves for UI Auth).
+
+ Args:
+ user_id: MXID of user to look up
+
+ Raises:
+ a map of idp_id to SsoIdentityProvider
+ """
+ external_ids = await self._store.get_external_ids_by_user(user_id)
+
+ valid_idps = {}
+ for idp_id, _ in external_ids:
+ idp = self._identity_providers.get(idp_id)
+ if not idp:
+ logger.warning(
+ "User %r has an SSO mapping for IdP %r, but this is no longer "
+ "configured.",
+ user_id,
+ idp_id,
+ )
+ else:
+ valid_idps[idp_id] = idp
+
+ return valid_idps
def render_error(
- self, request, error: str, error_description: Optional[str] = None
+ self,
+ request: Request,
+ error: str,
+ error_description: Optional[str] = None,
+ code: int = 400,
) -> None:
"""Renders the error template and responds with it.
@@ -64,11 +258,53 @@ class SsoHandler(BaseHandler):
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
+ code: The integer error code (an HTTP response code)
"""
html = self._error_template.render(
error=error, error_description=error_description
)
- respond_with_html(request, 400, html)
+ respond_with_html(request, code, html)
+
+ async def handle_redirect_request(
+ self,
+ request: SynapseRequest,
+ client_redirect_url: bytes,
+ idp_id: Optional[str],
+ ) -> str:
+ """Handle a request to /login/sso/redirect
+
+ Args:
+ request: incoming HTTP request
+ client_redirect_url: the URL that we should redirect the
+ client to after login.
+ idp_id: optional identity provider chosen by the client
+
+ Returns:
+ the URI to redirect to
+ """
+ if not self._identity_providers:
+ raise SynapseError(
+ 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
+ )
+
+ # if the client chose an IdP, use that
+ idp = None # type: Optional[SsoIdentityProvider]
+ if idp_id:
+ idp = self._identity_providers.get(idp_id)
+ if not idp:
+ raise NotFoundError("Unknown identity provider")
+
+ # if we only have one auth provider, redirect to it directly
+ elif len(self._identity_providers) == 1:
+ idp = next(iter(self._identity_providers.values()))
+
+ if idp:
+ return await idp.handle_redirect_request(request, client_redirect_url)
+
+ # otherwise, redirect to the IDP picker
+ return "/_synapse/client/pick_idp?" + urlencode(
+ (("redirectUrl", client_redirect_url),)
+ )
async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str
@@ -95,8 +331,9 @@ class SsoHandler(BaseHandler):
)
# Check if we already have a mapping for this user.
- previously_registered_user_id = await self.store.get_user_by_external_id(
- auth_provider_id, remote_user_id,
+ previously_registered_user_id = await self._store.get_user_by_external_id(
+ auth_provider_id,
+ remote_user_id,
)
# A match was found, return the user ID.
@@ -112,15 +349,16 @@ class SsoHandler(BaseHandler):
# No match.
return None
- async def get_mxid_from_sso(
+ async def complete_sso_login_request(
self,
auth_provider_id: str,
remote_user_id: str,
- user_agent: str,
- ip_address: str,
+ request: SynapseRequest,
+ client_redirect_url: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
- grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
- ) -> str:
+ grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
+ extra_login_attributes: Optional[JsonDict] = None,
+ ) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -139,12 +377,18 @@ class SsoHandler(BaseHandler):
given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls.
+ Finally, we generate a redirect to the supplied redirect uri, with a login token
+
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
+
remote_user_id: The unique identifier from the SSO provider.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
+
+ request: The request to respond to
+
+ client_redirect_url: The redirect URL passed in by the client.
+
sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed.
@@ -156,12 +400,13 @@ class SsoHandler(BaseHandler):
to the user.
RedirectException to redirect to an additional page (e.g.
to prompt the user for more information).
+
grandfather_existing_users: A callable which can return an previously
existing matrix ID. The SSO ID is then linked to the returned
matrix ID.
- Returns:
- The user ID associated with the SSO response.
+ extra_login_attributes: An optional dictionary of extra
+ attributes to be provided to the client in the login response.
Raises:
MappingException if there was a problem mapping the response to a user.
@@ -169,24 +414,65 @@ class SsoHandler(BaseHandler):
to an additional page. (e.g. to prompt for more information)
"""
- # first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
- auth_provider_id, remote_user_id,
- )
- if previously_registered_user_id:
- return previously_registered_user_id
+ new_user = False
+
+ # grab a lock while we try to find a mapping for this user. This seems...
+ # optimistic, especially for implementations that end up redirecting to
+ # interstitial pages.
+ with await self._mapping_lock.queue(auth_provider_id):
+ # first of all, check if we already have a mapping for this user
+ user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id,
+ remote_user_id,
+ )
- # Check for grandfathering of users.
- if grandfather_existing_users:
- previously_registered_user_id = await grandfather_existing_users()
- if previously_registered_user_id:
- # Future logins should also match this user ID.
- await self.store.record_user_external_id(
- auth_provider_id, remote_user_id, previously_registered_user_id
+ # Check for grandfathering of users.
+ if not user_id:
+ user_id = await grandfather_existing_users()
+ if user_id:
+ # Future logins should also match this user ID.
+ await self._store.record_user_external_id(
+ auth_provider_id, remote_user_id, user_id
+ )
+
+ # Otherwise, generate a new user.
+ if not user_id:
+ attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+
+ if attributes.localpart is None:
+ # the mapper doesn't return a username. bail out with a redirect to
+ # the username picker.
+ await self._redirect_to_username_picker(
+ auth_provider_id,
+ remote_user_id,
+ attributes,
+ client_redirect_url,
+ extra_login_attributes,
+ )
+
+ user_id = await self._register_mapped_user(
+ attributes,
+ auth_provider_id,
+ remote_user_id,
+ get_request_user_agent(request),
+ request.getClientIP(),
)
- return previously_registered_user_id
+ new_user = True
- # Otherwise, generate a new user.
+ await self._auth_handler.complete_sso_login(
+ user_id,
+ auth_provider_id,
+ request,
+ client_redirect_url,
+ extra_login_attributes,
+ new_user=new_user,
+ )
+
+ async def _call_attribute_mapper(
+ self,
+ sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ ) -> UserAttributes:
+ """Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES):
try:
attributes = await sso_to_matrix_id_mapper(i)
@@ -208,14 +494,12 @@ class SsoHandler(BaseHandler):
)
if not attributes.localpart:
- raise MappingException(
- "Error parsing SSO response: SSO mapping provider plugin "
- "did not return a localpart value"
- )
+ # the mapper has not picked a localpart
+ return attributes
# Check if this mxid already exists
- user_id = UserID(attributes.localpart, self.server_name).to_string()
- if not await self.store.get_users_by_id_case_insensitive(user_id):
+ user_id = UserID(attributes.localpart, self._server_name).to_string()
+ if not await self._store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:
@@ -224,10 +508,101 @@ class SsoHandler(BaseHandler):
raise MappingException(
"Unable to generate a Matrix ID from the SSO response"
)
+ return attributes
+
+ async def _redirect_to_username_picker(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ attributes: UserAttributes,
+ client_redirect_url: str,
+ extra_login_attributes: Optional[JsonDict],
+ ) -> NoReturn:
+ """Creates a UsernameMappingSession and redirects the browser
+
+ Called if the user mapping provider doesn't return a localpart for a new user.
+ Raises a RedirectException which redirects the browser to the username picker.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+
+ remote_user_id: The unique identifier from the SSO provider.
+
+ attributes: the user attributes returned by the user mapping provider.
+
+ client_redirect_url: The redirect URL passed in by the client, which we
+ will eventually redirect back to.
+
+ extra_login_attributes: An optional dictionary of extra
+ attributes to be provided to the client in the login response.
+
+ Raises:
+ RedirectException
+ """
+ session_id = random_string(16)
+ now = self._clock.time_msec()
+ session = UsernameMappingSession(
+ auth_provider_id=auth_provider_id,
+ remote_user_id=remote_user_id,
+ display_name=attributes.display_name,
+ emails=attributes.emails,
+ client_redirect_url=client_redirect_url,
+ expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
+ extra_login_attributes=extra_login_attributes,
+ )
+
+ self._username_mapping_sessions[session_id] = session
+ logger.info("Recorded registration session id %s", session_id)
+
+ # Set the cookie and redirect to the username picker
+ e = RedirectException(b"/_synapse/client/pick_username/account_details")
+ e.cookies.append(
+ b"%s=%s; path=/"
+ % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
+ )
+ raise e
+
+ async def _register_mapped_user(
+ self,
+ attributes: UserAttributes,
+ auth_provider_id: str,
+ remote_user_id: str,
+ user_agent: str,
+ ip_address: str,
+ ) -> str:
+ """Register a new SSO user.
+
+ This is called once we have successfully mapped the remote user id onto a local
+ user id, one way or another.
+
+ Args:
+ attributes: user attributes returned by the user mapping provider,
+ including a non-empty localpart.
+
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+
+ remote_user_id: The unique identifier from the SSO provider.
+
+ user_agent: The user-agent in the HTTP request (used for potential
+ shadow-banning.)
+
+ ip_address: The IP address of the requester (used for potential
+ shadow-banning.)
+
+ Raises:
+ a MappingException if the localpart is invalid.
+
+ a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart
+ is already taken.
+ """
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
- if contains_invalid_mxid_characters(attributes.localpart):
+ if not attributes.localpart or contains_invalid_mxid_characters(
+ attributes.localpart
+ ):
raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
logger.debug("Mapped SSO user to local part %s", attributes.localpart)
@@ -236,9 +611,382 @@ class SsoHandler(BaseHandler):
default_display_name=attributes.display_name,
bind_emails=attributes.emails,
user_agent_ips=[(user_agent, ip_address)],
+ auth_provider_id=auth_provider_id,
)
- await self.store.record_user_external_id(
+ await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
+
+ async def complete_sso_ui_auth_request(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ ui_auth_session_id: str,
+ request: Request,
+ ) -> None:
+ """
+ Given an SSO ID, retrieve the user ID for it and complete UIA.
+
+ Note that this requires that the user is mapped in the "user_external_ids"
+ table. This will be the case if they have ever logged in via SAML or OIDC in
+ recentish synapse versions, but may not be for older users.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ ui_auth_session_id: The ID of the user-interactive auth session.
+ request: The request to complete.
+ """
+
+ user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id,
+ remote_user_id,
+ )
+
+ user_id_to_verify = await self._auth_handler.get_session_data(
+ ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+ ) # type: str
+
+ if not user_id:
+ logger.warning(
+ "Remote user %s/%s has not previously logged in here: UIA will fail",
+ auth_provider_id,
+ remote_user_id,
+ )
+ elif user_id != user_id_to_verify:
+ logger.warning(
+ "Remote user %s/%s mapped onto incorrect user %s: UIA will fail",
+ auth_provider_id,
+ remote_user_id,
+ user_id,
+ )
+ else:
+ # success!
+ # Mark the stage of the authentication as successful.
+ await self._store.mark_ui_auth_stage_complete(
+ ui_auth_session_id, LoginType.SSO, user_id
+ )
+
+ # Render the HTML confirmation page and return.
+ html = self._sso_auth_success_template
+ respond_with_html(request, 200, html)
+ return
+
+ # the user_id didn't match: mark the stage of the authentication as unsuccessful
+ await self._store.mark_ui_auth_stage_complete(
+ ui_auth_session_id, LoginType.SSO, ""
+ )
+
+ # render an error page.
+ html = self._bad_user_template.render(
+ server_name=self._server_name,
+ user_id_to_verify=user_id_to_verify,
+ )
+ respond_with_html(request, 200, html)
+
+ def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
+ """Look up the given username mapping session
+
+ If it is not found, raises a SynapseError with an http code of 400
+
+ Args:
+ session_id: session to look up
+ Returns:
+ active mapping session
+ Raises:
+ SynapseError if the session is not found/has expired
+ """
+ self._expire_old_sessions()
+ session = self._username_mapping_sessions.get(session_id)
+ if session:
+ return session
+ logger.info("Couldn't find session id %s", session_id)
+ raise SynapseError(400, "unknown session")
+
+ async def check_username_availability(
+ self,
+ localpart: str,
+ session_id: str,
+ ) -> bool:
+ """Handle an "is username available" callback check
+
+ Args:
+ localpart: desired localpart
+ session_id: the session id for the username picker
+ Returns:
+ True if the username is available
+ Raises:
+ SynapseError if the localpart is invalid or the session is unknown
+ """
+
+ # make sure that there is a valid mapping session, to stop people dictionary-
+ # scanning for accounts
+ self.get_mapping_session(session_id)
+
+ logger.info(
+ "[session %s] Checking for availability of username %s",
+ session_id,
+ localpart,
+ )
+
+ if contains_invalid_mxid_characters(localpart):
+ raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
+ user_id = UserID(localpart, self._server_name).to_string()
+ user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
+
+ logger.info("[session %s] users: %s", session_id, user_infos)
+ return not user_infos
+
+ async def handle_submit_username_request(
+ self,
+ request: SynapseRequest,
+ session_id: str,
+ localpart: str,
+ use_display_name: bool,
+ emails_to_use: Iterable[str],
+ ) -> None:
+ """Handle a request to the username-picker 'submit' endpoint
+
+ Will serve an HTTP response to the request.
+
+ Args:
+ request: HTTP request
+ localpart: localpart requested by the user
+ session_id: ID of the username mapping session, extracted from a cookie
+ use_display_name: whether the user wants to use the suggested display name
+ emails_to_use: emails that the user would like to use
+ """
+ try:
+ session = self.get_mapping_session(session_id)
+ except SynapseError as e:
+ self.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ # update the session with the user's choices
+ session.chosen_localpart = localpart
+ session.use_display_name = use_display_name
+
+ emails_from_idp = set(session.emails)
+ filtered_emails = set() # type: Set[str]
+
+ # we iterate through the list rather than just building a set conjunction, so
+ # that we can log attempts to use unknown addresses
+ for email in emails_to_use:
+ if email in emails_from_idp:
+ filtered_emails.add(email)
+ else:
+ logger.warning(
+ "[session %s] ignoring user request to use unknown email address %r",
+ session_id,
+ email,
+ )
+ session.emails_to_use = filtered_emails
+
+ # we may now need to collect consent from the user, in which case, redirect
+ # to the consent-extraction-unit
+ if self._consent_at_registration:
+ redirect_url = b"/_synapse/client/new_user_consent"
+
+ # otherwise, redirect to the completion page
+ else:
+ redirect_url = b"/_synapse/client/sso_register"
+
+ respond_with_redirect(request, redirect_url)
+
+ async def handle_terms_accepted(
+ self, request: Request, session_id: str, terms_version: str
+ ):
+ """Handle a request to the new-user 'consent' endpoint
+
+ Will serve an HTTP response to the request.
+
+ Args:
+ request: HTTP request
+ session_id: ID of the username mapping session, extracted from a cookie
+ terms_version: the version of the terms which the user viewed and consented
+ to
+ """
+ logger.info(
+ "[session %s] User consented to terms version %s",
+ session_id,
+ terms_version,
+ )
+ try:
+ session = self.get_mapping_session(session_id)
+ except SynapseError as e:
+ self.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ session.terms_accepted_version = terms_version
+
+ # we're done; now we can register the user
+ respond_with_redirect(request, b"/_synapse/client/sso_register")
+
+ async def register_sso_user(self, request: Request, session_id: str) -> None:
+ """Called once we have all the info we need to register a new user.
+
+ Does so and serves an HTTP response
+
+ Args:
+ request: HTTP request
+ session_id: ID of the username mapping session, extracted from a cookie
+ """
+ try:
+ session = self.get_mapping_session(session_id)
+ except SynapseError as e:
+ self.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ logger.info(
+ "[session %s] Registering localpart %s",
+ session_id,
+ session.chosen_localpart,
+ )
+
+ attributes = UserAttributes(
+ localpart=session.chosen_localpart,
+ emails=session.emails_to_use,
+ )
+
+ if session.use_display_name:
+ attributes.display_name = session.display_name
+
+ # the following will raise a 400 error if the username has been taken in the
+ # meantime.
+ user_id = await self._register_mapped_user(
+ attributes,
+ session.auth_provider_id,
+ session.remote_user_id,
+ get_request_user_agent(request),
+ request.getClientIP(),
+ )
+
+ logger.info(
+ "[session %s] Registered userid %s with attributes %s",
+ session_id,
+ user_id,
+ attributes,
+ )
+
+ # delete the mapping session and the cookie
+ del self._username_mapping_sessions[session_id]
+
+ # delete the cookie
+ request.addCookie(
+ USERNAME_MAPPING_SESSION_COOKIE_NAME,
+ b"",
+ expires=b"Thu, 01 Jan 1970 00:00:00 GMT",
+ path=b"/",
+ )
+
+ auth_result = {}
+ if session.terms_accepted_version:
+ # TODO: make this less awful.
+ auth_result[LoginType.TERMS] = True
+
+ await self._registration_handler.post_registration_actions(
+ user_id, auth_result, access_token=None
+ )
+
+ await self._auth_handler.complete_sso_login(
+ user_id,
+ session.auth_provider_id,
+ request,
+ session.client_redirect_url,
+ session.extra_login_attributes,
+ new_user=True,
+ )
+
+ def _expire_old_sessions(self):
+ to_expire = []
+ now = int(self._clock.time_msec())
+
+ for session_id, session in self._username_mapping_sessions.items():
+ if session.expiry_time_ms <= now:
+ to_expire.append(session_id)
+
+ for session_id in to_expire:
+ logger.info("Expiring mapping session %s", session_id)
+ del self._username_mapping_sessions[session_id]
+
+ def check_required_attributes(
+ self,
+ request: SynapseRequest,
+ attributes: Mapping[str, List[Any]],
+ attribute_requirements: Iterable[SsoAttributeRequirement],
+ ) -> bool:
+ """
+ Confirm that the required attributes were present in the SSO response.
+
+ If all requirements are met, this will return True.
+
+ If any requirement is not met, then the request will be finalized by
+ showing an error page to the user and False will be returned.
+
+ Args:
+ request: The request to (potentially) respond to.
+ attributes: The attributes from the SSO IdP.
+ attribute_requirements: The requirements that attributes must meet.
+
+ Returns:
+ True if all requirements are met, False if any attribute fails to
+ meet the requirement.
+
+ """
+ # Ensure that the attributes of the logged in user meet the required
+ # attributes.
+ for requirement in attribute_requirements:
+ if not _check_attribute_requirement(attributes, requirement):
+ self.render_error(
+ request, "unauthorised", "You are not authorised to log in here."
+ )
+ return False
+
+ return True
+
+
+def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
+ """Extract the session ID from the cookie
+
+ Raises a SynapseError if the cookie isn't found
+ """
+ session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+ if not session_id:
+ raise SynapseError(code=400, msg="missing session_id")
+ return session_id.decode("ascii", errors="replace")
+
+
+def _check_attribute_requirement(
+ attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
+) -> bool:
+ """Check if SSO attributes meet the proper requirements.
+
+ Args:
+ attributes: A mapping of attributes to an iterable of one or more values.
+ requirement: The configured requirement to check.
+
+ Returns:
+ True if the required attribute was found and had a proper value.
+ """
+ if req.attribute not in attributes:
+ logger.info("SSO attribute missing: %s", req.attribute)
+ return False
+
+ # If the requirement is None, the attribute existing is enough.
+ if req.value is None:
+ return True
+
+ values = attributes[req.attribute]
+ if req.value in values:
+ return True
+
+ logger.info(
+ "SSO attribute %s did not match required value '%s' (was '%s')",
+ req.attribute,
+ req.value,
+ values,
+ )
+ return False
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index fb4f70e8e2..ee8f87e59a 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -14,15 +14,25 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class StateDeltasHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+ async def _get_key_change(
+ self,
+ prev_event_id: Optional[str],
+ event_id: Optional[str],
+ key_name: str,
+ public_value: str,
+ ) -> Optional[bool]:
"""Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so.
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 8c390c9f4d..4227f8c1a2 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -14,13 +14,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
+
+from typing_extensions import Counter as CounterType
from synapse.api.constants import EventTypes, Membership
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -33,7 +39,7 @@ class StatsHandler:
Heavily derived from UserDirectoryHandler
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@@ -46,7 +52,7 @@ class StatsHandler:
self.stats_enabled = hs.config.stats_enabled
# The current position in the current_state_delta stream
- self.pos = None
+ self.pos = None # type: Optional[int]
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@@ -58,9 +64,8 @@ class StatsHandler:
# we start populating stats
self.clock.call_later(0, self.notify_new_event)
- def notify_new_event(self):
- """Called when there may be more deltas to process
- """
+ def notify_new_event(self) -> None:
+ """Called when there may be more deltas to process"""
if not self.stats_enabled or self._is_processing:
return
@@ -74,7 +79,7 @@ class StatsHandler:
run_as_background_process("stats.notify_new_event", process)
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_stats_positions()
@@ -112,10 +117,10 @@ class StatsHandler:
)
for room_id, fields in room_count.items():
- room_deltas.setdefault(room_id, {}).update(fields)
+ room_deltas.setdefault(room_id, Counter()).update(fields)
for user_id, fields in user_count.items():
- user_deltas.setdefault(user_id, {}).update(fields)
+ user_deltas.setdefault(user_id, Counter()).update(fields)
logger.debug("room_deltas: %s", room_deltas)
logger.debug("user_deltas: %s", user_deltas)
@@ -133,19 +138,20 @@ class StatsHandler:
self.pos = max_pos
- async def _handle_deltas(self, deltas):
+ async def _handle_deltas(
+ self, deltas: Iterable[JsonDict]
+ ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
"""Called with the state deltas to process
Returns:
- tuple[dict[str, Counter], dict[str, counter]]
Two dicts: the room deltas and the user deltas,
mapping from room/user ID to changes in the various fields.
"""
- room_to_stats_deltas = {}
- user_to_stats_deltas = {}
+ room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
+ user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
- room_to_state_updates = {}
+ room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
for delta in deltas:
typ = delta["type"]
@@ -175,7 +181,7 @@ class StatsHandler:
)
continue
- event_content = {}
+ event_content = {} # type: JsonDict
sender = None
if event_id is not None:
@@ -263,13 +269,13 @@ class StatsHandler:
)
if has_changed_joinedness:
- delta = +1 if membership == Membership.JOIN else -1
+ membership_delta = +1 if membership == Membership.JOIN else -1
user_to_stats_deltas.setdefault(user_id, Counter())[
"joined_rooms"
- ] += delta
+ ] += membership_delta
- room_stats_delta["local_users_in_room"] += delta
+ room_stats_delta["local_users_in_room"] += membership_delta
elif typ == EventTypes.Create:
room_state["is_federatable"] = (
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index fe03004a01..17277619ad 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -80,7 +80,7 @@ class SyncConfig:
filter_collection = attr.ib(type=FilterCollection)
is_guest = attr.ib(type=bool)
request_key = attr.ib(type=Tuple[Any, ...])
- device_id = attr.ib(type=str)
+ device_id = attr.ib(type=Optional[str])
@attr.s(slots=True, frozen=True)
@@ -258,7 +258,7 @@ class SyncHandler:
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache(
- hs, "sync"
+ hs.get_clock(), "sync"
) # type: ResponseCache[Tuple[Any, ...]]
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
@@ -353,8 +353,7 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
- """Get the sync for client needed to match what the server has now.
- """
+ """Get the sync for client needed to match what the server has now."""
return await self.generate_sync_result(sync_config, since_token, full_state)
async def push_rules_for_user(self, user: UserID) -> JsonDict:
@@ -568,7 +567,7 @@ class SyncHandler:
event.event_id, state_filter=state_filter
)
if event.is_state():
- state_ids = state_ids.copy()
+ state_ids = dict(state_ids)
state_ids[(event.type, event.state_key)] = event.event_id
return state_ids
@@ -578,7 +577,7 @@ class SyncHandler:
stream_position: StreamToken,
state_filter: StateFilter = StateFilter.all(),
) -> StateMap[str]:
- """ Get the room state at a particular stream position
+ """Get the room state at a particular stream position
Args:
room_id: room for which to get state
@@ -612,7 +611,7 @@ class SyncHandler:
state: MutableStateMap[EventBase],
now_token: StreamToken,
) -> Optional[JsonDict]:
- """ Works out a room summary block for this room, summarising the number
+ """Works out a room summary block for this room, summarising the number
of joined members in the room, and providing the 'hero' members if the
room has no name so clients can consistently name rooms. Also adds
state events to 'state' if needed to describe the heroes.
@@ -738,7 +737,9 @@ class SyncHandler:
return summary
- def get_lazy_loaded_members_cache(self, cache_key: Tuple[str, str]) -> LruCache:
+ def get_lazy_loaded_members_cache(
+ self, cache_key: Tuple[str, Optional[str]]
+ ) -> LruCache:
cache = self.lazy_loaded_members_cache.get(cache_key)
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
@@ -757,7 +758,7 @@ class SyncHandler:
now_token: StreamToken,
full_state: bool,
) -> MutableStateMap[EventBase]:
- """ Works out the difference in state between the start of the timeline
+ """Works out the difference in state between the start of the timeline
and the previous sync.
Args:
@@ -834,8 +835,10 @@ class SyncHandler:
)
elif batch.limited:
if batch:
- state_at_timeline_start = await self.state_store.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter
+ state_at_timeline_start = (
+ await self.state_store.get_state_ids_for_event(
+ batch.events[0].event_id, state_filter=state_filter
+ )
)
else:
# We can get here if the user has ignored the senders of all
@@ -969,8 +972,7 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
- """Generates a sync result.
- """
+ """Generates a sync result."""
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
@@ -1046,8 +1048,8 @@ class SyncHandler:
one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
- unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
- user_id, device_id
+ unused_fallback_key_types = (
+ await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
logger.debug("Fetching group data")
@@ -1196,8 +1198,10 @@ class SyncHandler:
# weren't in the previous sync *or* they left and rejoined.
users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
- user_signatures_changed = await self.store.get_users_whose_signatures_changed(
- user_id, since_token.device_list_key
+ user_signatures_changed = (
+ await self.store.get_users_whose_signatures_changed(
+ user_id, since_token.device_list_key
+ )
)
users_that_have_changed.update(user_signatures_changed)
@@ -1413,8 +1417,10 @@ class SyncHandler:
logger.debug("no-oping sync")
return set(), set(), set(), set()
- ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
- AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
+ ignored_account_data = (
+ await self.store.get_global_account_data_by_type_for_user(
+ AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
+ )
)
# If there is ignored users account data and it matches the proper type,
@@ -1524,8 +1530,7 @@ class SyncHandler:
async def _get_rooms_changed(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
) -> _RoomChanges:
- """Gets the the changes that have happened since the last sync.
- """
+ """Gets the the changes that have happened since the last sync."""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
@@ -1741,7 +1746,11 @@ class SyncHandler:
room_entries.append(entry)
return _RoomChanges(
- room_entries, invited, knocked, newly_joined_rooms, newly_left_rooms,
+ room_entries,
+ invited,
+ knocked,
+ newly_joined_rooms,
+ newly_left_rooms,
)
async def _get_all_rooms(
@@ -2017,8 +2026,10 @@ class SyncHandler:
logger.info("User joined room after current token: %s", room_id)
- extrems = await self.store.get_forward_extremeties_for_room(
- room_id, event_pos.stream
+ extrems = (
+ await self.store.get_forward_extremities_for_room_at_stream_ordering(
+ room_id, event_pos.stream
+ )
)
users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
if user_id in users_in_room:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e919a8f9ed..096d199f4c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,13 +15,13 @@
import logging
import random
from collections import namedtuple
-from typing import TYPE_CHECKING, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -61,23 +61,23 @@ class FollowerTypingHandler:
if hs.config.worker.writers.typing != hs.get_instance_name():
hs.get_federation_registry().register_instance_for_edu(
- "m.typing", hs.config.worker.writers.typing,
+ "m.typing",
+ hs.config.worker.writers.typing,
)
# map room IDs to serial numbers
- self._room_serials = {}
+ self._room_serials = {} # type: Dict[str, int]
# map room IDs to sets of users currently typing
- self._room_typing = {}
+ self._room_typing = {} # type: Dict[str, Set[str]]
- self._member_last_federation_poke = {}
+ self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
self.clock.looping_call(self._handle_timeouts, 5000)
- def _reset(self):
- """Reset the typing handler's data caches.
- """
+ def _reset(self) -> None:
+ """Reset the typing handler's data caches."""
# map room IDs to serial numbers
self._room_serials = {}
# map room IDs to sets of users currently typing
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
- def _handle_timeouts(self):
+ def _handle_timeouts(self) -> None:
logger.debug("Checking for typing timeouts")
now = self.clock.time_msec()
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
for member in members:
self._handle_timeout_for_member(now, member)
- def _handle_timeout_for_member(self, now: int, member: RoomMember):
+ def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
if not self.is_typing(member):
# Nothing to do if they're no longer typing
return
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
# each person typing.
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
- def is_typing(self, member):
+ def is_typing(self, member: RoomMember) -> bool:
return member.user_id in self._room_typing.get(member.room_id, [])
- async def _push_remote(self, member, typing):
+ async def _push_remote(self, member: RoomMember, typing: bool) -> None:
if not self.federation:
return
@@ -148,9 +148,8 @@ class FollowerTypingHandler:
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
- ):
- """Should be called whenever we receive updates for typing stream.
- """
+ ) -> None:
+ """Should be called whenever we receive updates for typing stream."""
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
@@ -178,7 +177,7 @@ class FollowerTypingHandler:
async def _send_changes_in_typing_to_remotes(
self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
- ):
+ ) -> None:
"""Process a change in typing of a room from replication, sending EDUs
for any local users.
"""
@@ -194,12 +193,12 @@ class FollowerTypingHandler:
if self.is_mine_id(user_id):
await self._push_remote(RoomMember(room_id, user_id), False)
- def get_current_token(self):
+ def get_current_token(self) -> int:
return self._latest_room_serial
class TypingWriterHandler(FollowerTypingHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
assert hs.config.worker.writers.typing == hs.get_instance_name()
@@ -213,14 +212,15 @@ class TypingWriterHandler(FollowerTypingHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
- self._member_typing_until = {} # clock time we expect to stop
+ # clock time we expect to stop
+ self._member_typing_until = {} # type: Dict[RoomMember, int]
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", self._latest_room_serial
)
- def _handle_timeout_for_member(self, now: int, member: RoomMember):
+ def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
super()._handle_timeout_for_member(now, member)
if not self.is_typing(member):
@@ -233,7 +233,9 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
return
- async def started_typing(self, target_user, requester, room_id, timeout):
+ async def started_typing(
+ self, target_user: UserID, requester: Requester, room_id: str, timeout: int
+ ) -> None:
target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
@@ -263,11 +265,13 @@ class TypingWriterHandler(FollowerTypingHandler):
if was_present:
# No point sending another notification
- return None
+ return
self._push_update(member=member, typing=True)
- async def stopped_typing(self, target_user, requester, room_id):
+ async def stopped_typing(
+ self, target_user: UserID, requester: Requester, room_id: str
+ ) -> None:
target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
@@ -290,23 +294,23 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
- def user_left_room(self, user, room_id):
+ def user_left_room(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
if self.is_mine_id(user_id):
member = RoomMember(room_id=room_id, user_id=user_id)
self._stopped_typing(member)
- def _stopped_typing(self, member):
+ def _stopped_typing(self, member: RoomMember) -> None:
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
- return None
+ return
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
self._push_update(member=member, typing=False)
- def _push_update(self, member, typing):
+ def _push_update(self, member: RoomMember, typing: bool) -> None:
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
run_as_background_process(
@@ -315,7 +319,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self._push_update_local(member=member, typing=typing)
- async def _recv_edu(self, origin, content):
+ async def _recv_edu(self, origin: str, content: JsonDict) -> None:
room_id = content["room_id"]
user_id = content["user_id"]
@@ -340,7 +344,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
self._push_update_local(member=member, typing=content["typing"])
- def _push_update_local(self, member, typing):
+ def _push_update_local(self, member: RoomMember, typing: bool) -> None:
room_set = self._room_typing.setdefault(member.room_id, set())
if typing:
room_set.add(member.user_id)
@@ -386,7 +390,7 @@ class TypingWriterHandler(FollowerTypingHandler):
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
last_id
- )
+ ) # type: Optional[Iterable[str]]
if changed_rooms is None:
changed_rooms = self._room_serials
@@ -412,13 +416,13 @@ class TypingWriterHandler(FollowerTypingHandler):
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
- ):
+ ) -> None:
# The writing process should never get updates from replication.
raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
# We can't call get_typing_handler here because there's a cycle:
@@ -427,7 +431,7 @@ class TypingNotificationEventSource:
#
self.get_typing_handler = hs.get_typing_handler
- def _make_event_for(self, room_id):
+ def _make_event_for(self, room_id: str) -> JsonDict:
typing = self.get_typing_handler()._room_typing[room_id]
return {
"type": "m.typing",
@@ -462,7 +466,9 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
- async def get_new_events(self, from_key, room_ids, **kwargs):
+ async def get_new_events(
+ self, from_key: int, room_ids: Iterable[str], **kwargs
+ ) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
handler = self.get_typing_handler()
@@ -478,5 +484,5 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
- def get_current_key(self):
+ def get_current_key(self) -> int:
return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 824f37f8f8..a68d5e790e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here.
"""
from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
+
+
+class UIAuthSessionDataConstants:
+ """Constants for use with AuthHandler.set_session_data"""
+
+ # used during registration and password reset to store a hashed copy of the
+ # password, so that the client does not need to submit it each time.
+ PASSWORD_HASH = "password_hash"
+
+ # used during registration to store the mxid of the registered user
+ REGISTERED_USER_ID = "registered_user_id"
+
+ # used by validate_user_via_ui_auth to store the mxid of the user we are validating
+ # for.
+ REQUEST_USER_ID = "request_user_id"
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index afbebfc200..b121286d95 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -14,14 +14,19 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
import synapse.metrics
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
+from synapse.types import JsonDict
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -36,7 +41,7 @@ class UserDirectoryHandler(StateDeltasHandler):
be in the directory or not when necessary.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.store = hs.get_datastore()
@@ -49,7 +54,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.search_all_users = hs.config.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
- self.pos = None
+ self.pos = None # type: Optional[int]
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@@ -61,7 +66,9 @@ class UserDirectoryHandler(StateDeltasHandler):
# we start populating the user directory
self.clock.call_later(0, self.notify_new_event)
- async def search_users(self, user_id, search_term, limit):
+ async def search_users(
+ self, user_id: str, search_term: str, limit: int
+ ) -> JsonDict:
"""Searches for users in directory
Returns:
@@ -81,17 +88,16 @@ class UserDirectoryHandler(StateDeltasHandler):
results = await self.store.search_user_dir(user_id, search_term, limit)
# Remove any spammy users from the results.
- results["results"] = [
- user
- for user in results["results"]
- if not self.spam_checker.check_username_for_spam(user)
- ]
+ non_spammy_users = []
+ for user in results["results"]:
+ if not await self.spam_checker.check_username_for_spam(user):
+ non_spammy_users.append(user)
+ results["results"] = non_spammy_users
return results
- def notify_new_event(self):
- """Called when there may be more deltas to process
- """
+ def notify_new_event(self) -> None:
+ """Called when there may be more deltas to process"""
if not self.update_user_directory:
return
@@ -107,32 +113,37 @@ class UserDirectoryHandler(StateDeltasHandler):
self._is_processing = True
run_as_background_process("user_directory.notify_new_event", process)
- async def handle_local_profile_change(self, user_id, profile):
+ async def handle_local_profile_change(
+ self, user_id: str, profile: ProfileInfo
+ ) -> None:
"""Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in.
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
- is_support = await self.store.is_support_user(user_id)
+
# Support users are for diagnostics and should not appear in the user directory.
- if not is_support:
+ is_support = await self.store.is_support_user(user_id)
+ # When change profile information of deactivated user it should not appear in the user directory.
+ is_deactivated = await self.store.get_user_deactivated_status(user_id)
+
+ if not (is_support or is_deactivated):
await self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
- async def handle_user_deactivated(self, user_id):
- """Called when a user ID is deactivated
- """
+ async def handle_user_deactivated(self, user_id: str) -> None:
+ """Called when a user ID is deactivated"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
await self.store.remove_from_user_dir(user_id)
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos()
- # If still None then the initial background update hasn't happened yet
+ # If still None then the initial background update hasn't happened yet.
if self.pos is None:
return None
@@ -162,9 +173,8 @@ class UserDirectoryHandler(StateDeltasHandler):
await self.store.update_user_directory_stream_pos(max_pos)
- async def _handle_deltas(self, deltas):
- """Called with the state deltas to process
- """
+ async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
+ """Called with the state deltas to process"""
for delta in deltas:
typ = delta["type"]
state_key = delta["state_key"]
@@ -220,6 +230,11 @@ class UserDirectoryHandler(StateDeltasHandler):
if change: # The user joined
event = await self.store.get_event(event_id, allow_none=True)
+ # It isn't expected for this event to not exist, but we
+ # don't want the entire background process to break.
+ if event is None:
+ continue
+
profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"),
@@ -232,16 +247,20 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug("Ignoring irrelevant type: %r", typ)
async def _handle_room_publicity_change(
- self, room_id, prev_event_id, event_id, typ
- ):
+ self,
+ room_id: str,
+ prev_event_id: Optional[str],
+ event_id: Optional[str],
+ typ: str,
+ ) -> None:
"""Handle a room having potentially changed from/to world_readable/publicly
joinable.
Args:
- room_id (str)
- prev_event_id (str|None): The previous event before the state change
- event_id (str|None): The new event after the state change
- typ (str): Type of the event
+ room_id: The ID of the room which changed.
+ prev_event_id: The previous event before the state change
+ event_id: The new event after the state change
+ typ: Type of the event
"""
logger.debug("Handling change for %s: %s", typ, room_id)
@@ -250,7 +269,7 @@ class UserDirectoryHandler(StateDeltasHandler):
prev_event_id,
event_id,
key_name="history_visibility",
- public_value="world_readable",
+ public_value=HistoryVisibility.WORLD_READABLE,
)
elif typ == EventTypes.JoinRules:
change = await self._get_key_change(
@@ -299,12 +318,14 @@ class UserDirectoryHandler(StateDeltasHandler):
for user_id, profile in users_with_profile.items():
await self._handle_new_user(room_id, user_id, profile)
- async def _handle_new_user(self, room_id, user_id, profile):
+ async def _handle_new_user(
+ self, room_id: str, user_id: str, profile: ProfileInfo
+ ) -> None:
"""Called when we might need to add user to directory
Args:
- room_id (str): room_id that user joined or started being public
- user_id (str)
+ room_id: The room ID that user joined or started being public
+ user_id
"""
logger.debug("Adding new user to dir, %r", user_id)
@@ -352,12 +373,12 @@ class UserDirectoryHandler(StateDeltasHandler):
if to_insert:
await self.store.add_users_who_share_private_room(room_id, to_insert)
- async def _handle_remove_user(self, room_id, user_id):
+ async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
"""Called when we might need to remove user from directory
Args:
- room_id (str): room_id that user left or stopped being public that
- user_id (str)
+ room_id: The room ID that user left or stopped being public that
+ user_id
"""
logger.debug("Removing user %r", user_id)
@@ -370,7 +391,13 @@ class UserDirectoryHandler(StateDeltasHandler):
if len(rooms_user_is_in) == 0:
await self.store.remove_from_user_dir(user_id)
- async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
+ async def _handle_profile_change(
+ self,
+ user_id: str,
+ room_id: str,
+ prev_event_id: Optional[str],
+ event_id: Optional[str],
+ ) -> None:
"""Check member event changes for any profile changes and update the
database if there are.
"""
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 59b01b812c..142b007d01 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
+from typing import Union
-from twisted.internet import task
+from twisted.internet import address, task
from twisted.web.client import FileBodyProducer
+from twisted.web.iweb import IRequest
from synapse.api.errors import SynapseError
@@ -50,3 +52,50 @@ class QuieterFileBodyProducer(FileBodyProducer):
FileBodyProducer.stopProducing(self)
except task.TaskStopped:
pass
+
+
+def get_request_uri(request: IRequest) -> bytes:
+ """Return the full URI that was requested by the client"""
+ return b"%s://%s%s" % (
+ b"https" if request.isSecure() else b"http",
+ _get_requested_host(request),
+ # despite its name, "request.uri" is only the path and query-string.
+ request.uri,
+ )
+
+
+def _get_requested_host(request: IRequest) -> bytes:
+ hostname = request.getHeader(b"host")
+ if hostname:
+ return hostname
+
+ # no Host header, use the address/port that the request arrived on
+ host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address]
+
+ hostname = host.host.encode("ascii")
+
+ if request.isSecure() and host.port == 443:
+ # default port for https
+ return hostname
+
+ if not request.isSecure() and host.port == 80:
+ # default port for http
+ return hostname
+
+ return b"%s:%i" % (
+ hostname,
+ host.port,
+ )
+
+
+def get_request_user_agent(request: IRequest, default: str = "") -> str:
+ """Return the last User-Agent header, or the given default."""
+ # There could be raw utf-8 bytes in the User-Agent header.
+
+ # N.B. if you don't do this, the logger explodes cryptically
+ # with maximum recursion trying to log errors about
+ # the charset problem.
+ # c.f. https://github.com/matrix-org/synapse/issues/3471
+
+ h = request.getHeader(b"User-Agent")
+ return h.decode("ascii", "replace") if h else default
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 960d57fb27..a0caba84e4 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -32,19 +32,22 @@ from typing import (
import treq
from canonicaljson import encode_canonical_json
-from netaddr import IPAddress, IPSet
+from netaddr import AddrFormatError, IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
+from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
+ ITCPTransport,
)
+from twisted.internet.protocol import connectionDone
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
@@ -56,18 +59,25 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import (
+ UNKNOWN_LENGTH,
+ IAgent,
+ IBodyProducer,
+ IPolicyForHTTPS,
+ IResponse,
+)
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.types import ISynapseReactor
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -150,16 +160,17 @@ class _IPBlacklistingResolver:
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:
-
- r = recv()
addresses = [] # type: List[IAddress]
def _callback() -> None:
- r.resolutionBegan(None)
-
has_bad_ip = False
- for i in addresses:
- ip_address = IPAddress(i.host)
+ for address in addresses:
+ # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
+ # should go through this path.
+ if not isinstance(address, (IPv4Address, IPv6Address)):
+ continue
+
+ ip_address = IPAddress(address.host)
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
@@ -174,15 +185,15 @@ class _IPBlacklistingResolver:
# request, but all we can really do from here is claim that there were no
# valid results.
if not has_bad_ip:
- for i in addresses:
- r.addressResolved(i)
- r.resolutionComplete()
+ for address in addresses:
+ recv.addressResolved(address)
+ recv.resolutionComplete()
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
- pass
+ recv.resolutionBegan(resolutionInProgress)
@staticmethod
def addressResolved(address: IAddress) -> None:
@@ -196,10 +207,10 @@ class _IPBlacklistingResolver:
EndpointReceiver, hostname, portNumber=portNumber
)
- return r
+ return recv
-@implementer(IReactorPluggableNameResolver)
+@implementer(ISynapseReactor)
class BlacklistingReactorWrapper:
"""
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
@@ -261,16 +272,16 @@ class BlacklistingAgentWrapper(Agent):
try:
ip_address = IPAddress(h.hostname)
-
+ except AddrFormatError:
+ # Not an IP
+ pass
+ else:
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info("Blocking access to %s due to blacklist" % (ip_address,))
e = SynapseError(403, "IP address blocked by IP blacklist entry")
return defer.fail(Failure(e))
- except Exception:
- # Not an IP
- pass
return self._agent.request(
method, uri, headers=headers, bodyProducer=bodyProducer
@@ -324,7 +335,7 @@ class SimpleHttpClient:
# filters out blacklisted IP addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
- )
+ ) # type: ISynapseReactor
else:
self.reactor = hs.get_reactor()
@@ -345,7 +356,7 @@ class SimpleHttpClient:
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
use_proxy=use_proxy,
- )
+ ) # type: IAgent
if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent
@@ -396,7 +407,8 @@ class SimpleHttpClient:
body_producer = None
if data is not None:
body_producer = QuieterFileBodyProducer(
- BytesIO(data), cooperator=self._cooperator,
+ BytesIO(data),
+ cooperator=self._cooperator,
)
request_deferred = treq.request(
@@ -405,13 +417,18 @@ class SimpleHttpClient:
agent=self.agent,
data=body_producer,
headers=headers,
+ # Avoid buffering the body in treq since we do not reuse
+ # response bodies.
+ unbuffered=True,
**self._extra_treq_args,
) # type: defer.Deferred
# we use our own timeout mechanism rather than treq's as a workaround
# for https://twistedmatrix.com/trac/ticket/9534.
request_deferred = timeout_deferred(
- request_deferred, 60, self.hs.get_reactor(),
+ request_deferred,
+ 60,
+ self.hs.get_reactor(),
)
# turn timeouts into RequestTimedOutErrors
@@ -697,18 +714,6 @@ class SimpleHttpClient:
resp_headers = dict(response.headers.getAllRawHeaders())
- if (
- b"Content-Length" in resp_headers
- and max_size
- and int(resp_headers[b"Content-Length"][0]) > max_size
- ):
- logger.warning("Requested URL is too large > %r bytes" % (max_size,))
- raise SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (max_size,),
- Codes.TOO_LARGE,
- )
-
if response.code > 299:
logger.warning("Got %d when downloading %s" % (response.code, url))
raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
@@ -719,11 +724,14 @@ class SimpleHttpClient:
try:
length = await make_deferred_yieldable(
- readBodyToFile(response, output_stream, max_size)
+ read_body_with_max_size(response, output_stream, max_size)
+ )
+ except BodyExceededMaxSize:
+ raise SynapseError(
+ 502,
+ "Requested file is too large > %r bytes" % (max_size,),
+ Codes.TOO_LARGE,
)
- except SynapseError:
- # This can happen e.g. because the body is too large.
- raise
except Exception as e:
raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
@@ -747,7 +755,41 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
-class _ReadBodyToFileProtocol(protocol.Protocol):
+class BodyExceededMaxSize(Exception):
+ """The maximum allowed size of the HTTP body was exceeded."""
+
+
+class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
+ """A protocol which immediately errors upon receiving data."""
+
+ transport = None # type: Optional[ITCPTransport]
+
+ def __init__(self, deferred: defer.Deferred):
+ self.deferred = deferred
+
+ def _maybe_fail(self):
+ """
+ Report a max size exceed error and disconnect the first time this is called.
+ """
+ if not self.deferred.called:
+ self.deferred.errback(BodyExceededMaxSize())
+ # Close the connection (forcefully) since all the data will get
+ # discarded anyway.
+ assert self.transport is not None
+ self.transport.abortConnection()
+
+ def dataReceived(self, data: bytes) -> None:
+ self._maybe_fail()
+
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
+ self._maybe_fail()
+
+
+class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
+ """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
+
+ transport = None # type: Optional[ITCPTransport]
+
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@@ -757,20 +799,27 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.max_size = max_size
def dataReceived(self, data: bytes) -> None:
+ # If the deferred was called, bail early.
+ if self.deferred.called:
+ return
+
self.stream.write(data)
self.length += len(data)
+ # The first time the maximum size is exceeded, error and cancel the
+ # connection. dataReceived might be called again if data was received
+ # in the meantime.
if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(
- SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- )
- )
- self.deferred = defer.Deferred()
- self.transport.loseConnection()
+ self.deferred.errback(BodyExceededMaxSize())
+ # Close the connection (forcefully) since all the data will get
+ # discarded anyway.
+ assert self.transport is not None
+ self.transport.abortConnection()
+
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
+ # If the maximum size was already exceeded, there's nothing to do.
+ if self.deferred.called:
+ return
- def connectionLost(self, reason: Failure) -> None:
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
@@ -781,12 +830,15 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
-def readBodyToFile(
+def read_body_with_max_size(
response: IResponse, stream: BinaryIO, max_size: Optional[int]
) -> defer.Deferred:
"""
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
+ If the maximum file size is reached, the returned Deferred will resolve to a
+ Failure with a BodyExceededMaxSize exception.
+
Args:
response: The HTTP response to read from.
stream: The file-object to write to.
@@ -795,9 +847,16 @@ def readBodyToFile(
Returns:
A Deferred which resolves to the length of the read body.
"""
-
d = defer.Deferred()
- response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
+
+ # If the Content-Length header gives a size larger than the maximum allowed
+ # size, do not bother downloading the body.
+ if max_size is not None and response.length != UNKNOWN_LENGTH:
+ if response.length > max_size:
+ response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
+ return d
+
+ response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
return d
@@ -825,6 +884,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
return query_str.encode("utf8")
+@implementer(IPolicyForHTTPS)
class InsecureInterceptableContextFactory(ssl.ContextFactory):
"""
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
deleted file mode 100644
index 92a5b606c8..0000000000
--- a/synapse/http/endpoint.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import logging
-import re
-
-logger = logging.getLogger(__name__)
-
-
-def parse_server_name(server_name):
- """Split a server name into host/port parts.
-
- Args:
- server_name (str): server name to parse
-
- Returns:
- Tuple[str, int|None]: host/port parts.
-
- Raises:
- ValueError if the server name could not be parsed.
- """
- try:
- if server_name[-1] == "]":
- # ipv6 literal, hopefully
- return server_name, None
-
- domain_port = server_name.rsplit(":", 1)
- domain = domain_port[0]
- port = int(domain_port[1]) if domain_port[1:] else None
- return domain, port
- except Exception:
- raise ValueError("Invalid server name '%s'" % server_name)
-
-
-VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
-
-
-def parse_and_validate_server_name(server_name):
- """Split a server name into host/port parts and do some basic validation.
-
- Args:
- server_name (str): server name to parse
-
- Returns:
- Tuple[str, int|None]: host/port parts.
-
- Raises:
- ValueError if the server name could not be parsed.
- """
- host, port = parse_server_name(server_name)
-
- # these tests don't need to be bulletproof as we'll find out soon enough
- # if somebody is giving us invalid data. What we *do* need is to be sure
- # that nobody is sneaking IP literals in that look like hostnames, etc.
-
- # look for ipv6 literals
- if host[0] == "[":
- if host[-1] != "]":
- raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
- return host, port
-
- # otherwise it should only be alphanumerics.
- if not VALID_HOST_REGEX.match(host):
- raise ValueError(
- "Server name '%s' contains invalid characters" % (server_name,)
- )
-
- return host, port
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 3b756a7dc2..5935a125fd 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
-from typing import List, Optional
+from typing import Any, Generator, List, Optional
from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer
@@ -35,6 +35,7 @@ from synapse.http.client import BlacklistingAgentWrapper
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import ISynapseReactor
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -68,7 +69,7 @@ class MatrixFederationAgent:
def __init__(
self,
- reactor: IReactorCore,
+ reactor: ISynapseReactor,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
ip_blacklist: IPSet,
@@ -102,7 +103,6 @@ class MatrixFederationAgent:
pool=self._pool,
contextFactory=tls_client_options_factory,
),
- self._reactor,
ip_blacklist=ip_blacklist,
),
user_agent=self.user_agent,
@@ -117,7 +117,7 @@ class MatrixFederationAgent:
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
- ) -> defer.Deferred:
+ ) -> Generator[defer.Deferred, Any, defer.Deferred]:
"""
Args:
method: HTTP method: GET/POST/etc
@@ -178,17 +178,17 @@ class MatrixFederationAgent:
# We need to make sure the host header is set to the netloc of the
# server and that a user-agent is provided.
if headers is None:
- headers = Headers()
+ request_headers = Headers()
else:
- headers = headers.copy()
+ request_headers = headers.copy()
- if not headers.hasHeader(b"host"):
- headers.addRawHeader(b"host", parsed_uri.netloc)
- if not headers.hasHeader(b"user-agent"):
- headers.addRawHeader(b"user-agent", self.user_agent)
+ if not request_headers.hasHeader(b"host"):
+ request_headers.addRawHeader(b"host", parsed_uri.netloc)
+ if not request_headers.hasHeader(b"user-agent"):
+ request_headers.addRawHeader(b"user-agent", self.user_agent)
res = yield make_deferred_yieldable(
- self._agent.request(method, uri, headers, bodyProducer)
+ self._agent.request(method, uri, request_headers, bodyProducer)
)
return res
@@ -196,8 +196,7 @@ class MatrixFederationAgent:
@implementer(IAgentEndpointFactory)
class MatrixHostnameEndpointFactory:
- """Factory for MatrixHostnameEndpoint for parsing to an Agent.
- """
+ """Factory for MatrixHostnameEndpoint for parsing to an Agent."""
def __init__(
self,
@@ -262,8 +261,7 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
- """Implements IStreamClientEndpoint interface
- """
+ """Implements IStreamClientEndpoint interface"""
return run_in_background(self._do_connect, protocol_factory)
@@ -324,12 +322,19 @@ class MatrixHostnameEndpoint:
if port or _is_ip_literal(host):
return [Server(host, port or 8448)]
+ logger.debug("Looking up SRV record for %s", host.decode(errors="replace"))
server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
if server_list:
+ logger.debug(
+ "Got %s from SRV lookup for %s",
+ ", ".join(map(str, server_list)),
+ host.decode(errors="replace"),
+ )
return server_list
# No SRV records, so we fallback to host and 8448
+ logger.debug("No SRV records for %s", host.decode(errors="replace"))
return [Server(host, 8448)]
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 5e08ef1664..ce4079f15c 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -15,17 +15,19 @@
import logging
import random
import time
+from io import BytesIO
from typing import Callable, Dict, Optional, Tuple
import attr
from twisted.internet import defer
from twisted.internet.interfaces import IReactorTime
-from twisted.web.client import RedirectAgent, readBody
+from twisted.web.client import RedirectAgent
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IResponse
+from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
from synapse.util.caches.ttlcache import TTLCache
@@ -53,6 +55,9 @@ WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
# lower bound for .well-known cache period
WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60
+# The maximum size (in bytes) to allow a well-known file to be.
+WELL_KNOWN_MAX_SIZE = 50 * 1024 # 50 KiB
+
# Attempt to refetch a cached well-known N% of the TTL before it expires.
# e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then
# we'll start trying to refetch 1 minute before it expires.
@@ -66,8 +71,10 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3
logger = logging.getLogger(__name__)
-_well_known_cache = TTLCache("well-known")
-_had_valid_well_known_cache = TTLCache("had-valid-well-known")
+_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]]
+_had_valid_well_known_cache = TTLCache(
+ "had-valid-well-known"
+) # type: TTLCache[bytes, bool]
@attr.s(slots=True, frozen=True)
@@ -76,16 +83,15 @@ class WellKnownLookupResult:
class WellKnownResolver:
- """Handles well-known lookups for matrix servers.
- """
+ """Handles well-known lookups for matrix servers."""
def __init__(
self,
reactor: IReactorTime,
agent: IAgent,
user_agent: bytes,
- well_known_cache: Optional[TTLCache] = None,
- had_well_known_cache: Optional[TTLCache] = None,
+ well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None,
+ had_well_known_cache: Optional[TTLCache[bytes, bool]] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@@ -229,6 +235,9 @@ class WellKnownResolver:
server_name: name of the server, from the requested url
retry: Whether to retry the request if it fails.
+ Raises:
+ _FetchWellKnownFailure if we fail to lookup a result
+
Returns:
Returns the response object and body. Response may be a non-200 response.
"""
@@ -250,7 +259,11 @@ class WellKnownResolver:
b"GET", uri, headers=Headers(headers)
)
)
- body = await make_deferred_yieldable(readBody(response))
+ body_stream = BytesIO()
+ await make_deferred_yieldable(
+ read_body_with_max_size(response, body_stream, WELL_KNOWN_MAX_SIZE)
+ )
+ body = body_stream.getvalue()
if 500 <= response.code < 600:
raise Exception("Non-200 response %s" % (response.code,))
@@ -259,6 +272,15 @@ class WellKnownResolver:
except defer.CancelledError:
# Bail if we've been cancelled
raise
+ except BodyExceededMaxSize:
+ # If the well-known file was too large, do not keep attempting
+ # to download it, but consider it a temporary error.
+ logger.warning(
+ "Requested .well-known file for %s is too large > %r bytes",
+ server_name.decode("ascii"),
+ WELL_KNOWN_MAX_SIZE,
+ )
+ raise _FetchWellKnownFailure(temporary=True)
except Exception as e:
if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS:
logger.info("Error fetching %s: %s", uri_str, e)
@@ -302,7 +324,8 @@ def _cache_period_from_headers(
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
cache_controls = {}
- for hdr in headers.getRawHeaders(b"cache-control", []):
+ cache_control_headers = headers.getRawHeaders(b"cache-control") or []
+ for hdr in cache_control_headers:
for directive in hdr.split(b","):
splits = [x.strip() for x in directive.split(b"=", 1)]
k = splits[0].lower()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c962994727..5f01ebd3d4 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -37,16 +37,19 @@ from twisted.web.iweb import IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
+ Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
+ SynapseError,
)
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
+ BodyExceededMaxSize,
encode_query_args,
- readBodyToFile,
+ read_body_with_max_size,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable
@@ -56,7 +59,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
-from synapse.types import JsonDict
+from synapse.types import ISynapseReactor, JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -171,6 +174,16 @@ async def _handle_json_response(
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = await make_deferred_yieldable(d)
+ except ValueError as e:
+ # The JSON content was invalid.
+ logger.warning(
+ "{%s} [%s] Failed to parse JSON response - %s %s",
+ request.txn_id,
+ request.destination,
+ request.method,
+ request.uri.decode("ascii"),
+ )
+ raise RequestSendFailed(e, can_retry=False) from e
except defer.TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response - %s %s",
@@ -224,14 +237,14 @@ class MatrixFederationHttpClient:
# addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
- )
+ ) # type: ISynapseReactor
user_agent = hs.version_string
if hs.config.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
user_agent = user_agent.encode("ascii")
- self.agent = MatrixFederationAgent(
+ federation_agent = MatrixFederationAgent(
self.reactor,
tls_client_options_factory,
user_agent,
@@ -241,7 +254,8 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
+ federation_agent,
+ ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
@@ -520,9 +534,10 @@ class MatrixFederationHttpClient:
response.code, response_phrase, body
)
- # Retry if the error is a 429 (Too Many Requests),
- # otherwise just raise a standard HttpResponseException
- if response.code == 429:
+ # Retry if the error is a 5xx or a 429 (Too Many
+ # Requests), otherwise just raise a standard
+ # `HttpResponseException`
+ if 500 <= response.code < 600 or response.code == 429:
raise RequestSendFailed(exc, can_retry=True) from exc
else:
raise exc
@@ -639,7 +654,7 @@ class MatrixFederationHttpClient:
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
) -> Union[JsonDict, list]:
- """ Sends the specified json data using PUT
+ """Sends the specified json data using PUT
Args:
destination: The remote server to send the HTTP request to.
@@ -727,7 +742,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
args: Optional[QueryArgs] = None,
) -> Union[JsonDict, list]:
- """ Sends the specified json data using POST
+ """Sends the specified json data using POST
Args:
destination: The remote server to send the HTTP request to.
@@ -786,7 +801,11 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
body = await _handle_json_response(
- self.reactor, _sec_timeout, request, response, start_ms,
+ self.reactor,
+ _sec_timeout,
+ request,
+ response,
+ start_ms,
)
return body
@@ -800,7 +819,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
) -> Union[JsonDict, list]:
- """ GETs some json from the given host homeserver and path
+ """GETs some json from the given host homeserver and path
Args:
destination: The remote server to send the HTTP request to.
@@ -975,9 +994,18 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
try:
- d = readBodyToFile(response, output_stream, max_size)
+ d = read_body_with_max_size(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d)
+ except BodyExceededMaxSize:
+ msg = "Requested file is too large > %r bytes" % (max_size,)
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(502, msg, Codes.TOO_LARGE)
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",
@@ -1022,14 +1050,14 @@ def check_content_type_is_json(headers: Headers) -> None:
RequestSendFailed: if the Content-Type header is missing or isn't JSON
"""
- c_type = headers.getRawHeaders(b"Content-Type")
- if c_type is None:
+ content_type_headers = headers.getRawHeaders(b"Content-Type")
+ if content_type_headers is None:
raise RequestSendFailed(
RuntimeError("No Content-Type header received from remote server"),
can_retry=False,
)
- c_type = c_type[0].decode("ascii") # only the first header
+ c_type = content_type_headers[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RequestSendFailed(
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index ee65a6668b..16ec850064 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -179,7 +179,8 @@ class ProxyAgent(_AgentBase):
should_skip_proxy = False
if self.no_proxy is not None:
should_skip_proxy = proxy_bypass_environment(
- parsed_uri.host.decode(), proxies={"no": self.no_proxy},
+ parsed_uri.host.decode(),
+ proxies={"no": self.no_proxy},
)
if (
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 7c5defec82..0ec5d941b8 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -213,8 +213,7 @@ class RequestMetrics:
self.update_metrics()
def update_metrics(self):
- """Updates the in flight metrics with values from this request.
- """
+ """Updates the in flight metrics with values from this request."""
new_stats = self.start_context.get_resource_usage()
diff = new_stats - self._request_stats
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 6a4e429a6c..fa89260850 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -21,11 +21,25 @@ import logging
import types
import urllib
from http import HTTPStatus
+from inspect import isawaitable
from io import BytesIO
-from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Pattern,
+ Tuple,
+ Union,
+)
import jinja2
from canonicaljson import iterencode_canonical_json
+from typing_extensions import Protocol
from zope.interface import implementer
from twisted.internet import defer, interfaces
@@ -64,14 +78,15 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
- """Sends a JSON error response to clients.
- """
+ """Sends a JSON error response to clients."""
if f.check(SynapseError):
- error_code = f.value.code
- error_dict = f.value.error_dict()
+ # mypy doesn't understand that f.check asserts the type.
+ exc = f.value # type: SynapseError # type: ignore
+ error_code = exc.code
+ error_dict = exc.error_dict()
- logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
+ logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
else:
error_code = 500
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
@@ -80,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
# Only respond with an error response if we haven't already started writing,
@@ -94,12 +109,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
pass
else:
respond_with_json(
- request, error_code, error_dict, send_cors=True,
+ request,
+ error_code,
+ error_dict,
+ send_cors=True,
)
def return_html_error(
- f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template],
+ f: failure.Failure,
+ request: Request,
+ error_template: Union[str, jinja2.Template],
) -> None:
"""Sends an HTML error page corresponding to the given failure.
@@ -112,7 +132,8 @@ def return_html_error(
`{msg}` placeholders), or a jinja2 template
"""
if f.check(CodeMessageException):
- cme = f.value
+ # mypy doesn't understand that f.check asserts the type.
+ cme = f.value # type: CodeMessageException # type: ignore
code = cme.code
msg = cme.msg
@@ -126,7 +147,7 @@ def return_html_error(
logger.error(
"Failed handle request %r",
request,
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
else:
code = HTTPStatus.INTERNAL_SERVER_ERROR
@@ -135,7 +156,7 @@ def return_html_error(
logger.error(
"Failed handle request %r",
request,
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
if isinstance(error_template, str):
@@ -168,24 +189,39 @@ def wrap_async_request_handler(h):
return preserve_fn(wrapped_async_request_handler)
-class HttpServer:
- """ Interface for registering callbacks on a HTTP server
- """
+# Type of a callback method for processing requests
+# it is actually called with a SynapseRequest and a kwargs dict for the params,
+# but I can't figure out how to represent that.
+ServletCallback = Callable[
+ ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]]
+]
+
- def register_paths(self, method, path_patterns, callback):
- """ Register a callback that gets fired if we receive a http request
+class HttpServer(Protocol):
+ """Interface for registering callbacks on a HTTP server"""
+
+ def register_paths(
+ self,
+ method: str,
+ path_patterns: Iterable[Pattern],
+ callback: ServletCallback,
+ servlet_classname: str,
+ ) -> None:
+ """Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex.
If the regex contains groups these gets passed to the callback via
an unpacked tuple.
Args:
- method (str): The method to listen to.
- path_patterns (list<SRE_Pattern>): The regex used to match requests.
- callback (function): The function to fire if we receive a matched
+ method: The HTTP method to listen to.
+ path_patterns: The regex used to match requests.
+ callback: The function to fire if we receive a matched
request. The first argument will be the request object and
subsequent arguments will be any matched groups from the regex.
- This should return a tuple of (code, response).
+ This should return either tuple of (code, response), or None.
+ servlet_classname (str): The name of the handler to be used in prometheus
+ and opentracing logs.
"""
pass
@@ -207,8 +243,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
self._extract_context = extract_context
def render(self, request):
- """ This gets called by twisted every time someone sends us a request.
- """
+ """This gets called by twisted every time someone sends us a request."""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET
@@ -248,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
raw_callback_return = method_handler(request)
# Is it synchronous? We'll allow this for now.
- if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+ if isawaitable(raw_callback_return):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return # type: ignore
@@ -259,13 +294,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
@abc.abstractmethod
def _send_response(
- self, request: SynapseRequest, code: int, response_object: Any,
+ self,
+ request: SynapseRequest,
+ code: int,
+ response_object: Any,
) -> None:
raise NotImplementedError()
@abc.abstractmethod
def _send_error_response(
- self, f: failure.Failure, request: SynapseRequest,
+ self,
+ f: failure.Failure,
+ request: SynapseRequest,
) -> None:
raise NotImplementedError()
@@ -275,11 +315,17 @@ class DirectServeJsonResource(_AsyncResource):
formatting responses and errors as JSON.
"""
+ def __init__(self, canonical_json=False, extract_context=False):
+ super().__init__(extract_context)
+ self.canonical_json = canonical_json
+
def _send_response(
- self, request: Request, code: int, response_object: Any,
+ self,
+ request: Request,
+ code: int,
+ response_object: Any,
):
- """Implements _AsyncResource._send_response
- """
+ """Implements _AsyncResource._send_response"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request,
@@ -290,15 +336,16 @@ class DirectServeJsonResource(_AsyncResource):
)
def _send_error_response(
- self, f: failure.Failure, request: SynapseRequest,
+ self,
+ f: failure.Failure,
+ request: SynapseRequest,
) -> None:
- """Implements _AsyncResource._send_error_response
- """
+ """Implements _AsyncResource._send_error_response"""
return_json_error(f, request)
class JsonResource(DirectServeJsonResource):
- """ This implements the HttpServer interface and provides JSON support for
+ """This implements the HttpServer interface and provides JSON support for
Resources.
Register callbacks via register_paths()
@@ -318,9 +365,7 @@ class JsonResource(DirectServeJsonResource):
)
def __init__(self, hs, canonical_json=True, extract_context=False):
- super().__init__(extract_context)
-
- self.canonical_json = canonical_json
+ super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
self.path_regexs = {}
self.hs = hs
@@ -352,15 +397,17 @@ class JsonResource(DirectServeJsonResource):
def _get_handler_for_request(
self, request: SynapseRequest
- ) -> Tuple[Callable, str, Dict[str, str]]:
+ ) -> Tuple[ServletCallback, str, Dict[str, str]]:
"""Finds a callback method to handle the given request.
Returns:
A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback
"""
+ # At this point the path must be bytes.
+ request_path_bytes = request.path # type: bytes # type: ignore
+ request_path = request_path_bytes.decode("ascii")
# Treat HEAD requests as GET requests.
- request_path = request.path.decode("ascii")
request_method = request.method
if request_method == b"HEAD":
request_method = b"GET"
@@ -413,10 +460,12 @@ class DirectServeHtmlResource(_AsyncResource):
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
def _send_response(
- self, request: SynapseRequest, code: int, response_object: Any,
+ self,
+ request: SynapseRequest,
+ code: int,
+ response_object: Any,
):
- """Implements _AsyncResource._send_response
- """
+ """Implements _AsyncResource._send_response"""
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
html_bytes = response_object
@@ -424,10 +473,11 @@ class DirectServeHtmlResource(_AsyncResource):
respond_with_html_bytes(request, 200, html_bytes)
def _send_error_response(
- self, f: failure.Failure, request: SynapseRequest,
+ self,
+ f: failure.Failure,
+ request: SynapseRequest,
) -> None:
- """Implements _AsyncResource._send_error_response
- """
+ """Implements _AsyncResource._send_error_response"""
return_html_error(f, request, self.ERROR_TEMPLATE)
@@ -504,9 +554,11 @@ class _ByteProducer:
min_chunk_size = 1024
def __init__(
- self, request: Request, iterator: Iterator[bytes],
+ self,
+ request: Request,
+ iterator: Iterator[bytes],
):
- self._request = request
+ self._request = request # type: Optional[Request]
self._iterator = iterator
self._paused = False
@@ -518,7 +570,7 @@ class _ByteProducer:
"""
Send a list of bytes as a chunk of a response.
"""
- if not data:
+ if not data or not self._request:
return
self._request.write(b"".join(data))
@@ -624,7 +676,10 @@ def respond_with_json(
def respond_with_json_bytes(
- request: Request, code: int, json_bytes: bytes, send_cors: bool = False,
+ request: Request,
+ code: int,
+ json_bytes: bytes,
+ send_cors: bool = False,
):
"""Sends encoded JSON in response to the given request.
@@ -731,8 +786,15 @@ def set_clickjacking_protection_headers(request: Request):
request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")
+def respond_with_redirect(request: Request, url: bytes) -> None:
+ """Write a 302 response to the request, if it is still alive."""
+ logger.debug("Redirect to %s", url.decode("utf-8"))
+ request.redirect(url)
+ finish_request(request)
+
+
def finish_request(request: Request):
- """ Finish writing the response to the request.
+ """Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the
response was written but doesn't provide a convenient or reliable way to
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 9bfe151b5f..839d58d0d4 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -309,7 +309,7 @@ def assert_params_in_dict(body, required):
class RestServlet:
- """ A Synapse REST Servlet.
+ """A Synapse REST Servlet.
An implementing class can either provide its own custom 'register' method,
or use the automatic pattern handling provided by the base class.
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5f0581dc3f..47754aff43 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,13 +14,17 @@
import contextlib
import logging
import time
-from typing import Optional, Union
+from typing import Optional, Type, Union
+import attr
+from zope.interface import implementer
+
+from twisted.internet.interfaces import IAddress
from twisted.python.failure import Failure
from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
-from synapse.http import redact_uri
+from synapse.http import get_request_user_agent, redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.types import Requester
@@ -53,7 +57,7 @@ class SynapseRequest(Request):
def __init__(self, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
- self.site = channel.site
+ self.site = channel.site # type: SynapseSite
self._channel = channel # this is used by the tests
self.start_time = 0.0
@@ -92,44 +96,43 @@ class SynapseRequest(Request):
def get_request_id(self):
return "%s-%i" % (self.get_method(), self.request_seq)
- def get_redacted_uri(self):
- uri = self.uri
+ def get_redacted_uri(self) -> str:
+ """Gets the redacted URI associated with the request (or placeholder if the URI
+ has not yet been received).
+
+ Note: This is necessary as the placeholder value in twisted is str
+ rather than bytes, so we need to sanitise `self.uri`.
+
+ Returns:
+ The redacted URI as a string.
+ """
+ uri = self.uri # type: Union[bytes, str]
if isinstance(uri, bytes):
- uri = self.uri.decode("ascii", errors="replace")
+ uri = uri.decode("ascii", errors="replace")
return redact_uri(uri)
- def get_method(self):
- """Gets the method associated with the request (or placeholder if not
- method has yet been received).
+ def get_method(self) -> str:
+ """Gets the method associated with the request (or placeholder if method
+ has not yet been received).
Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.method`.
Returns:
- str
+ The request method as a string.
"""
- method = self.method
+ method = self.method # type: Union[bytes, str]
if isinstance(method, bytes):
- method = self.method.decode("ascii")
+ return self.method.decode("ascii")
return method
- def get_user_agent(self, default: str) -> str:
- """Return the last User-Agent header, or the given default.
- """
- user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
- if user_agent is None:
- return default
-
- return user_agent.decode("ascii", "replace")
-
def render(self, resrc):
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
# create a LogContext for this request
request_id = self.get_request_id()
- logcontext = self.logcontext = LoggingContext(request_id)
- logcontext.request = request_id
+ self.logcontext = LoggingContext(request_id, request=request_id)
# override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string)
@@ -259,8 +262,7 @@ class SynapseRequest(Request):
)
def _finished_processing(self):
- """Log the completion of this request and update the metrics
- """
+ """Log the completion of this request and update the metrics"""
assert self.logcontext is not None
usage = self.logcontext.get_resource_usage()
@@ -286,19 +288,15 @@ class SynapseRequest(Request):
# authenticated (e.g. and admin is puppetting a user) then we log both.
if self.requester.user.to_string() != authenticated_entity:
authenticated_entity = "{},{}".format(
- authenticated_entity, self.requester.user.to_string(),
+ authenticated_entity,
+ self.requester.user.to_string(),
)
elif self.requester is not None:
# This shouldn't happen, but we log it so we don't lose information
# and can see that we're doing something wrong.
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
- # ...or could be raw utf-8 bytes in the User-Agent header.
- # N.B. if you don't do this, the logger explodes cryptically
- # with maximum recursion trying to log errors about
- # the charset problem.
- # c.f. https://github.com/matrix-org/synapse/issues/3471
- user_agent = self.get_user_agent("-")
+ user_agent = get_request_user_agent(self, "-")
code = str(self.code)
if not self.finished:
@@ -337,8 +335,7 @@ class SynapseRequest(Request):
logger.warning("Failed to stop metrics: %r", e)
def _should_log_request(self) -> bool:
- """Whether we should log at INFO that we processed the request.
- """
+ """Whether we should log at INFO that we processed the request."""
if self.path == b"/health":
return False
@@ -349,26 +346,77 @@ class SynapseRequest(Request):
class XForwardedForRequest(SynapseRequest):
- def __init__(self, *args, **kw):
- SynapseRequest.__init__(self, *args, **kw)
+ """Request object which honours proxy headers
+ Extends SynapseRequest to replace getClientIP, getClientAddress, and isSecure with
+ information from request headers.
"""
- Add a layer on top of another request that only uses the value of an
- X-Forwarded-For header as the result of C{getClientIP}.
- """
- def getClientIP(self):
+ # the client IP and ssl flag, as extracted from the headers.
+ _forwarded_for = None # type: Optional[_XForwardedForAddress]
+ _forwarded_https = False # type: bool
+
+ def requestReceived(self, command, path, version):
+ # this method is called by the Channel once the full request has been
+ # received, to dispatch the request to a resource.
+ # We can use it to set the IP address and protocol according to the
+ # headers.
+ self._process_forwarded_headers()
+ return super().requestReceived(command, path, version)
+
+ def _process_forwarded_headers(self):
+ headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
+ if not headers:
+ return
+
+ # for now, we just use the first x-forwarded-for header. Really, we ought
+ # to start from the client IP address, and check whether it is trusted; if it
+ # is, work backwards through the headers until we find an untrusted address.
+ # see https://github.com/matrix-org/synapse/issues/9471
+ self._forwarded_for = _XForwardedForAddress(
+ headers[0].split(b",")[0].strip().decode("ascii")
+ )
+
+ # if we got an x-forwarded-for header, also look for an x-forwarded-proto header
+ header = self.getHeader(b"x-forwarded-proto")
+ if header is not None:
+ self._forwarded_https = header.lower() == b"https"
+ else:
+ # this is done largely for backwards-compatibility so that people that
+ # haven't set an x-forwarded-proto header don't get a redirect loop.
+ logger.warning(
+ "forwarded request lacks an x-forwarded-proto header: assuming https"
+ )
+ self._forwarded_https = True
+
+ def isSecure(self):
+ if self._forwarded_https:
+ return True
+ return super().isSecure()
+
+ def getClientIP(self) -> str:
"""
- @return: The client address (the first address) in the value of the
- I{X-Forwarded-For header}. If the header is not present, return
- C{b"-"}.
+ Return the IP address of the client who submitted this request.
+
+ This method is deprecated. Use getClientAddress() instead.
"""
- return (
- self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0]
- .split(b",")[0]
- .strip()
- .decode("ascii")
- )
+ if self._forwarded_for is not None:
+ return self._forwarded_for.host
+ return super().getClientIP()
+
+ def getClientAddress(self) -> IAddress:
+ """
+ Return the address of the client who submitted this request.
+ """
+ if self._forwarded_for is not None:
+ return self._forwarded_for
+ return super().getClientAddress()
+
+
+@implementer(IAddress)
+@attr.s(frozen=True, slots=True)
+class _XForwardedForAddress:
+ host = attr.ib(type=str)
class SynapseSite(Site):
@@ -393,7 +441,9 @@ class SynapseSite(Site):
assert config.http_options is not None
proxied = config.http_options.x_forwarded
- self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
+ self.requestFactory = (
+ XForwardedForRequest if proxied else SynapseRequest
+ ) # type: Type[Request]
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index fb937b3f28..643492ceaf 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
-from twisted.internet.interfaces import IPushProducer, ITransport
+from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
from twisted.internet.protocol import Factory, Protocol
+from twisted.internet.tcp import Connection
from twisted.python.failure import Failure
logger = logging.getLogger(__name__)
@@ -52,7 +53,9 @@ class LogProducer:
format: A callable to format the log record to a string.
"""
- transport = attr.ib(type=ITransport)
+ # This is essentially ITCPTransport, but that is missing certain fields
+ # (connected and registerProducer) which are part of the implementation.
+ transport = attr.ib(type=Connection)
_format = attr.ib(type=Callable[[logging.LogRecord], str])
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
@@ -121,7 +124,9 @@ class RemoteHandler(logging.Handler):
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
- endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
+ endpoint = TCP4ClientEndpoint(
+ _reactor, self.host, self.port
+ ) # type: IStreamClientEndpoint
elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
else:
@@ -147,8 +152,6 @@ class RemoteHandler(logging.Handler):
if self._connection_waiter:
return
- self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
-
def fail(failure: Failure) -> None:
# If the Deferred was cancelled (e.g. during shutdown) do not try to
# reconnect (this will cause an infinite loop of errors).
@@ -161,9 +164,13 @@ class RemoteHandler(logging.Handler):
self._connect()
def writer(result: Protocol) -> None:
+ # Force recognising transport as a Connection and not the more
+ # generic ITransport.
+ transport = result.transport # type: Connection # type: ignore
+
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
- if self._producer and result.transport is self._producer.transport:
+ if self._producer and transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
@@ -174,13 +181,17 @@ class RemoteHandler(logging.Handler):
# Make a new producer and start it.
self._producer = LogProducer(
- buffer=self._buffer, transport=result.transport, format=self.format,
+ buffer=self._buffer,
+ transport=transport,
+ format=self.format,
)
- result.transport.registerProducer(self._producer, True)
+ transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
- self._connection_waiter.addCallbacks(writer, fail)
+ deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
+ deferred.addCallbacks(writer, fail)
+ self._connection_waiter = deferred
def _handle_pressure(self) -> None:
"""
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 14d9c104c2..3e054f615c 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -60,7 +60,10 @@ def parse_drain_configs(
)
# Either use the default formatter or the tersejson one.
- if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,):
+ if logging_type in (
+ DrainType.CONSOLE_JSON,
+ DrainType.FILE_JSON,
+ ):
formatter = "json" # type: Optional[str]
elif logging_type in (
DrainType.CONSOLE_JSON_TERSE,
@@ -131,7 +134,9 @@ def parse_drain_configs(
)
-def setup_structured_logging(log_config: dict,) -> dict:
+def setup_structured_logging(
+ log_config: dict,
+) -> dict:
"""
Convert a legacy structured logging configuration (from Synapse < v1.23.0)
to one compatible with the new standard library handlers.
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index ca0c774cc5..03cf3c2b8e 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -203,10 +203,6 @@ class _Sentinel:
def copy_to(self, record):
pass
- def copy_to_twisted_log_entry(self, record):
- record["request"] = None
- record["scope"] = None
-
def start(self, rusage: "Optional[resource._RUsage]"):
pass
@@ -256,7 +252,12 @@ class LoggingContext:
"scope",
]
- def __init__(self, name=None, parent_context=None, request=None) -> None:
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ parent_context: "Optional[LoggingContext]" = None,
+ request: Optional[str] = None,
+ ) -> None:
self.previous_context = current_context()
self.name = name
@@ -337,7 +338,10 @@ class LoggingContext:
if self.previous_context != old_context:
logcontext_error(
"Expected previous context %r, found %r"
- % (self.previous_context, old_context,)
+ % (
+ self.previous_context,
+ old_context,
+ )
)
return self
@@ -372,13 +376,6 @@ class LoggingContext:
# we also track the current scope:
record.scope = self.scope
- def copy_to_twisted_log_entry(self, record) -> None:
- """
- Copy logging fields from this context to a Twisted log record.
- """
- record["request"] = self.request
- record["scope"] = self.scope
-
def start(self, rusage: "Optional[resource._RUsage]") -> None:
"""
Record that this logcontext is currently running.
@@ -542,28 +539,25 @@ class LoggingContext:
class LoggingContextFilter(logging.Filter):
"""Logging filter that adds values from the current logging context to each
record.
- Args:
- **defaults: Default values to avoid formatters complaining about
- missing fields
"""
- def __init__(self, **defaults) -> None:
- self.defaults = defaults
+ def __init__(self, request: str = ""):
+ self._default_request = request
- def filter(self, record) -> Literal[True]:
+ def filter(self, record: logging.LogRecord) -> Literal[True]:
"""Add each fields from the logging contexts to the record.
Returns:
True to include the record in the log output.
"""
context = current_context()
- for key, value in self.defaults.items():
- setattr(record, key, value)
+ record.request = self._default_request # type: ignore
# context should never be None, but if it somehow ends up being, then
# we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake.
if context is not None:
- context.copy_to(record)
+ # Logging is interested in the request.
+ record.request = context.request # type: ignore
return True
@@ -571,7 +565,7 @@ class LoggingContextFilter(logging.Filter):
class PreserveLoggingContext:
"""Context manager which replaces the logging context
- The previous logging context is restored on exit."""
+ The previous logging context is restored on exit."""
__slots__ = ["_old_context", "_new_context"]
@@ -594,7 +588,10 @@ class PreserveLoggingContext:
else:
logcontext_error(
"Expected logging context %s but found %s"
- % (self._new_context, context,)
+ % (
+ self._new_context,
+ context,
+ )
)
@@ -630,9 +627,7 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
return current
-def nested_logging_context(
- suffix: str, parent_context: Optional[LoggingContext] = None
-) -> LoggingContext:
+def nested_logging_context(suffix: str) -> LoggingContext:
"""Creates a new logging context as a child of another.
The nested logging context will have a 'request' made up of the parent context's
@@ -646,20 +641,23 @@ def nested_logging_context(
# ... do stuff
Args:
- suffix (str): suffix to add to the parent context's 'request'.
- parent_context (LoggingContext|None): parent context. Will use the current context
- if None.
+ suffix: suffix to add to the parent context's 'request'.
Returns:
LoggingContext: new logging context.
"""
- if parent_context is not None:
- context = parent_context # type: LoggingContextOrSentinel
+ curr_context = current_context()
+ if not curr_context:
+ logger.warning(
+ "Starting nested logging context from sentinel context: metrics will be lost"
+ )
+ parent_context = None
+ prefix = ""
else:
- context = current_context()
- return LoggingContext(
- parent_context=context, request=str(context.request) + "-" + suffix
- )
+ assert isinstance(curr_context, LoggingContext)
+ parent_context = curr_context
+ prefix = str(parent_context.request)
+ return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
def preserve_fn(f):
@@ -671,7 +669,7 @@ def preserve_fn(f):
return g
-def run_in_background(f, *args, **kwargs):
+def run_in_background(f, *args, **kwargs) -> defer.Deferred:
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@@ -691,7 +689,7 @@ def run_in_background(f, *args, **kwargs):
current = current_context()
try:
res = f(*args, **kwargs)
- except: # noqa: E722
+ except Exception:
# the assumption here is that the caller doesn't want to be disturbed
# by synchronous exceptions, so let's turn them into Failures.
return defer.fail()
@@ -699,8 +697,10 @@ def run_in_background(f, *args, **kwargs):
if isinstance(res, types.CoroutineType):
res = defer.ensureDeferred(res)
+ # At this point we should have a Deferred, if not then f was a synchronous
+ # function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
- return res
+ return defer.succeed(res)
if res.called and not res.paused:
# The function should have maintained the logcontext, so we can
@@ -836,10 +836,18 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
Deferred: A Deferred which fires a callback with the result of `f`, or an
errback if `f` throws an exception.
"""
- logcontext = current_context()
+ curr_context = current_context()
+ if not curr_context:
+ logger.warning(
+ "Calling defer_to_threadpool from sentinel context: metrics will be lost"
+ )
+ parent_context = None
+ else:
+ assert isinstance(curr_context, LoggingContext)
+ parent_context = curr_context
def g():
- with LoggingContext(parent_context=logcontext):
+ with LoggingContext(parent_context=parent_context):
return f(*args, **kwargs)
return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index ab586c318c..aa146e8bb8 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -169,7 +169,7 @@ import inspect
import logging
import re
from functools import wraps
-from typing import TYPE_CHECKING, Dict, Optional, Type
+from typing import TYPE_CHECKING, Dict, Optional, Pattern, Type
import attr
@@ -238,8 +238,7 @@ try:
@attr.s(slots=True, frozen=True)
class _WrappedRustReporter:
- """Wrap the reporter to ensure `report_span` never throws.
- """
+ """Wrap the reporter to ensure `report_span` never throws."""
_reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
@@ -263,7 +262,7 @@ logger = logging.getLogger(__name__)
# Block everything by default
# A regex which matches the server_names to expose traces for.
# None means 'block everything'.
-_homeserver_whitelist = None
+_homeserver_whitelist = None # type: Optional[Pattern[str]]
# Util methods
@@ -326,8 +325,7 @@ def noop_context_manager(*args, **kwargs):
def init_tracer(hs: "HomeServer"):
- """Set the whitelists and initialise the JaegerClient tracer
- """
+ """Set the whitelists and initialise the JaegerClient tracer"""
global opentracing
if not hs.config.opentracer_enabled:
# We don't have a tracer
@@ -384,7 +382,7 @@ def whitelisted_homeserver(destination):
Args:
destination (str)
- """
+ """
if _homeserver_whitelist:
return _homeserver_whitelist.match(destination)
@@ -791,7 +789,7 @@ def tag_args(func):
@wraps(func)
def _tag_args_inner(*args, **kwargs):
- argspec = inspect.getargspec(func)
+ argspec = inspect.getfullargspec(func)
for i, arg in enumerate(argspec.args[1:]):
set_tag("ARG_" + arg, args[i])
set_tag("args", args[len(argspec.args) :])
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index becf66dd86..fd3543ab04 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -43,8 +43,7 @@ def _log_debug_as_f(f, msg, msg_args):
def log_function(f):
- """ Function decorator that logs every call to that function.
- """
+ """Function decorator that logs every call to that function."""
func_name = f.__name__
@wraps(f)
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index cbf0dbb871..3b499efc07 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -155,8 +155,7 @@ class InFlightGauge:
self._registrations.setdefault(key, set()).add(callback)
def unregister(self, key, callback):
- """Registers that we've exited a block with labels `key`.
- """
+ """Registers that we've exited a block with labels `key`."""
with self._lock:
self._registrations.setdefault(key, set()).discard(callback)
@@ -402,7 +401,9 @@ class PyPyGCStats:
# Total time spent in GC: 0.073 # s.total_gc_time
pypy_gc_time = CounterMetricFamily(
- "pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[],
+ "pypy_gc_time_seconds_total",
+ "Total time spent in PyPy GC",
+ labels=[],
)
pypy_gc_time.add_metric([], s.total_gc_time / 1000)
yield pypy_gc_time
@@ -526,7 +527,7 @@ class ReactorLastSeenMetric:
REGISTRY.register(ReactorLastSeenMetric())
-def runUntilCurrentTimer(func):
+def runUntilCurrentTimer(reactor, func):
@functools.wraps(func)
def f(*args, **kwargs):
now = reactor.seconds()
@@ -589,13 +590,14 @@ def runUntilCurrentTimer(func):
try:
# Ensure the reactor has all the attributes we expect
- reactor.runUntilCurrent
- reactor._newTimedCalls
- reactor.threadCallQueue
+ reactor.seconds # type: ignore
+ reactor.runUntilCurrent # type: ignore
+ reactor._newTimedCalls # type: ignore
+ reactor.threadCallQueue # type: ignore
# runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling.
- reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
+ reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore
# We manually run the GC each reactor tick so that we can get some metrics
# about time spent doing GC,
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index 734271e765..71320a1402 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -216,7 +216,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
@classmethod
def factory(cls, registry):
"""Returns a dynamic MetricsHandler class tied
- to the passed registry.
+ to the passed registry.
"""
# This implementation relies on MetricsHandler.registry
# (defined above and defaulted to REGISTRY).
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 658f6ecd72..b56986d8e7 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import threading
from functools import wraps
@@ -25,6 +24,7 @@ from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.opentracing import noop_context_manager, start_active_span
+from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import resource
@@ -199,22 +199,17 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
_background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc()
- with BackgroundProcessLoggingContext(desc) as context:
- context.request = "%s-%i" % (desc, count)
+ with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
try:
ctx = noop_context_manager()
if bg_start_span:
ctx = start_active_span(desc, tags={"request_id": context.request})
with ctx:
- result = func(*args, **kwargs)
-
- if inspect.isawaitable(result):
- result = await result
-
- return result
+ return await maybe_awaitable(func(*args, **kwargs))
except Exception:
logger.exception(
- "Background process '%s' threw an exception", desc,
+ "Background process '%s' threw an exception",
+ desc,
)
finally:
_background_process_in_flight_count.labels(desc).dec()
@@ -249,14 +244,13 @@ class BackgroundProcessLoggingContext(LoggingContext):
__slots__ = ["_proc"]
- def __init__(self, name: str):
- super().__init__(name)
+ def __init__(self, name: str, request: Optional[str] = None):
+ super().__init__(name, request=request)
self._proc = _BackgroundProcess(name, self)
def start(self, rusage: "Optional[resource._RUsage]"):
- """Log context has started running (again).
- """
+ """Log context has started running (again)."""
super().start(rusage)
@@ -267,8 +261,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
_background_processes_active_since_last_scrape.add(self._proc)
def __exit__(self, type, value, traceback) -> None:
- """Log context has finished.
- """
+ """Log context has finished."""
super().__exit__(type, value, traceback)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 72ab5750cc..781e02fbbb 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Iterable, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple
from twisted.internet import defer
@@ -203,11 +203,26 @@ class ModuleApi:
)
def generate_short_term_login_token(
- self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ self,
+ user_id: str,
+ duration_in_ms: int = (2 * 60 * 1000),
+ auth_provider_id: str = "",
) -> str:
- """Generate a login token suitable for m.login.token authentication"""
+ """Generate a login token suitable for m.login.token authentication
+
+ Args:
+ user_id: gives the ID of the user that the token is for
+
+ duration_in_ms: the time that the token will be valid for
+
+ auth_provider_id: the ID of the SSO IdP that the user used to authenticate
+ to get this token, if any. This is encoded in the token so that
+ /login can report stats on number of successful logins by IdP.
+ """
return self._hs.get_macaroon_generator().generate_short_term_login_token(
- user_id, duration_in_ms
+ user_id,
+ auth_provider_id,
+ duration_in_ms,
)
@defer.inlineCallbacks
@@ -275,11 +290,19 @@ class ModuleApi:
redirect them directly if whitelisted).
"""
self._auth_handler._complete_sso_login(
- registered_user_id, request, client_redirect_url,
+ registered_user_id,
+ "<unknown>",
+ request,
+ client_redirect_url,
)
async def complete_sso_login_async(
- self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
+ self,
+ registered_user_id: str,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ new_user: bool = False,
+ auth_provider_id: str = "<unknown>",
):
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
@@ -291,15 +314,23 @@ class ModuleApi:
request: The request to respond to.
client_redirect_url: The URL to which to offer to redirect the user (or to
redirect them directly if whitelisted).
+ new_user: set to true to use wording for the consent appropriate to a user
+ who has just registered.
+ auth_provider_id: the ID of the SSO IdP which was used to log in. This
+ is used to track counts of sucessful logins by IdP.
"""
await self._auth_handler.complete_sso_login(
- registered_user_id, request, client_redirect_url,
+ registered_user_id,
+ auth_provider_id,
+ request,
+ client_redirect_url,
+ new_user=new_user,
)
@defer.inlineCallbacks
def get_state_events_in_room(
self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
- ) -> defer.Deferred:
+ ) -> Generator[defer.Deferred, Any, defer.Deferred]:
"""Gets current state events for the given room.
(This is exposed for compatibility with the old SpamCheckerApi. We should
@@ -346,7 +377,10 @@ class ModuleApi:
event,
_,
) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event(
- requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
+ requester,
+ event_dict,
+ ratelimit=False,
+ ignore_shadow_ban=True,
)
return event
diff --git a/synapse/notifier.py b/synapse/notifier.py
index a17352ef46..1374aae490 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -34,7 +34,7 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse.server
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
@@ -75,7 +75,7 @@ def count(func: Callable[[T], bool], it: Iterable[T]) -> int:
class _NotificationListener:
- """ This represents a single client connection to the events stream.
+ """This represents a single client connection to the events stream.
The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred.
"""
@@ -119,7 +119,10 @@ class _NotifierUserStream:
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(
- self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
+ self,
+ stream_key: str,
+ stream_id: Union[int, RoomStreamToken],
+ time_now_ms: int,
):
"""Notify any listeners for this user of a new event from an
event source.
@@ -140,7 +143,7 @@ class _NotifierUserStream:
noify_deferred.callback(self.current_token)
def remove(self, notifier: "Notifier"):
- """ Remove this listener from all the indexes in the Notifier
+ """Remove this listener from all the indexes in the Notifier
it knows about.
"""
@@ -186,7 +189,7 @@ class _PendingRoomEventEntry:
class Notifier:
- """ This class is responsible for notifying any listeners when there are
+ """This class is responsible for notifying any listeners when there are
new events available for it.
Primarily used from the /events stream.
@@ -265,8 +268,7 @@ class Notifier:
max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [],
):
- """Unwraps event and calls `on_new_room_event_args`.
- """
+ """Unwraps event and calls `on_new_room_event_args`."""
self.on_new_room_event_args(
event_pos=event_pos,
room_id=event.room_id,
@@ -341,7 +343,10 @@ class Notifier:
if users or rooms:
self.on_new_event(
- "room_key", max_room_stream_token, users=users, rooms=rooms,
+ "room_key",
+ max_room_stream_token,
+ users=users,
+ rooms=rooms,
)
self._on_updated_room_token(max_room_stream_token)
@@ -392,35 +397,36 @@ class Notifier:
users: Collection[Union[str, UserID]] = [],
rooms: Collection[str] = [],
):
- """ Used to inform listeners that something has happened event wise.
+ """Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms.
"""
- with PreserveLoggingContext():
- with Measure(self.clock, "on_new_event"):
- user_streams = set()
+ with Measure(self.clock, "on_new_event"):
+ user_streams = set()
- for user in users:
- user_stream = self.user_to_user_stream.get(str(user))
- if user_stream is not None:
- user_streams.add(user_stream)
+ for user in users:
+ user_stream = self.user_to_user_stream.get(str(user))
+ if user_stream is not None:
+ user_streams.add(user_stream)
- for room in rooms:
- user_streams |= self.room_to_user_streams.get(room, set())
+ for room in rooms:
+ user_streams |= self.room_to_user_streams.get(room, set())
- time_now_ms = self.clock.time_msec()
- for user_stream in user_streams:
- try:
- user_stream.notify(stream_key, new_token, time_now_ms)
- except Exception:
- logger.exception("Failed to notify listener")
+ time_now_ms = self.clock.time_msec()
+ for user_stream in user_streams:
+ try:
+ user_stream.notify(stream_key, new_token, time_now_ms)
+ except Exception:
+ logger.exception("Failed to notify listener")
- self.notify_replication()
+ self.notify_replication()
- # Notify appservices
- self._notify_app_services_ephemeral(
- stream_key, new_token, users,
- )
+ # Notify appservices
+ self._notify_app_services_ephemeral(
+ stream_key,
+ new_token,
+ users,
+ )
def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happened
@@ -503,7 +509,7 @@ class Notifier:
is_guest: bool = False,
explicit_room_id: str = None,
) -> EventStreamResult:
- """ For the given user and rooms, return any new events for them. If
+ """For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning.
@@ -611,7 +617,9 @@ class Notifier:
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:
- return state.content["history_visibility"] == "world_readable"
+ return (
+ state.content["history_visibility"] == HistoryVisibility.WORLD_READABLE
+ )
else:
return False
@@ -650,8 +658,7 @@ class Notifier:
cb()
def notify_remote_server_up(self, server: str):
- """Notify any replication that a remote server has come back up
- """
+ """Notify any replication that a remote server has come back up"""
# We call federation_sender directly rather than registering as a
# callback as a) we already have a reference to it and b) it introduces
# circular dependencies.
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 5a437f9810..9fc3da49a2 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -13,7 +13,111 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+import attr
+
+from synapse.types import JsonDict, RoomStreamToken
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+@attr.s(slots=True)
+class PusherConfig:
+ """Parameters necessary to configure a pusher."""
+
+ id = attr.ib(type=Optional[str])
+ user_name = attr.ib(type=str)
+ access_token = attr.ib(type=Optional[int])
+ profile_tag = attr.ib(type=str)
+ kind = attr.ib(type=str)
+ app_id = attr.ib(type=str)
+ app_display_name = attr.ib(type=str)
+ device_display_name = attr.ib(type=str)
+ pushkey = attr.ib(type=str)
+ ts = attr.ib(type=int)
+ lang = attr.ib(type=Optional[str])
+ data = attr.ib(type=Optional[JsonDict])
+ last_stream_ordering = attr.ib(type=int)
+ last_success = attr.ib(type=Optional[int])
+ failing_since = attr.ib(type=Optional[int])
+
+ def as_dict(self) -> Dict[str, Any]:
+ """Information that can be retrieved about a pusher after creation."""
+ return {
+ "app_display_name": self.app_display_name,
+ "app_id": self.app_id,
+ "data": self.data,
+ "device_display_name": self.device_display_name,
+ "kind": self.kind,
+ "lang": self.lang,
+ "profile_tag": self.profile_tag,
+ "pushkey": self.pushkey,
+ }
+
+
+@attr.s(slots=True)
+class ThrottleParams:
+ """Parameters for controlling the rate of sending pushes via email."""
+
+ last_sent_ts = attr.ib(type=int)
+ throttle_ms = attr.ib(type=int)
+
+
+class Pusher(metaclass=abc.ABCMeta):
+ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
+ self.hs = hs
+ self.store = self.hs.get_datastore()
+ self.clock = self.hs.get_clock()
+
+ self.pusher_id = pusher_config.id
+ self.user_id = pusher_config.user_name
+ self.app_id = pusher_config.app_id
+ self.pushkey = pusher_config.pushkey
+
+ self.last_stream_ordering = pusher_config.last_stream_ordering
+
+ # This is the highest stream ordering we know it's safe to process.
+ # When new events arrive, we'll be given a window of new events: we
+ # should honour this rather than just looking for anything higher
+ # because of potential out-of-order event serialisation.
+ self.max_stream_ordering = self.store.get_room_max_stream_ordering()
+
+ def on_new_notifications(self, max_token: RoomStreamToken) -> None:
+ # We just use the minimum stream ordering and ignore the vector clock
+ # component. This is safe to do as long as we *always* ignore the vector
+ # clock components.
+ max_stream_ordering = max_token.stream
+
+ self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
+ self._start_processing()
+
+ @abc.abstractmethod
+ def _start_processing(self):
+ """Start processing push notifications."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_started(self, have_notifs: bool) -> None:
+ """Called when this pusher has been started.
+
+ Args:
+ should_check_for_notifs: Whether we should immediately
+ check for push to send. Set to False only if it's known there
+ is nothing to send
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_stop(self) -> None:
+ raise NotImplementedError()
+
class PusherConfigException(Exception):
- def __init__(self, msg):
- super().__init__(msg)
+ """An error occurred when creating a pusher."""
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index fabc9ba126..38a47a600f 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -14,19 +14,22 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.util.metrics import Measure
-from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ActionGenerator:
- def __init__(self, hs):
- self.hs = hs
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
- self.store = hs.get_datastore()
self.bulk_evaluator = BulkPushRuleEvaluator(hs)
# really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and
@@ -35,6 +38,8 @@ class ActionGenerator:
# event stream, so we just run the rules for a client with no profile
# tag (ie. we just need all the users).
- async def handle_push_actions_for_event(self, event, context):
+ async def handle_push_actions_for_event(
+ self, event: EventBase, context: EventContext
+ ) -> None:
with Measure(self.clock, "action_for_event_by_user"):
await self.bulk_evaluator.action_for_event_by_user(event, context)
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index f5c6c977c7..4d284de133 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -15,16 +15,19 @@
# limitations under the License.
import copy
+from typing import Any, Dict, List
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
-def list_with_base_rules(rawrules, use_new_defaults=False):
+def list_with_base_rules(
+ rawrules: List[Dict[str, Any]], use_new_defaults: bool = False
+) -> List[Dict[str, Any]]:
"""Combine the list of rules set by the user with the default push rules
Args:
- rawrules(list): The rules the user has modified or set.
- use_new_defaults(bool): Whether to use the new experimental default rules when
+ rawrules: The rules the user has modified or set.
+ use_new_defaults: Whether to use the new experimental default rules when
appending or prepending default rules.
Returns:
@@ -94,7 +97,11 @@ def list_with_base_rules(rawrules, use_new_defaults=False):
return ruleslist
-def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
+def make_base_append_rules(
+ kind: str,
+ modified_base_rules: Dict[str, Dict[str, Any]],
+ use_new_defaults: bool = False,
+) -> List[Dict[str, Any]]:
rules = []
if kind == "override":
@@ -116,6 +123,7 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
+ assert isinstance(r["rule_id"], str)
modified = modified_base_rules.get(r["rule_id"])
if modified:
r["actions"] = modified["actions"]
@@ -123,7 +131,11 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
return rules
-def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
+def make_base_prepend_rules(
+ kind: str,
+ modified_base_rules: Dict[str, Dict[str, Any]],
+ use_new_defaults: bool = False,
+) -> List[Dict[str, Any]]:
rules = []
if kind == "override":
@@ -133,6 +145,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
+ assert isinstance(r["rule_id"], str)
modified = modified_base_rules.get(r["rule_id"])
if modified:
r["actions"] = modified["actions"]
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 82a72dc34f..1897f59153 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import attr
from prometheus_client import Counter
@@ -25,16 +26,16 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import register_cache
+from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
from synapse.util.caches.lrucache import LruCache
from .push_rule_evaluator import PushRuleEvaluatorForEvent
-logger = logging.getLogger(__name__)
-
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
-rules_by_room = {}
+logger = logging.getLogger(__name__)
push_rules_invalidation_counter = Counter(
@@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
room at once.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
resizable=False,
)
- async def _get_rules_for_event(self, event, context):
+ async def _get_rules_for_event(
+ self, event: EventBase, context: EventContext
+ ) -> Dict[str, List[Dict[str, Any]]]:
"""This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite.
@@ -140,12 +143,8 @@ class BulkPushRuleEvaluator:
return rules_by_user
@lru_cache()
- def _get_rules_for_room(self, room_id):
- """Get the current RulesForRoom object for the given room id
-
- Returns:
- RulesForRoom
- """
+ def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
+ """Get the current RulesForRoom object for the given room id"""
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be
# a race if invalidate_all gets called (which assumes its in the cache)
@@ -156,20 +155,21 @@ class BulkPushRuleEvaluator:
self.room_push_rule_cache_metrics,
)
- async def _get_power_levels_and_sender_level(self, event, context):
+ async def _get_power_levels_and_sender_level(
+ self, event: EventBase, context: EventContext
+ ) -> Tuple[dict, int]:
prev_state_ids = await context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
- pl_event = await self.store.get_event(pl_event_id)
- auth_events = {POWER_KEY: pl_event}
+ auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
else:
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False
)
- auth_events = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+ auth_events_dict = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -177,7 +177,9 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
- async def action_for_event_by_user(self, event, context) -> None:
+ async def action_for_event_by_user(
+ self, event: EventBase, context: EventContext
+ ) -> None:
"""Given an event and context, evaluate the push rules, check if the message
should increment the unread count, and insert the results into the
event_push_actions_staging table.
@@ -185,7 +187,7 @@ class BulkPushRuleEvaluator:
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
- actions_by_user = {}
+ actions_by_user = {} # type: Dict[str, List[Union[dict, str]]]
room_members = await self.store.get_joined_users_from_context(event, context)
@@ -198,16 +200,20 @@ class BulkPushRuleEvaluator:
event, len(room_members), sender_power_level, power_levels
)
- condition_cache = {}
+ condition_cache = {} # type: Dict[str, bool]
+
+ # If the event is not a state event check if any users ignore the sender.
+ if not event.is_state():
+ ignorers = await self.store.ignored_by(event.sender)
+ else:
+ ignorers = set()
for uid, rules in rules_by_user.items():
if event.sender == uid:
continue
- if not event.is_state():
- is_ignored = await self.store.is_ignored_by(event.sender, uid)
- if is_ignored:
- continue
+ if uid in ignorers:
+ continue
display_name = None
profile_info = room_members.get(uid)
@@ -245,11 +251,19 @@ class BulkPushRuleEvaluator:
# notified for this event. (This will then get handled when we persist
# the event)
await self.store.add_push_actions_to_staging(
- event.event_id, actions_by_user, count_as_unread,
+ event.event_id,
+ actions_by_user,
+ count_as_unread,
)
-def _condition_checker(evaluator, conditions, uid, display_name, cache):
+def _condition_checker(
+ evaluator: PushRuleEvaluatorForEvent,
+ conditions: List[dict],
+ uid: str,
+ display_name: str,
+ cache: Dict[str, bool],
+) -> bool:
for cond in conditions:
_id = cond.get("_id", None)
if _id:
@@ -277,15 +291,19 @@ class RulesForRoom:
"""
def __init__(
- self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
+ self,
+ hs: "HomeServer",
+ room_id: str,
+ rules_for_room_cache: LruCache,
+ room_push_rule_cache_metrics: CacheMetric,
):
"""
Args:
- hs (HomeServer)
- room_id (str)
+ hs: The HomeServer object.
+ room_id: The room ID.
rules_for_room_cache: The cache object that caches these
RoomsForUser objects.
- room_push_rule_cache_metrics (CacheMetric)
+ room_push_rule_cache_metrics: The metrics object
"""
self.room_id = room_id
self.is_mine_id = hs.is_mine_id
@@ -294,8 +312,10 @@ class RulesForRoom:
self.linearizer = Linearizer(name="rules_for_room")
- self.member_map = {} # event_id -> (user_id, state)
- self.rules_by_user = {} # user_id -> rules
+ # event_id -> (user_id, state)
+ self.member_map = {} # type: Dict[str, Tuple[str, str]]
+ # user_id -> rules
+ self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
@@ -315,7 +335,7 @@ class RulesForRoom:
# calculate push for)
# These never need to be invalidated as we will never set up push for
# them.
- self.uninteresting_user_set = set()
+ self.uninteresting_user_set = set() # type: Set[str]
# We need to be clever on the invalidating caches callbacks, as
# otherwise the invalidation callback holds a reference to the object,
@@ -325,7 +345,9 @@ class RulesForRoom:
# to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
- async def get_rules(self, event, context):
+ async def get_rules(
+ self, event: EventBase, context: EventContext
+ ) -> Dict[str, List[Dict[str, dict]]]:
"""Given an event context return the rules for all users who are
currently in the room.
"""
@@ -356,6 +378,8 @@ class RulesForRoom:
else:
current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()
+ # Ensure the state IDs exist.
+ assert current_state_ids is not None
push_rules_state_size_counter.inc(len(current_state_ids))
@@ -420,18 +444,23 @@ class RulesForRoom:
return ret_rules_by_user
async def _update_rules_with_member_event_ids(
- self, ret_rules_by_user, member_event_ids, state_group, event
- ):
+ self,
+ ret_rules_by_user: Dict[str, list],
+ member_event_ids: Dict[str, str],
+ state_group: Optional[int],
+ event: EventBase,
+ ) -> None:
"""Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list.
Args:
- ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
+ ret_rules_by_user: Partially filled dict of push rules. Gets
updated with any new rules.
- member_event_ids (dict): Dict of user id to event id for membership events
+ member_event_ids: Dict of user id to event id for membership events
that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
+ event: The event we are currently computing push rules for.
"""
sequence = self.sequence
@@ -449,19 +478,19 @@ class RulesForRoom:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
- user_ids = {
+ joined_user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
}
- logger.debug("Joined: %r", user_ids)
+ logger.debug("Joined: %r", joined_user_ids)
# Previously we only considered users with pushers or read receipts in that
# room. We can't do this anymore because we use push actions to calculate unread
# counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here.
- user_ids = list(filter(self.is_mine_id, user_ids))
+ user_ids = list(filter(self.is_mine_id, joined_user_ids))
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
@@ -473,7 +502,7 @@ class RulesForRoom:
self.update_cache(sequence, members, ret_rules_by_user, state_group)
- def invalidate_all(self):
+ def invalidate_all(self) -> None:
# Note: Don't hand this function directly to an invalidation callback
# as it keeps a reference to self and will stop this instance from being
# GC'd if it gets dropped from the rules_to_user cache. Instead use
@@ -485,7 +514,7 @@ class RulesForRoom:
self.rules_by_user = {}
push_rules_invalidation_counter.inc()
- def update_cache(self, sequence, members, rules_by_user, state_group):
+ def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
if sequence == self.sequence:
self.member_map.update(members)
self.rules_by_user = rules_by_user
@@ -496,7 +525,7 @@ class RulesForRoom:
class _Invalidation:
# _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
# which means that it it is stored on the bulk_get_push_rules cache entry. In order
- # to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
+ # to ensure that we don't accumulate lots of redundant callbacks on the cache entry,
# we need to ensure that two _Invalidation objects are "equal" if they refer to the
# same `cache` and `room_id`.
#
@@ -506,7 +535,7 @@ class _Invalidation:
cache = attr.ib(type=LruCache)
room_id = attr.ib(type=str)
- def __call__(self):
+ def __call__(self) -> None:
rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index a59b639f15..0cadba761a 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -14,24 +14,27 @@
# limitations under the License.
import copy
+from typing import Any, Dict, List, Optional
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
+from synapse.types import UserID
-def format_push_rules_for_user(user, ruleslist):
+def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist)
- rules = {"global": {}, "device": {}}
+ rules = {
+ "global": {},
+ "device": {},
+ } # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
for r in ruleslist:
- rulearray = None
-
template_name = _priority_class_to_template_name(r["priority_class"])
# Remove internal stuff.
@@ -57,13 +60,13 @@ def format_push_rules_for_user(user, ruleslist):
return rules
-def _add_empty_priority_class_arrays(d):
+def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
-def _rule_to_template(rule):
+def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
unscoped_rule_id = None
if "rule_id" in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
@@ -82,6 +85,10 @@ def _rule_to_template(rule):
return None
templaterule = {"actions": rule["actions"]}
templaterule["pattern"] = thecond["pattern"]
+ else:
+ # This should not be reached unless this function is not kept in sync
+ # with PRIORITY_CLASS_INVERSE_MAP.
+ raise ValueError("Unexpected template_name: %s" % (template_name,))
if unscoped_rule_id:
templaterule["rule_id"] = unscoped_rule_id
@@ -90,9 +97,9 @@ def _rule_to_template(rule):
return templaterule
-def _rule_id_from_namespaced(in_rule_id):
+def _rule_id_from_namespaced(in_rule_id: str) -> str:
return in_rule_id.split("/")[-1]
-def _priority_class_to_template_name(pc):
+def _priority_class_to_template_name(pc: int) -> str:
return PRIORITY_CLASS_INVERSE_MAP[pc]
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index c6763971ee..c0968dc7a1 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -14,11 +14,17 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Dict, List, Optional
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from twisted.internet.interfaces import IDelayedCall
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import RoomStreamToken
+from synapse.push import Pusher, PusherConfig, ThrottleParams
+from synapse.push.mailer import Mailer
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -46,7 +52,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000
INCLUDE_ALL_UNREAD_NOTIFS = False
-class EmailPusher:
+class EmailPusher(Pusher):
"""
A pusher that sends email notifications about events (approximately)
when they happen.
@@ -54,37 +60,30 @@ class EmailPusher:
factor out the common parts
"""
- def __init__(self, hs, pusherdict, mailer):
- self.hs = hs
+ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer):
+ super().__init__(hs, pusher_config)
self.mailer = mailer
self.store = self.hs.get_datastore()
- self.clock = self.hs.get_clock()
- self.pusher_id = pusherdict["id"]
- self.user_id = pusherdict["user_name"]
- self.app_id = pusherdict["app_id"]
- self.email = pusherdict["pushkey"]
- self.last_stream_ordering = pusherdict["last_stream_ordering"]
- self.timed_call = None
- self.throttle_params = None
-
- # See httppusher
- self.max_stream_ordering = None
+ self.email = pusher_config.pushkey
+ self.timed_call = None # type: Optional[IDelayedCall]
+ self.throttle_params = {} # type: Dict[str, ThrottleParams]
+ self._inited = False
self._is_processing = False
- def on_started(self, should_check_for_notifs):
+ def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started.
Args:
- should_check_for_notifs (bool): Whether we should immediately
+ should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there
is nothing to send
"""
if should_check_for_notifs and self.mailer is not None:
self._start_processing()
- def on_stop(self):
+ def on_stop(self) -> None:
if self.timed_call:
try:
self.timed_call.cancel()
@@ -92,37 +91,23 @@ class EmailPusher:
pass
self.timed_call = None
- def on_new_notifications(self, max_token: RoomStreamToken):
- # We just use the minimum stream ordering and ignore the vector clock
- # component. This is safe to do as long as we *always* ignore the vector
- # clock components.
- max_stream_ordering = max_token.stream
-
- if self.max_stream_ordering:
- self.max_stream_ordering = max(
- max_stream_ordering, self.max_stream_ordering
- )
- else:
- self.max_stream_ordering = max_stream_ordering
- self._start_processing()
-
- def on_new_receipts(self, min_stream_id, max_stream_id):
+ def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
# timer fire
pass
- def on_timer(self):
+ def on_timer(self) -> None:
self.timed_call = None
self._start_processing()
- def _start_processing(self):
+ def _start_processing(self) -> None:
if self._is_processing:
return
run_as_background_process("emailpush.process", self._process)
- def _pause_processing(self):
+ def _pause_processing(self) -> None:
"""Used by tests to temporarily pause processing of events.
Asserts that its not currently processing.
@@ -130,25 +115,26 @@ class EmailPusher:
assert not self._is_processing
self._is_processing = True
- def _resume_processing(self):
- """Used by tests to resume processing of events after pausing.
- """
+ def _resume_processing(self) -> None:
+ """Used by tests to resume processing of events after pausing."""
assert self._is_processing
self._is_processing = False
self._start_processing()
- async def _process(self):
+ async def _process(self) -> None:
# we should never get here if we are already processing
assert not self._is_processing
try:
self._is_processing = True
- if self.throttle_params is None:
+ if not self._inited:
# this is our first loop: load up the throttle params
+ assert self.pusher_id is not None
self.throttle_params = await self.store.get_throttle_params_by_room(
self.pusher_id
)
+ self._inited = True
# if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up.
@@ -163,17 +149,20 @@ class EmailPusher:
finally:
self._is_processing = False
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
"""
Main logic of the push loop without the wrapper function that sets
up logging, measures and guards against multiple instances of it
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
- fn = self.store.get_unread_push_actions_for_user_in_range_for_email
- unprocessed = await fn(self.user_id, start, self.max_stream_ordering)
+ unprocessed = (
+ await self.store.get_unread_push_actions_for_user_in_range_for_email(
+ self.user_id, start, self.max_stream_ordering
+ )
+ )
- soonest_due_at = None
+ soonest_due_at = None # type: Optional[int]
if not unprocessed:
await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
@@ -230,46 +219,48 @@ class EmailPusher:
self.seconds_until(soonest_due_at), self.on_timer
)
- async def save_last_stream_ordering_and_success(self, last_stream_ordering):
- if last_stream_ordering is None:
- # This happens if we haven't yet processed anything
- return
-
+ async def save_last_stream_ordering_and_success(
+ self, last_stream_ordering: int
+ ) -> None:
self.last_stream_ordering = last_stream_ordering
- pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
- self.app_id,
- self.email,
- self.user_id,
- last_stream_ordering,
- self.clock.time_msec(),
+ pusher_still_exists = (
+ await self.store.update_pusher_last_stream_ordering_and_success(
+ self.app_id,
+ self.email,
+ self.user_id,
+ last_stream_ordering,
+ self.clock.time_msec(),
+ )
)
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
# lets just stop and return.
self.on_stop()
- def seconds_until(self, ts_msec):
+ def seconds_until(self, ts_msec: int) -> float:
secs = (ts_msec - self.clock.time_msec()) / 1000
return max(secs, 0)
- def get_room_throttle_ms(self, room_id):
+ def get_room_throttle_ms(self, room_id: str) -> int:
if room_id in self.throttle_params:
- return self.throttle_params[room_id]["throttle_ms"]
+ return self.throttle_params[room_id].throttle_ms
else:
return 0
- def get_room_last_sent_ts(self, room_id):
+ def get_room_last_sent_ts(self, room_id: str) -> int:
if room_id in self.throttle_params:
- return self.throttle_params[room_id]["last_sent_ts"]
+ return self.throttle_params[room_id].last_sent_ts
else:
return 0
- def room_ready_to_notify_at(self, room_id):
+ def room_ready_to_notify_at(self, room_id: str) -> int:
"""
Determines whether throttling should prevent us from sending an email
for the given room
- Returns: The timestamp when we are next allowed to send an email notif
- for this room
+
+ Returns:
+ The timestamp when we are next allowed to send an email notif
+ for this room
"""
last_sent_ts = self.get_room_last_sent_ts(room_id)
throttle_ms = self.get_room_throttle_ms(room_id)
@@ -277,7 +268,9 @@ class EmailPusher:
may_send_at = last_sent_ts + throttle_ms
return may_send_at
- async def sent_notif_update_throttle(self, room_id, notified_push_action):
+ async def sent_notif_update_throttle(
+ self, room_id: str, notified_push_action: dict
+ ) -> None:
# We have sent a notification, so update the throttle accordingly.
# If the event that triggered the notif happened more than
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
@@ -307,15 +300,16 @@ class EmailPusher:
new_throttle_ms = min(
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
)
- self.throttle_params[room_id] = {
- "last_sent_ts": self.clock.time_msec(),
- "throttle_ms": new_throttle_ms,
- }
+ self.throttle_params[room_id] = ThrottleParams(
+ self.clock.time_msec(),
+ new_throttle_ms,
+ )
+ assert self.pusher_id is not None
await self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id]
)
- async def send_notification(self, push_actions, reason):
+ async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
logger.info("Sending notif email for user %r", self.user_id)
await self.mailer.send_notification_mail(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 0e845212a9..26af5309c1 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -14,19 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import urllib.parse
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
from prometheus_client import Counter
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from twisted.internet.interfaces import IDelayedCall
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.push import PusherConfigException
-from synapse.types import RoomStreamToken
+from synapse.push import Pusher, PusherConfig, PusherConfigException
from . import push_rule_evaluator, push_tools
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
http_push_processed_counter = Counter(
@@ -50,91 +56,77 @@ http_badges_failed_counter = Counter(
)
-class HttpPusher:
+class HttpPusher(Pusher):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
MAX_BACKOFF_SEC = 60 * 60
# This one's in ms because we compare it against the clock
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
- def __init__(self, hs, pusherdict):
- self.hs = hs
- self.store = self.hs.get_datastore()
+ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
+ super().__init__(hs, pusher_config)
self.storage = self.hs.get_storage()
- self.clock = self.hs.get_clock()
- self.state_handler = self.hs.get_state_handler()
- self.user_id = pusherdict["user_name"]
- self.app_id = pusherdict["app_id"]
- self.app_display_name = pusherdict["app_display_name"]
- self.device_display_name = pusherdict["device_display_name"]
- self.pushkey = pusherdict["pushkey"]
- self.pushkey_ts = pusherdict["ts"]
- self.data = pusherdict["data"]
- self.last_stream_ordering = pusherdict["last_stream_ordering"]
+ self.app_display_name = pusher_config.app_display_name
+ self.device_display_name = pusher_config.device_display_name
+ self.pushkey_ts = pusher_config.ts
+ self.data = pusher_config.data
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
- self.failing_since = pusherdict["failing_since"]
- self.timed_call = None
+ self.failing_since = pusher_config.failing_since
+ self.timed_call = None # type: Optional[IDelayedCall]
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
+ self._pusherpool = hs.get_pusherpool()
- # This is the highest stream ordering we know it's safe to process.
- # When new events arrive, we'll be given a window of new events: we
- # should honour this rather than just looking for anything higher
- # because of potential out-of-order event serialisation. This starts
- # off as None though as we don't know any better.
- self.max_stream_ordering = None
-
- if "data" not in pusherdict:
- raise PusherConfigException("No 'data' key for HTTP pusher")
- self.data = pusherdict["data"]
+ self.data = pusher_config.data
+ if self.data is None:
+ raise PusherConfigException("'data' key can not be null for HTTP pusher")
self.name = "%s/%s/%s" % (
- pusherdict["user_name"],
- pusherdict["app_id"],
- pusherdict["pushkey"],
+ pusher_config.user_name,
+ pusher_config.app_id,
+ pusher_config.pushkey,
)
- if self.data is None:
- raise PusherConfigException("data can not be null for HTTP pusher")
-
+ # Validate that there's a URL and it is of the proper form.
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
- self.url = self.data["url"]
+
+ url = self.data["url"]
+ if not isinstance(url, str):
+ raise PusherConfigException("'url' must be a string")
+ url_parts = urllib.parse.urlparse(url)
+ # Note that the specification also says the scheme must be HTTPS, but
+ # it isn't up to the homeserver to verify that.
+ if url_parts.path != "/_matrix/push/v1/notify":
+ raise PusherConfigException(
+ "'url' must have a path of '/_matrix/push/v1/notify'"
+ )
+
+ self.url = url
self.http_client = hs.get_proxied_blacklisted_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]
- def on_started(self, should_check_for_notifs):
+ def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started.
Args:
- should_check_for_notifs (bool): Whether we should immediately
+ should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there
is nothing to send
"""
if should_check_for_notifs:
self._start_processing()
- def on_new_notifications(self, max_token: RoomStreamToken):
- # We just use the minimum stream ordering and ignore the vector clock
- # component. This is safe to do as long as we *always* ignore the vector
- # clock components.
- max_stream_ordering = max_token.stream
-
- self.max_stream_ordering = max(
- max_stream_ordering, self.max_stream_ordering or 0
- )
- self._start_processing()
-
- def on_new_receipts(self, min_stream_id, max_stream_id):
+ def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway...
run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
- async def _update_badge(self):
+ async def _update_badge(self) -> None:
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
badge = await push_tools.get_badge_count(
@@ -144,10 +136,10 @@ class HttpPusher:
)
await self._send_badge(badge)
- def on_timer(self):
+ def on_timer(self) -> None:
self._start_processing()
- def on_stop(self):
+ def on_stop(self) -> None:
if self.timed_call:
try:
self.timed_call.cancel()
@@ -155,13 +147,13 @@ class HttpPusher:
pass
self.timed_call = None
- def _start_processing(self):
+ def _start_processing(self) -> None:
if self._is_processing:
return
run_as_background_process("httppush.process", self._process)
- async def _process(self):
+ async def _process(self) -> None:
# we should never get here if we are already processing
assert not self._is_processing
@@ -180,16 +172,16 @@ class HttpPusher:
finally:
self._is_processing = False
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
"""
Looks for unset notifications and dispatch them, in order
Never call this directly: use _process which will only allow this to
run once per pusher.
"""
-
- fn = self.store.get_unread_push_actions_for_user_in_range_for_http
- unprocessed = await fn(
- self.user_id, self.last_stream_ordering, self.max_stream_ordering
+ unprocessed = (
+ await self.store.get_unread_push_actions_for_user_in_range_for_http(
+ self.user_id, self.last_stream_ordering, self.max_stream_ordering
+ )
)
logger.info(
@@ -216,12 +208,14 @@ class HttpPusher:
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
- pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
- self.app_id,
- self.pushkey,
- self.user_id,
- self.last_stream_ordering,
- self.clock.time_msec(),
+ pusher_still_exists = (
+ await self.store.update_pusher_last_stream_ordering_and_success(
+ self.app_id,
+ self.pushkey,
+ self.user_id,
+ self.last_stream_ordering,
+ self.clock.time_msec(),
+ )
)
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
@@ -257,17 +251,12 @@ class HttpPusher:
)
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
- pusher_still_exists = await self.store.update_pusher_last_stream_ordering(
+ await self.store.update_pusher_last_stream_ordering(
self.app_id,
self.pushkey,
self.user_id,
self.last_stream_ordering,
)
- if not pusher_still_exists:
- # The pusher has been deleted while we were processing, so
- # lets just stop and return.
- self.on_stop()
- return
self.failing_since = None
await self.store.update_pusher_failing_since(
@@ -283,7 +272,7 @@ class HttpPusher:
)
break
- async def _process_one(self, push_action):
+ async def _process_one(self, push_action: dict) -> bool:
if "notify" not in push_action["actions"]:
return True
@@ -301,20 +290,23 @@ class HttpPusher:
if rejected is False:
return False
- if isinstance(rejected, list) or isinstance(rejected, tuple):
+ if isinstance(rejected, (list, tuple)):
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warning(
- ("Ignoring rejected pushkey %s because we didn't send it"), pk,
+ ("Ignoring rejected pushkey %s because we didn't send it"),
+ pk,
)
else:
logger.info("Pushkey %s was rejected: removing", pk)
- await self.hs.remove_pusher(self.app_id, pk, self.user_id)
+ await self._pusherpool.remove_pusher(self.app_id, pk, self.user_id)
return True
- async def _build_notification_dict(self, event, tweaks, badge):
+ async def _build_notification_dict(
+ self, event: EventBase, tweaks: Dict[str, bool], badge: int
+ ) -> Dict[str, Any]:
priority = "low"
if (
event.type == EventTypes.Encrypted
@@ -325,6 +317,8 @@ class HttpPusher:
# or may do so (i.e. is encrypted so has unknown effects).
priority = "high"
+ # This was checked in the __init__, but mypy doesn't seem to know that.
+ assert self.data is not None
if self.data.get("format") == "event_id_only":
d = {
"notification": {
@@ -344,9 +338,7 @@ class HttpPusher:
}
return d
- ctx = await push_tools.get_context_for_event(
- self.storage, self.state_handler, event, self.user_id
- )
+ ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
d = {
"notification": {
@@ -386,7 +378,9 @@ class HttpPusher:
return d
- async def dispatch_push(self, event, tweaks, badge):
+ async def dispatch_push(
+ self, event: EventBase, tweaks: Dict[str, bool], badge: int
+ ) -> Union[bool, Iterable[str]]:
notification_dict = await self._build_notification_dict(event, tweaks, badge)
if not notification_dict:
return []
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 38195c8eea..2e5161de2c 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -19,7 +19,7 @@ import logging
import urllib.parse
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
-from typing import Iterable, List, TypeVar
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
import bleach
import jinja2
@@ -27,16 +27,21 @@ import jinja2
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.config.emailconfig import EmailSubjectConfig
+from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable
from synapse.push.presentable_names import (
calculate_room_name,
descriptor_from_member_events,
name_from_member_event,
)
-from synapse.types import UserID
+from synapse.storage.state import StateFilter
+from synapse.types import StateMap, UserID
from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -93,13 +98,20 @@ ALLOWED_ATTRS = {
class Mailer:
- def __init__(self, hs, app_name, template_html, template_text):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ app_name: str,
+ template_html: jinja2.Template,
+ template_text: jinja2.Template,
+ ):
self.hs = hs
self.template_html = template_html
self.template_text = template_text
self.sendmail = self.hs.get_sendmail()
self.store = self.hs.get_datastore()
+ self.state_store = self.hs.get_storage().state
self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage()
@@ -108,17 +120,19 @@ class Mailer:
logger.info("Created Mailer for app_name %s" % app_name)
- async def send_password_reset_mail(self, email_address, token, client_secret, sid):
+ async def send_password_reset_mail(
+ self, email_address: str, token: str, client_secret: str, sid: str
+ ) -> None:
"""Send an email with a password reset link to a user
Args:
- email_address (str): Email address we're sending the password
+ email_address: Email address we're sending the password
reset to
- token (str): Unique token generated by the server to verify
+ token: Unique token generated by the server to verify
the email was received
- client_secret (str): Unique token generated by the client to
+ client_secret: Unique token generated by the client to
group together multiple email sending attempts
- sid (str): The generated session ID
+ sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
@@ -136,17 +150,19 @@ class Mailer:
template_vars,
)
- async def send_registration_mail(self, email_address, token, client_secret, sid):
+ async def send_registration_mail(
+ self, email_address: str, token: str, client_secret: str, sid: str
+ ) -> None:
"""Send an email with a registration confirmation link to a user
Args:
- email_address (str): Email address we're sending the registration
+ email_address: Email address we're sending the registration
link to
- token (str): Unique token generated by the server to verify
+ token: Unique token generated by the server to verify
the email was received
- client_secret (str): Unique token generated by the client to
+ client_secret: Unique token generated by the client to
group together multiple email sending attempts
- sid (str): The generated session ID
+ sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
@@ -164,18 +180,20 @@ class Mailer:
template_vars,
)
- async def send_add_threepid_mail(self, email_address, token, client_secret, sid):
+ async def send_add_threepid_mail(
+ self, email_address: str, token: str, client_secret: str, sid: str
+ ) -> None:
"""Send an email with a validation link to a user for adding a 3pid to their account
Args:
- email_address (str): Email address we're sending the validation link to
+ email_address: Email address we're sending the validation link to
- token (str): Unique token generated by the server to verify the email was received
+ token: Unique token generated by the server to verify the email was received
- client_secret (str): Unique token generated by the client to group together
+ client_secret: Unique token generated by the client to group together
multiple email sending attempts
- sid (str): The generated session ID
+ sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
@@ -194,16 +212,31 @@ class Mailer:
)
async def send_notification_mail(
- self, app_id, user_id, email_address, push_actions, reason
- ):
- """Send email regarding a user's room notifications"""
+ self,
+ app_id: str,
+ user_id: str,
+ email_address: str,
+ push_actions: Iterable[Dict[str, Any]],
+ reason: Dict[str, Any],
+ ) -> None:
+ """
+ Send email regarding a user's room notifications
+
+ Params:
+ app_id: The application receiving the notification.
+ user_id: The user receiving the notification.
+ email_address: The email address receiving the notification.
+ push_actions: All outstanding notifications.
+ reason: The notification that was ready and is the cause of an email
+ being sent.
+ """
rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
notif_events = await self.store.get_events(
[pa["event_id"] for pa in push_actions]
)
- notifs_by_room = {}
+ notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]]
for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa)
@@ -220,7 +253,7 @@ class Mailer:
except StoreError:
user_display_name = user_id
- async def _fetch_room_state(room_id):
+ async def _fetch_room_state(room_id: str) -> None:
room_state = await self.store.get_current_state_ids(room_id)
state_by_room[room_id] = room_state
@@ -234,7 +267,7 @@ class Mailer:
rooms = []
for r in rooms_in_order:
- roomvars = await self.get_room_vars(
+ roomvars = await self._get_room_vars(
r, user_id, notifs_by_room[r], notif_events, state_by_room[r]
)
rooms.append(roomvars)
@@ -246,13 +279,25 @@ class Mailer:
fallback_to_members=True,
)
- summary_text = await self.make_summary_text(
- notifs_by_room, state_by_room, notif_events, user_id, reason
- )
+ if len(notifs_by_room) == 1:
+ # Only one room has new stuff
+ room_id = list(notifs_by_room.keys())[0]
+
+ summary_text = await self._make_summary_text_single_room(
+ room_id,
+ notifs_by_room[room_id],
+ state_by_room[room_id],
+ notif_events,
+ user_id,
+ )
+ else:
+ summary_text = await self._make_summary_text(
+ notifs_by_room, state_by_room, notif_events, reason
+ )
template_vars = {
"user_display_name": user_display_name,
- "unsubscribe_link": self.make_unsubscribe_link(
+ "unsubscribe_link": self._make_unsubscribe_link(
user_id, app_id, email_address
),
"summary_text": summary_text,
@@ -262,7 +307,9 @@ class Mailer:
await self.send_email(email_address, summary_text, template_vars)
- async def send_email(self, email_address, subject, extra_template_vars):
+ async def send_email(
+ self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
+ ) -> None:
"""Send an email with the given information and template text"""
try:
from_string = self.hs.config.email_notif_from % {"app": self.app_name}
@@ -314,9 +361,28 @@ class Mailer:
)
)
- async def get_room_vars(
- self, room_id, user_id, notifs, notif_events, room_state_ids
- ):
+ async def _get_room_vars(
+ self,
+ room_id: str,
+ user_id: str,
+ notifs: Iterable[Dict[str, Any]],
+ notif_events: Dict[str, EventBase],
+ room_state_ids: StateMap[str],
+ ) -> Dict[str, Any]:
+ """
+ Generate the variables for notifications on a per-room basis.
+
+ Args:
+ room_id: The room ID
+ user_id: The user receiving the notification.
+ notifs: The outstanding push actions for this room.
+ notif_events: The events related to the above notifications.
+ room_state_ids: The event IDs of the current room state.
+
+ Returns:
+ A dictionary to be added to the template context.
+ """
+
# Check if one of the notifs is an invite event for the user.
is_invite = False
for n in notifs:
@@ -333,12 +399,12 @@ class Mailer:
"hash": string_ordinal_total(room_id), # See sender avatar hash
"notifs": [],
"invite": is_invite,
- "link": self.make_room_link(room_id),
- }
+ "link": self._make_room_link(room_id),
+ } # type: Dict[str, Any]
if not is_invite:
for n in notifs:
- notifvars = await self.get_notif_vars(
+ notifvars = await self._get_notif_vars(
n, user_id, notif_events[n["event_id"]], room_state_ids
)
@@ -365,7 +431,26 @@ class Mailer:
return room_vars
- async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
+ async def _get_notif_vars(
+ self,
+ notif: Dict[str, Any],
+ user_id: str,
+ notif_event: EventBase,
+ room_state_ids: StateMap[str],
+ ) -> Dict[str, Any]:
+ """
+ Generate the variables for a single notification.
+
+ Args:
+ notif: The outstanding notification for this room.
+ user_id: The user receiving the notification.
+ notif_event: The event related to the above notification.
+ room_state_ids: The event IDs of the current room state.
+
+ Returns:
+ A dictionary to be added to the template context.
+ """
+
results = await self.store.get_events_around(
notif["room_id"],
notif["event_id"],
@@ -374,7 +459,7 @@ class Mailer:
)
ret = {
- "link": self.make_notif_link(notif),
+ "link": self._make_notif_link(notif),
"ts": notif["received_ts"],
"messages": [],
}
@@ -385,20 +470,51 @@ class Mailer:
the_events.append(notif_event)
for event in the_events:
- messagevars = await self.get_message_vars(notif, event, room_state_ids)
+ messagevars = await self._get_message_vars(notif, event, room_state_ids)
if messagevars is not None:
ret["messages"].append(messagevars)
return ret
- async def get_message_vars(self, notif, event, room_state_ids):
+ async def _get_message_vars(
+ self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
+ ) -> Optional[Dict[str, Any]]:
+ """
+ Generate the variables for a single event, if possible.
+
+ Args:
+ notif: The outstanding notification for this room.
+ event: The event under consideration.
+ room_state_ids: The event IDs of the current room state.
+
+ Returns:
+ A dictionary to be added to the template context, or None if the
+ event cannot be processed.
+ """
if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
return None
- sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
- sender_state_event = await self.store.get_event(sender_state_event_id)
- sender_name = name_from_member_event(sender_state_event)
- sender_avatar_url = sender_state_event.content.get("avatar_url")
+ # Get the sender's name and avatar from the room state.
+ type_state_key = ("m.room.member", event.sender)
+ sender_state_event_id = room_state_ids.get(type_state_key)
+ if sender_state_event_id:
+ sender_state_event = await self.store.get_event(
+ sender_state_event_id
+ ) # type: Optional[EventBase]
+ else:
+ # Attempt to check the historical state for the room.
+ historical_state = await self.state_store.get_state_for_event(
+ event.event_id, StateFilter.from_types((type_state_key,))
+ )
+ sender_state_event = historical_state.get(type_state_key)
+
+ if sender_state_event:
+ sender_name = name_from_member_event(sender_state_event)
+ sender_avatar_url = sender_state_event.content.get("avatar_url")
+ else:
+ # No state could be found, fallback to the MXID.
+ sender_name = event.sender
+ sender_avatar_url = None
# 'hash' for deterministically picking default images: use
# sender_hash % the number of default images to choose from
@@ -423,16 +539,25 @@ class Mailer:
ret["msgtype"] = msgtype
if msgtype == "m.text":
- self.add_text_message_vars(ret, event)
+ self._add_text_message_vars(ret, event)
elif msgtype == "m.image":
- self.add_image_message_vars(ret, event)
+ self._add_image_message_vars(ret, event)
if "body" in event.content:
ret["body_text_plain"] = event.content["body"]
return ret
- def add_text_message_vars(self, messagevars, event):
+ def _add_text_message_vars(
+ self, messagevars: Dict[str, Any], event: EventBase
+ ) -> None:
+ """
+ Potentially add a sanitised message body to the message variables.
+
+ Args:
+ messagevars: The template context to be modified.
+ event: The event under consideration.
+ """
msgformat = event.content.get("format")
messagevars["format"] = msgformat
@@ -445,142 +570,232 @@ class Mailer:
elif body:
messagevars["body_text_html"] = safe_text(body)
- return messagevars
+ def _add_image_message_vars(
+ self, messagevars: Dict[str, Any], event: EventBase
+ ) -> None:
+ """
+ Potentially add an image URL to the message variables.
- def add_image_message_vars(self, messagevars, event):
- messagevars["image_url"] = event.content["url"]
+ Args:
+ messagevars: The template context to be modified.
+ event: The event under consideration.
+ """
+ if "url" in event.content:
+ messagevars["image_url"] = event.content["url"]
+
+ async def _make_summary_text_single_room(
+ self,
+ room_id: str,
+ notifs: List[Dict[str, Any]],
+ room_state_ids: StateMap[str],
+ notif_events: Dict[str, EventBase],
+ user_id: str,
+ ) -> str:
+ """
+ Make a summary text for the email when only a single room has notifications.
- return messagevars
+ Args:
+ room_id: The ID of the room.
+ notifs: The push actions for this room.
+ room_state_ids: The state map for the room.
+ notif_events: A map of event ID -> notification event.
+ user_id: The user receiving the notification.
+
+ Returns:
+ The summary text.
+ """
+ # If the room has some kind of name, use it, but we don't
+ # want the generated-from-names one here otherwise we'll
+ # end up with, "new message from Bob in the Bob room"
+ room_name = await calculate_room_name(
+ self.store, room_state_ids, user_id, fallback_to_members=False
+ )
- async def make_summary_text(
- self, notifs_by_room, room_state_ids, notif_events, user_id, reason
- ):
- if len(notifs_by_room) == 1:
- # Only one room has new stuff
- room_id = list(notifs_by_room.keys())[0]
+ # See if one of the notifs is an invite event for the user
+ invite_event = None
+ for n in notifs:
+ ev = notif_events[n["event_id"]]
+ if ev.type == EventTypes.Member and ev.state_key == user_id:
+ if ev.content.get("membership") == Membership.INVITE:
+ invite_event = ev
+ break
- # If the room has some kind of name, use it, but we don't
- # want the generated-from-names one here otherwise we'll
- # end up with, "new message from Bob in the Bob room"
- room_name = await calculate_room_name(
- self.store, room_state_ids[room_id], user_id, fallback_to_members=False
+ if invite_event:
+ inviter_member_event_id = room_state_ids.get(
+ ("m.room.member", invite_event.sender)
)
-
- # See if one of the notifs is an invite event for the user
- invite_event = None
- for n in notifs_by_room[room_id]:
- ev = notif_events[n["event_id"]]
- if ev.type == EventTypes.Member and ev.state_key == user_id:
- if ev.content.get("membership") == Membership.INVITE:
- invite_event = ev
- break
-
- if invite_event:
- inviter_member_event_id = room_state_ids[room_id].get(
- ("m.room.member", invite_event.sender)
+ inviter_name = invite_event.sender
+ if inviter_member_event_id:
+ inviter_member_event = await self.store.get_event(
+ inviter_member_event_id, allow_none=True
)
- inviter_name = invite_event.sender
- if inviter_member_event_id:
- inviter_member_event = await self.store.get_event(
- inviter_member_event_id, allow_none=True
- )
- if inviter_member_event:
- inviter_name = name_from_member_event(inviter_member_event)
-
- if room_name is None:
- return self.email_subjects.invite_from_person % {
- "person": inviter_name,
- "app": self.app_name,
- }
- else:
- return self.email_subjects.invite_from_person_to_room % {
- "person": inviter_name,
- "room": room_name,
- "app": self.app_name,
- }
+ if inviter_member_event:
+ inviter_name = name_from_member_event(inviter_member_event)
+ if room_name is None:
+ return self.email_subjects.invite_from_person % {
+ "person": inviter_name,
+ "app": self.app_name,
+ }
+
+ return self.email_subjects.invite_from_person_to_room % {
+ "person": inviter_name,
+ "room": room_name,
+ "app": self.app_name,
+ }
+
+ if len(notifs) == 1:
+ # There is just the one notification, so give some detail
sender_name = None
- if len(notifs_by_room[room_id]) == 1:
- # There is just the one notification, so give some detail
- event = notif_events[notifs_by_room[room_id][0]["event_id"]]
- if ("m.room.member", event.sender) in room_state_ids[room_id]:
- state_event_id = room_state_ids[room_id][
- ("m.room.member", event.sender)
- ]
- state_event = await self.store.get_event(state_event_id)
- sender_name = name_from_member_event(state_event)
-
- if sender_name is not None and room_name is not None:
- return self.email_subjects.message_from_person_in_room % {
- "person": sender_name,
- "room": room_name,
- "app": self.app_name,
- }
- elif sender_name is not None:
- return self.email_subjects.message_from_person % {
- "person": sender_name,
- "app": self.app_name,
- }
- else:
- # There's more than one notification for this room, so just
- # say there are several
- if room_name is not None:
- return self.email_subjects.messages_in_room % {
- "room": room_name,
- "app": self.app_name,
- }
- else:
- # If the room doesn't have a name, say who the messages
- # are from explicitly to avoid, "messages in the Bob room"
- sender_ids = list(
- {
- notif_events[n["event_id"]].sender
- for n in notifs_by_room[room_id]
- }
- )
-
- member_events = await self.store.get_events(
- [
- room_state_ids[room_id][("m.room.member", s)]
- for s in sender_ids
- ]
- )
-
- return self.email_subjects.messages_from_person % {
- "person": descriptor_from_member_events(member_events.values()),
- "app": self.app_name,
- }
- else:
- # Stuff's happened in multiple different rooms
+ event = notif_events[notifs[0]["event_id"]]
+ if ("m.room.member", event.sender) in room_state_ids:
+ state_event_id = room_state_ids[("m.room.member", event.sender)]
+ state_event = await self.store.get_event(state_event_id)
+ sender_name = name_from_member_event(state_event)
+
+ if sender_name is not None and room_name is not None:
+ return self.email_subjects.message_from_person_in_room % {
+ "person": sender_name,
+ "room": room_name,
+ "app": self.app_name,
+ }
+ elif sender_name is not None:
+ return self.email_subjects.message_from_person % {
+ "person": sender_name,
+ "app": self.app_name,
+ }
- # ...but we still refer to the 'reason' room which triggered the mail
- if reason["room_name"] is not None:
- return self.email_subjects.messages_in_room_and_others % {
- "room": reason["room_name"],
+ # The sender is unknown, just use the room name (or ID).
+ return self.email_subjects.messages_in_room % {
+ "room": room_name or room_id,
+ "app": self.app_name,
+ }
+ else:
+ # There's more than one notification for this room, so just
+ # say there are several
+ if room_name is not None:
+ return self.email_subjects.messages_in_room % {
+ "room": room_name,
"app": self.app_name,
}
+
+ return await self._make_summary_text_from_member_events(
+ room_id, notifs, room_state_ids, notif_events
+ )
+
+ async def _make_summary_text(
+ self,
+ notifs_by_room: Dict[str, List[Dict[str, Any]]],
+ room_state_ids: Dict[str, StateMap[str]],
+ notif_events: Dict[str, EventBase],
+ reason: Dict[str, Any],
+ ) -> str:
+ """
+ Make a summary text for the email when multiple rooms have notifications.
+
+ Args:
+ notifs_by_room: A map of room ID to the push actions for that room.
+ room_state_ids: A map of room ID to the state map for that room.
+ notif_events: A map of event ID -> notification event.
+ reason: The reason this notification is being sent.
+
+ Returns:
+ The summary text.
+ """
+ # Stuff's happened in multiple different rooms
+ # ...but we still refer to the 'reason' room which triggered the mail
+ if reason["room_name"] is not None:
+ return self.email_subjects.messages_in_room_and_others % {
+ "room": reason["room_name"],
+ "app": self.app_name,
+ }
+
+ room_id = reason["room_id"]
+ return await self._make_summary_text_from_member_events(
+ room_id, notifs_by_room[room_id], room_state_ids[room_id], notif_events
+ )
+
+ async def _make_summary_text_from_member_events(
+ self,
+ room_id: str,
+ notifs: List[Dict[str, Any]],
+ room_state_ids: StateMap[str],
+ notif_events: Dict[str, EventBase],
+ ) -> str:
+ """
+ Make a summary text for the email when only a single room has notifications.
+
+ Args:
+ room_id: The ID of the room.
+ notifs: The push actions for this room.
+ room_state_ids: The state map for the room.
+ notif_events: A map of event ID -> notification event.
+
+ Returns:
+ The summary text.
+ """
+ # If the room doesn't have a name, say who the messages
+ # are from explicitly to avoid, "messages in the Bob room"
+
+ # Find the latest event ID for each sender, note that the notifications
+ # are already in descending received_ts.
+ sender_ids = {}
+ for n in notifs:
+ sender = notif_events[n["event_id"]].sender
+ if sender not in sender_ids:
+ sender_ids[sender] = n["event_id"]
+
+ # Get the actual member events (in order to calculate a pretty name for
+ # the room).
+ member_event_ids = []
+ member_events = {}
+ for sender_id, event_id in sender_ids.items():
+ type_state_key = ("m.room.member", sender_id)
+ sender_state_event_id = room_state_ids.get(type_state_key)
+ if sender_state_event_id:
+ member_event_ids.append(sender_state_event_id)
else:
- # If the reason room doesn't have a name, say who the messages
- # are from explicitly to avoid, "messages in the Bob room"
- room_id = reason["room_id"]
-
- sender_ids = list(
- {
- notif_events[n["event_id"]].sender
- for n in notifs_by_room[room_id]
- }
+ # Attempt to check the historical state for the room.
+ historical_state = await self.state_store.get_state_for_event(
+ event_id, StateFilter.from_types((type_state_key,))
)
+ sender_state_event = historical_state.get(type_state_key)
+ if sender_state_event:
+ member_events[event_id] = sender_state_event
+ member_events.update(await self.store.get_events(member_event_ids))
+
+ if not member_events:
+ # No member events were found! Maybe the room is empty?
+ # Fallback to the room ID (note that if there was a room name this
+ # would already have been used previously).
+ return self.email_subjects.messages_in_room % {
+ "room": room_id,
+ "app": self.app_name,
+ }
+
+ # There was a single sender.
+ if len(member_events) == 1:
+ return self.email_subjects.messages_from_person % {
+ "person": descriptor_from_member_events(member_events.values()),
+ "app": self.app_name,
+ }
+
+ # There was more than one sender, use the first one and a tweaked template.
+ return self.email_subjects.messages_from_person_and_others % {
+ "person": descriptor_from_member_events(list(member_events.values())[:1]),
+ "app": self.app_name,
+ }
- member_events = await self.store.get_events(
- [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
- )
+ def _make_room_link(self, room_id: str) -> str:
+ """
+ Generate a link to open a room in the web client.
- return self.email_subjects.messages_from_person_and_others % {
- "person": descriptor_from_member_events(member_events.values()),
- "app": self.app_name,
- }
+ Args:
+ room_id: The room ID to generate a link to.
- def make_room_link(self, room_id):
+ Returns:
+ A link to open a room in the web client.
+ """
if self.hs.config.email_riot_base_url:
base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
elif self.app_name == "Vector":
@@ -590,7 +805,16 @@ class Mailer:
base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id)
- def make_notif_link(self, notif):
+ def _make_notif_link(self, notif: Dict[str, str]) -> str:
+ """
+ Generate a link to open an event in the web client.
+
+ Args:
+ notif: The notification to generate a link for.
+
+ Returns:
+ A link to open the notification in the web client.
+ """
if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
@@ -606,7 +830,20 @@ class Mailer:
else:
return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
- def make_unsubscribe_link(self, user_id, app_id, email_address):
+ def _make_unsubscribe_link(
+ self, user_id: str, app_id: str, email_address: str
+ ) -> str:
+ """
+ Generate a link to unsubscribe from email notifications.
+
+ Args:
+ user_id: The user receiving the notification.
+ app_id: The application receiving the notification.
+ email_address: The email address receiving the notification.
+
+ Returns:
+ A link to unsubscribe from email notifications.
+ """
params = {
"access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
"app_id": app_id,
@@ -620,7 +857,16 @@ class Mailer:
)
-def safe_markup(raw_html):
+def safe_markup(raw_html: str) -> jinja2.Markup:
+ """
+ Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs.
+
+ Args
+ raw_html: Unsafe HTML.
+
+ Returns:
+ A Markup object ready to safely use in a Jinja template.
+ """
return jinja2.Markup(
bleach.linkify(
bleach.clean(
@@ -635,10 +881,15 @@ def safe_markup(raw_html):
)
-def safe_text(raw_text):
+def safe_text(raw_text: str) -> jinja2.Markup:
"""
- Process text: treat it as HTML but escape any tags (ie. just escape the
- HTML) then linkify it.
+ Sanitise text (escape any HTML tags), and then linkify any bare URLs.
+
+ Args
+ raw_text: Unsafe text which might include HTML markup.
+
+ Returns:
+ A Markup object ready to safely use in a Jinja template.
"""
return jinja2.Markup(
bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False))
@@ -655,7 +906,7 @@ def deduped_ordered_list(it: Iterable[T]) -> List[T]:
return ret
-def string_ordinal_total(s):
+def string_ordinal_total(s: str) -> int:
tot = 0
for c in s:
tot += ord(c)
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index d8f4a453cd..04c2c1482c 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -15,8 +15,14 @@
import logging
import re
+from typing import TYPE_CHECKING, Dict, Iterable, Optional
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
+from synapse.events import EventBase
+from synapse.types import StateMap
+
+if TYPE_CHECKING:
+ from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -28,32 +34,36 @@ ALL_ALONE = "Empty Room"
async def calculate_room_name(
- store,
- room_state_ids,
- user_id,
- fallback_to_members=True,
- fallback_to_single_member=True,
-):
+ store: "DataStore",
+ room_state_ids: StateMap[str],
+ user_id: str,
+ fallback_to_members: bool = True,
+ fallback_to_single_member: bool = True,
+) -> Optional[str]:
"""
Works out a user-facing name for the given room as per Matrix
spec recommendations.
Does not yet support internationalisation.
Args:
- room_state: Dictionary of the room's state
+ store: The data store to query.
+ room_state_ids: Dictionary of the room's state IDs.
user_id: The ID of the user to whom the room name is being presented
fallback_to_members: If False, return None instead of generating a name
based on the room's members if the room has no
title or aliases.
+ fallback_to_single_member: If False, return None instead of generating a
+ name based on the user who invited this user to the room if the room
+ has no title or aliases.
Returns:
- (string or None) A human readable name for the room.
+ A human readable name for the room, if possible.
"""
# does it have a name?
if (EventTypes.Name, "") in room_state_ids:
m_room_name = await store.get_event(
room_state_ids[(EventTypes.Name, "")], allow_none=True
)
- if m_room_name and m_room_name.content and m_room_name.content["name"]:
+ if m_room_name and m_room_name.content and m_room_name.content.get("name"):
return m_room_name.content["name"]
# does it have a canonical alias?
@@ -64,15 +74,11 @@ async def calculate_room_name(
if (
canon_alias
and canon_alias.content
- and canon_alias.content["alias"]
+ and canon_alias.content.get("alias")
and _looks_like_an_alias(canon_alias.content["alias"])
):
return canon_alias.content["alias"]
- # at this point we're going to need to search the state by all state keys
- # for an event type, so rearrange the data structure
- room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
-
if not fallback_to_members:
return None
@@ -84,7 +90,7 @@ async def calculate_room_name(
if (
my_member_event is not None
- and my_member_event.content["membership"] == "invite"
+ and my_member_event.content.get("membership") == Membership.INVITE
):
if (EventTypes.Member, my_member_event.sender) in room_state_ids:
inviter_member_event = await store.get_event(
@@ -97,10 +103,14 @@ async def calculate_room_name(
name_from_member_event(inviter_member_event),
)
else:
- return
+ return None
else:
return "Room Invite"
+ # at this point we're going to need to search the state by all state keys
+ # for an event type, so rearrange the data structure
+ room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
+
# we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user.
if EventTypes.Member in room_state_bytype_ids:
@@ -110,8 +120,8 @@ async def calculate_room_name(
all_members = [
ev
for ev in member_events.values()
- if ev.content["membership"] == "join"
- or ev.content["membership"] == "invite"
+ if ev.content.get("membership") == Membership.JOIN
+ or ev.content.get("membership") == Membership.INVITE
]
# Sort the member events oldest-first so the we name people in the
# order the joined (it should at least be deterministic rather than
@@ -150,19 +160,19 @@ async def calculate_room_name(
else:
return ALL_ALONE
elif len(other_members) == 1 and not fallback_to_single_member:
- return
- else:
- return descriptor_from_member_events(other_members)
+ return None
+ return descriptor_from_member_events(other_members)
-def descriptor_from_member_events(member_events):
+
+def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
"""Get a description of the room based on the member events.
Args:
- member_events (Iterable[FrozenEvent])
+ member_events: The events of a room.
Returns:
- str
+ The room description
"""
member_events = list(member_events)
@@ -183,22 +193,18 @@ def descriptor_from_member_events(member_events):
)
-def name_from_member_event(member_event):
- if (
- member_event.content
- and "displayname" in member_event.content
- and member_event.content["displayname"]
- ):
+def name_from_member_event(member_event: EventBase) -> str:
+ if member_event.content and member_event.content.get("displayname"):
return member_event.content["displayname"]
return member_event.state_key
-def _state_as_two_level_dict(state):
- ret = {}
+def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
+ ret = {} # type: Dict[str, Dict[str, str]]
for k, v in state.items():
ret.setdefault(k[0], {})[k[1]] = v
return ret
-def _looks_like_an_alias(string):
+def _looks_like_an_alias(string: str) -> bool:
return ALIAS_RE.match(string) is not None
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 2ce9e444ab..ba1877adcd 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -30,22 +30,30 @@ IS_GLOB = re.compile(r"[\?\*\[\]]")
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
-def _room_member_count(ev, condition, room_member_count):
+def _room_member_count(
+ ev: EventBase, condition: Dict[str, Any], room_member_count: int
+) -> bool:
return _test_ineq_condition(condition, room_member_count)
-def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
+def _sender_notification_permission(
+ ev: EventBase,
+ condition: Dict[str, Any],
+ sender_power_level: int,
+ power_levels: Dict[str, Union[int, Dict[str, int]]],
+) -> bool:
notif_level_key = condition.get("key")
if notif_level_key is None:
return False
notif_levels = power_levels.get("notifications", {})
+ assert isinstance(notif_levels, dict)
room_notif_level = notif_levels.get(notif_level_key, 50)
return sender_power_level >= room_notif_level
-def _test_ineq_condition(condition, number):
+def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
if "is" not in condition:
return False
m = INEQUALITY_EXPR.match(condition["is"])
@@ -110,7 +118,7 @@ class PushRuleEvaluatorForEvent:
event: EventBase,
room_member_count: int,
sender_power_level: int,
- power_levels: dict,
+ power_levels: Dict[str, Union[int, Dict[str, int]]],
):
self._event = event
self._room_member_count = room_member_count
@@ -120,7 +128,9 @@ class PushRuleEvaluatorForEvent:
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
- def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
+ def matches(
+ self, condition: Dict[str, Any], user_id: str, display_name: str
+ ) -> bool:
if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
elif condition["kind"] == "contains_display_name":
@@ -261,7 +271,13 @@ def _re_word_boundary(r: str) -> str:
return r"(^|\W)%s(\W|$)" % (r,)
-def _flatten_dict(d, prefix=[], result=None):
+def _flatten_dict(
+ d: Union[EventBase, dict],
+ prefix: Optional[List[str]] = None,
+ result: Optional[Dict[str, str]] = None,
+) -> Dict[str, str]:
+ if prefix is None:
+ prefix = []
if result is None:
result = {}
for key, value in d.items():
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 6e7c880dc0..df34103224 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict
+
+from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore
@@ -46,7 +49,9 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
return badge
-async def get_context_for_event(storage: Storage, state_handler, ev, user_id):
+async def get_context_for_event(
+ storage: Storage, ev: EventBase, user_id: str
+) -> Dict[str, str]:
ctx = {}
room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 2a52e226e3..cb94127850 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -14,25 +14,31 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Callable, Dict, Optional
+from synapse.push import Pusher, PusherConfig
from synapse.push.emailpusher import EmailPusher
+from synapse.push.httppusher import HttpPusher
from synapse.push.mailer import Mailer
-from .httppusher import HttpPusher
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class PusherFactory:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.config = hs.config
- self.pusher_types = {"http": HttpPusher}
+ self.pusher_types = {
+ "http": HttpPusher
+ } # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]]
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
- self.mailers = {} # app_name -> Mailer
+ self.mailers = {} # type: Dict[str, Mailer]
self._notif_template_html = hs.config.email_notif_template_html
self._notif_template_text = hs.config.email_notif_template_text
@@ -41,16 +47,18 @@ class PusherFactory:
logger.info("defined email pusher type")
- def create_pusher(self, pusherdict):
- kind = pusherdict["kind"]
+ def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
+ kind = pusher_config.kind
f = self.pusher_types.get(kind, None)
if not f:
return None
- logger.debug("creating %s pusher for %r", kind, pusherdict)
- return f(self.hs, pusherdict)
+ logger.debug("creating %s pusher for %r", kind, pusher_config)
+ return f(self.hs, pusher_config)
- def _create_email_pusher(self, _hs, pusherdict):
- app_name = self._app_name_from_pusherdict(pusherdict)
+ def _create_email_pusher(
+ self, _hs: "HomeServer", pusher_config: PusherConfig
+ ) -> EmailPusher:
+ app_name = self._app_name_from_pusherdict(pusher_config)
mailer = self.mailers.get(app_name)
if not mailer:
mailer = Mailer(
@@ -60,10 +68,10 @@ class PusherFactory:
template_text=self._notif_template_text,
)
self.mailers[app_name] = mailer
- return EmailPusher(self.hs, pusherdict, mailer)
+ return EmailPusher(self.hs, pusher_config, mailer)
- def _app_name_from_pusherdict(self, pusherdict):
- data = pusherdict["data"]
+ def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str:
+ data = pusher_config.data
if isinstance(data, dict):
brand = data.get("brand")
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 32794af6ca..2534ecb1d4 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,19 +15,19 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Union
+from typing import TYPE_CHECKING, Dict, Iterable, Optional
from prometheus_client import Gauge
+from synapse.api.errors import Codes, SynapseError
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
-from synapse.push import PusherConfigException
-from synapse.push.emailpusher import EmailPusher
-from synapse.push.httppusher import HttpPusher
+from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.push.pusher import PusherFactory
-from synapse.types import RoomStreamToken
+from synapse.replication.http.push import ReplicationRemovePusherRestServlet
+from synapse.types import JsonDict, RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING:
@@ -60,7 +60,6 @@ class PusherPool:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.pusher_factory = PusherFactory(hs)
- self._should_start_pushers = hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
@@ -69,6 +68,16 @@ class PusherPool:
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
+ self._should_start_pushers = (
+ self._instance_name in self._pusher_shard_config.instances
+ )
+
+ # We can only delete pushers on master.
+ self._remove_pusher_client = None
+ if hs.config.worker.worker_app:
+ self._remove_pusher_client = ReplicationRemovePusherRestServlet.make_client(
+ hs
+ )
# Record the last stream ID that we were poked about so we can get
# changes since then. We set this to the current max stream ID on
@@ -77,11 +86,10 @@ class PusherPool:
self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
# map from user id to app_id:pushkey to pusher
- self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
+ self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
- def start(self):
- """Starts the pushers off in a background process.
- """
+ def start(self) -> None:
+ """Starts the pushers off in a background process."""
if not self._should_start_pushers:
logger.info("Not starting pushers because they are disabled in the config")
return
@@ -89,52 +97,58 @@ class PusherPool:
async def add_pusher(
self,
- user_id,
- access_token,
- kind,
- app_id,
- app_display_name,
- device_display_name,
- pushkey,
- lang,
- data,
- profile_tag="",
- ):
+ user_id: str,
+ access_token: Optional[int],
+ kind: str,
+ app_id: str,
+ app_display_name: str,
+ device_display_name: str,
+ pushkey: str,
+ lang: Optional[str],
+ data: JsonDict,
+ profile_tag: str = "",
+ ) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool
Returns:
- EmailPusher|HttpPusher
+ The newly created pusher.
"""
+ if kind == "email":
+ email_owner = await self.store.get_user_id_by_threepid("email", pushkey)
+ if email_owner != user_id:
+ raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
+
time_now_msec = self.clock.time_msec()
+ # create the pusher setting last_stream_ordering to the current maximum
+ # stream ordering, so it will process pushes from this point onwards.
+ last_stream_ordering = self.store.get_room_max_stream_ordering()
+
# we try to create the pusher just to validate the config: it
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
self.pusher_factory.create_pusher(
- {
- "id": None,
- "user_name": user_id,
- "kind": kind,
- "app_id": app_id,
- "app_display_name": app_display_name,
- "device_display_name": device_display_name,
- "pushkey": pushkey,
- "ts": time_now_msec,
- "lang": lang,
- "data": data,
- "last_stream_ordering": None,
- "last_success": None,
- "failing_since": None,
- }
+ PusherConfig(
+ id=None,
+ user_name=user_id,
+ access_token=access_token,
+ profile_tag=profile_tag,
+ kind=kind,
+ app_id=app_id,
+ app_display_name=app_display_name,
+ device_display_name=device_display_name,
+ pushkey=pushkey,
+ ts=time_now_msec,
+ lang=lang,
+ data=data,
+ last_stream_ordering=last_stream_ordering,
+ last_success=None,
+ failing_since=None,
+ )
)
- # create the pusher setting last_stream_ordering to the current maximum
- # stream ordering in event_push_actions, so it will process
- # pushes from this point onwards.
- last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
-
await self.store.add_pusher(
user_id=user_id,
access_token=access_token,
@@ -154,43 +168,41 @@ class PusherPool:
return pusher
async def remove_pushers_by_app_id_and_pushkey_not_user(
- self, app_id, pushkey, not_user_id
- ):
+ self, app_id: str, pushkey: str, not_user_id: str
+ ) -> None:
to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove:
- if p["user_name"] != not_user_id:
+ if p.user_name != not_user_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
app_id,
pushkey,
- p["user_name"],
+ p.user_name,
)
- await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+ await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
- async def remove_pushers_by_access_token(self, user_id, access_tokens):
+ async def remove_pushers_by_access_token(
+ self, user_id: str, access_tokens: Iterable[int]
+ ) -> None:
"""Remove the pushers for a given user corresponding to a set of
access_tokens.
Args:
- user_id (str): user to remove pushers for
- access_tokens (Iterable[int]): access token *ids* to remove pushers
- for
+ user_id: user to remove pushers for
+ access_tokens: access token *ids* to remove pushers for
"""
- if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
- return
-
tokens = set(access_tokens)
for p in await self.store.get_pushers_by_user_id(user_id):
- if p["access_token"] in tokens:
+ if p.access_token in tokens:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
- p["app_id"],
- p["pushkey"],
- p["user_name"],
+ p.app_id,
+ p.pushkey,
+ p.user_name,
)
- await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+ await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
- def on_new_notifications(self, max_token: RoomStreamToken):
+ def on_new_notifications(self, max_token: RoomStreamToken) -> None:
if not self.pushers:
# nothing to do here.
return
@@ -209,7 +221,7 @@ class PusherPool:
self._on_new_notifications(max_token)
@wrap_as_background_process("on_new_notifications")
- async def _on_new_notifications(self, max_token: RoomStreamToken):
+ async def _on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
@@ -239,7 +251,9 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_notifications")
- async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
+ async def on_new_receipts(
+ self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
+ ) -> None:
if not self.pushers:
# nothing to do here.
return
@@ -267,34 +281,35 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_receipts")
- async def start_pusher_by_id(self, app_id, pushkey, user_id):
+ async def start_pusher_by_id(
+ self, app_id: str, pushkey: str, user_id: str
+ ) -> Optional[Pusher]:
"""Look up the details for the given pusher, and start it
Returns:
- EmailPusher|HttpPusher|None: The pusher started, if any
+ The pusher started, if any
"""
if not self._should_start_pushers:
- return
+ return None
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
- return
+ return None
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
- pusher_dict = None
+ pusher_config = None
for r in resultlist:
- if r["user_name"] == user_id:
- pusher_dict = r
+ if r.user_name == user_id:
+ pusher_config = r
pusher = None
- if pusher_dict:
- pusher = await self._start_pusher(pusher_dict)
+ if pusher_config:
+ pusher = await self._start_pusher(pusher_config)
return pusher
async def _start_pushers(self) -> None:
- """Start all the pushers
- """
+ """Start all the pushers"""
pushers = await self.store.get_all_pushers()
# Stagger starting up the pushers so we don't completely drown the
@@ -303,44 +318,45 @@ class PusherPool:
logger.info("Started pushers")
- async def _start_pusher(self, pusherdict):
+ async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
"""Start the given pusher
Args:
- pusherdict (dict): dict with the values pulled from the db table
+ pusher_config: The pusher configuration with the values pulled from the db table
Returns:
- EmailPusher|HttpPusher
+ The newly created pusher or None.
"""
if not self._pusher_shard_config.should_handle(
- self._instance_name, pusherdict["user_name"]
+ self._instance_name, pusher_config.user_name
):
- return
+ return None
try:
- p = self.pusher_factory.create_pusher(pusherdict)
+ p = self.pusher_factory.create_pusher(pusher_config)
except PusherConfigException as e:
logger.warning(
"Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
- pusherdict["id"],
- pusherdict.get("user_name"),
- pusherdict.get("app_id"),
- pusherdict.get("pushkey"),
+ pusher_config.id,
+ pusher_config.user_name,
+ pusher_config.app_id,
+ pusher_config.pushkey,
e,
)
- return
+ return None
except Exception:
logger.exception(
- "Couldn't start pusher id %i: caught Exception", pusherdict["id"],
+ "Couldn't start pusher id %i: caught Exception",
+ pusher_config.id,
)
- return
+ return None
if not p:
- return
+ return None
- appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
+ appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
- byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+ byuser = self.pushers.setdefault(pusher_config.user_name, {})
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
@@ -350,8 +366,8 @@ class PusherPool:
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
# push.
- user_id = pusherdict["user_name"]
- last_stream_ordering = pusherdict["last_stream_ordering"]
+ user_id = pusher_config.user_name
+ last_stream_ordering = pusher_config.last_stream_ordering
if last_stream_ordering:
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering
@@ -365,7 +381,7 @@ class PusherPool:
return p
- async def remove_pusher(self, app_id, pushkey, user_id):
+ async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {})
@@ -377,6 +393,12 @@ class PusherPool:
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
- await self.store.delete_pusher_by_app_id_pushkey_user_id(
- app_id, pushkey, user_id
- )
+ # We can only delete pushers on master.
+ if self._remove_pusher_client:
+ await self._remove_pusher_client(
+ app_id=app_id, pushkey=pushkey, user_id=user_id
+ )
+ else:
+ await self.store.delete_pusher_by_app_id_pushkey_user_id(
+ app_id, pushkey, user_id
+ )
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 8e5d3e3c52..2a1c925ee8 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import logging
from typing import List, Set
@@ -82,19 +83,26 @@ REQUIREMENTS = [
"Jinja2>=2.9",
"bleach>=1.4.3",
"typing-extensions>=3.7.4",
+ # We enforce that we have a `cryptography` version that bundles an `openssl`
+ # with the latest security patches.
+ "cryptography>=3.4.7;python_version>='3.6'",
]
CONDITIONAL_REQUIREMENTS = {
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
- # we use execute_batch, which arrived in psycopg 2.7.
- "postgres": ["psycopg2>=2.7"],
+ "postgres": [
+ # we use execute_values with the fetch param, which arrived in psycopg 2.8.
+ "psycopg2>=2.8 ; platform_python_implementation != 'PyPy'",
+ "psycopg2cffi>=2.8 ; platform_python_implementation == 'PyPy'",
+ "psycopg2cffi-compat==1.1 ; platform_python_implementation == 'PyPy'",
+ ],
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
"acme": [
"txacme>=0.9.2",
# txacme depends on eliot. Eliot 1.8.0 is incompatible with
# python 3.5.2, as per https://github.com/itamarst/eliot/issues/418
- 'eliot<1.8.0;python_version<"3.5.3"',
+ "eliot<1.8.0;python_version<'3.5.3'",
],
"saml2": [
# pysaml2 6.4.0 is incompatible with Python 3.5 (see https://github.com/IdentityPython/pysaml2/issues/749)
@@ -102,6 +110,9 @@ CONDITIONAL_REQUIREMENTS = {
"pysaml2>=4.5.0;python_version>='3.6'",
],
"oidc": ["authlib>=0.14.0"],
+ # systemd-python is necessary for logging to the systemd journal via
+ # `systemd.journal.JournalHandler`, as is documented in
+ # `contrib/systemd/log_config.yaml`.
"systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"],
"sentry": ["sentry-sdk>=0.7.2"],
@@ -112,8 +123,6 @@ CONDITIONAL_REQUIREMENTS = {
"redis": ["txredisapi>=1.4.7", "hiredis"],
}
-CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
-
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
@@ -123,6 +132,18 @@ for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS
+# ensure there are no double-quote characters in any of the deps (otherwise the
+# 'pip install' incantation in DependencyException will break)
+for dep in itertools.chain(
+ REQUIREMENTS,
+ *CONDITIONAL_REQUIREMENTS.values(),
+):
+ if '"' in dep:
+ raise Exception(
+ "Dependency `%s` contains double-quote; use single-quotes instead" % (dep,)
+ )
+
+
def list_requirements():
return list(set(REQUIREMENTS) | ALL_OPTIONAL_REQUIREMENTS)
@@ -142,7 +163,7 @@ class DependencyException(Exception):
@property
def dependencies(self):
for i in self.args[0]:
- yield "'" + i + "'"
+ yield '"' + i + '"'
def check_requirements(for_feature=None):
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index a84a064c8d..cb4a52dbe9 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -15,11 +15,13 @@
from synapse.http.server import JsonResource
from synapse.replication.http import (
+ account_data,
devices,
federation,
login,
membership,
presence,
+ push,
register,
send_event,
streams,
@@ -40,6 +42,8 @@ class ReplicationRestResource(JsonResource):
presence.register_servlets(hs, self)
membership.register_servlets(hs, self)
streams.register_servlets(hs, self)
+ account_data.register_servlets(hs, self)
+ push.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 2b3972cb14..b7aa0c280f 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -18,7 +18,7 @@ import logging
import re
import urllib
from inspect import signature
-from typing import Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, List, Tuple
from prometheus_client import Counter, Gauge
@@ -28,6 +28,9 @@ from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
_pending_outgoing_requests = Gauge(
@@ -88,10 +91,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
CACHE = True
RETRY_ON_TIMEOUT = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
if self.CACHE:
self.response_cache = ResponseCache(
- hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
+ hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
) # type: ResponseCache[str]
# We reserve `instance_name` as a parameter to sending requests, so we
@@ -106,6 +109,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
assert self.METHOD in ("PUT", "POST", "GET")
+ self._replication_secret = None
+ if hs.config.worker.worker_replication_secret:
+ self._replication_secret = hs.config.worker.worker_replication_secret
+
+ def _check_auth(self, request) -> None:
+ # Get the authorization header.
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+
+ if len(auth_headers) > 1:
+ raise RuntimeError("Too many Authorization headers.")
+ parts = auth_headers[0].split(b" ")
+ if parts[0] == b"Bearer" and len(parts) == 2:
+ received_secret = parts[1].decode("ascii")
+ if self._replication_secret == received_secret:
+ # Success!
+ return
+
+ raise RuntimeError("Invalid Authorization header.")
+
@abc.abstractmethod
async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request.
@@ -150,9 +172,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
+ replication_secret = None
+ if hs.config.worker.worker_replication_secret:
+ replication_secret = hs.config.worker.worker_replication_secret.encode(
+ "ascii"
+ )
+
@trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress()
- async def send_request(instance_name="master", **kwargs):
+ async def send_request(*, instance_name="master", **kwargs):
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
if instance_name == "master":
@@ -202,6 +230,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# the master, and so whether we should clean up or not.
while True:
headers = {} # type: Dict[bytes, List[bytes]]
+ # Add an authorization header, if configured.
+ if replication_secret:
+ headers[b"Authorization"] = [b"Bearer " + replication_secret]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
@@ -236,21 +267,22 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""
url_args = list(self.PATH_ARGS)
- handler = self._handle_request
method = self.METHOD
if self.CACHE:
- handler = self._cached_handler # type: ignore
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths(
- method, [pattern], handler, self.__class__.__name__,
+ method,
+ [pattern],
+ self._check_auth_and_handle,
+ self.__class__.__name__,
)
- def _cached_handler(self, request, txn_id, **kwargs):
+ def _check_auth_and_handle(self, request, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
@@ -258,6 +290,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# We just use the txn_id here, but we probably also want to use the
# other PATH_ARGS as well.
- assert self.CACHE
+ # Check the authorization headers before handling the request.
+ if self._replication_secret:
+ self._check_auth(request)
+
+ if self.CACHE:
+ txn_id = kwargs.pop("txn_id")
+
+ return self.response_cache.wrap(
+ txn_id, self._handle_request, request, **kwargs
+ )
- return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
+ return self._handle_request(request, **kwargs)
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
new file mode 100644
index 0000000000..60899b6ad6
--- /dev/null
+++ b/synapse/replication/http/account_data.py
@@ -0,0 +1,191 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
+ """Add user account data on the appropriate account data worker.
+
+ Request format:
+
+ POST /_synapse/replication/add_user_account_data/:user_id/:type
+
+ {
+ "content": { ... },
+ }
+
+ """
+
+ NAME = "add_user_account_data"
+ PATH_ARGS = ("user_id", "account_data_type")
+ CACHE = False
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ self.handler = hs.get_account_data_handler()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload(user_id, account_data_type, content):
+ payload = {
+ "content": content,
+ }
+
+ return payload
+
+ async def _handle_request(self, request, user_id, account_data_type):
+ content = parse_json_object_from_request(request)
+
+ max_stream_id = await self.handler.add_account_data_for_user(
+ user_id, account_data_type, content["content"]
+ )
+
+ return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
+ """Add room account data on the appropriate account data worker.
+
+ Request format:
+
+ POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
+
+ {
+ "content": { ... },
+ }
+
+ """
+
+ NAME = "add_room_account_data"
+ PATH_ARGS = ("user_id", "room_id", "account_data_type")
+ CACHE = False
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ self.handler = hs.get_account_data_handler()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload(user_id, room_id, account_data_type, content):
+ payload = {
+ "content": content,
+ }
+
+ return payload
+
+ async def _handle_request(self, request, user_id, room_id, account_data_type):
+ content = parse_json_object_from_request(request)
+
+ max_stream_id = await self.handler.add_account_data_to_room(
+ user_id, room_id, account_data_type, content["content"]
+ )
+
+ return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationAddTagRestServlet(ReplicationEndpoint):
+ """Add tag on the appropriate account data worker.
+
+ Request format:
+
+ POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
+
+ {
+ "content": { ... },
+ }
+
+ """
+
+ NAME = "add_tag"
+ PATH_ARGS = ("user_id", "room_id", "tag")
+ CACHE = False
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ self.handler = hs.get_account_data_handler()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload(user_id, room_id, tag, content):
+ payload = {
+ "content": content,
+ }
+
+ return payload
+
+ async def _handle_request(self, request, user_id, room_id, tag):
+ content = parse_json_object_from_request(request)
+
+ max_stream_id = await self.handler.add_tag_to_room(
+ user_id, room_id, tag, content["content"]
+ )
+
+ return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
+ """Remove tag on the appropriate account data worker.
+
+ Request format:
+
+ POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
+
+ {}
+
+ """
+
+ NAME = "remove_tag"
+ PATH_ARGS = (
+ "user_id",
+ "room_id",
+ "tag",
+ )
+ CACHE = False
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ self.handler = hs.get_account_data_handler()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload(user_id, room_id, tag):
+
+ return {}
+
+ async def _handle_request(self, request, user_id, room_id, tag):
+ max_stream_id = await self.handler.remove_tag_from_room(
+ user_id,
+ room_id,
+ tag,
+ )
+
+ return 200, {"max_stream_id": max_stream_id}
+
+
+def register_servlets(hs, http_server):
+ ReplicationUserAccountDataRestServlet(hs).register(http_server)
+ ReplicationRoomAccountDataRestServlet(hs).register(http_server)
+ ReplicationAddTagRestServlet(hs).register(http_server)
+ ReplicationRemoveTagRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 7a0dbb5b1a..82ea3b895f 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -40,6 +40,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
// containing the event
"event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. },
+ "outlier": true|false,
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
}],
@@ -84,6 +85,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"room_version": event.room_version.identifier,
"event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(),
+ "outlier": event.internal_metadata.is_outlier(),
"rejected_reason": event.rejected_reason,
"context": serialized_context,
}
@@ -116,6 +118,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event = make_event_from_dict(
event_dict, room_ver, internal_metadata, rejected_reason
)
+ event.internal_metadata.outlier = event_payload["outlier"]
context = EventContext.deserialize(
self.storage, event_payload["context"]
@@ -213,8 +216,9 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
content = parse_json_object_from_request(request)
args = content["args"]
+ args["origin"] = content["origin"]
- logger.info("Got %r query", query_type)
+ logger.info("Got %r query from %s", query_type, args["origin"])
result = await self.registry.on_query(query_type, args)
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 4c81e2d784..4ec1bfa6ea 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
+ async def _serialize_payload(
+ user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
+ ):
"""
Args:
device_id (str|None): Device ID to use, if None a new one is
@@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"device_id": device_id,
"initial_display_name": initial_display_name,
"is_guest": is_guest,
+ "is_appservice_ghost": is_appservice_ghost,
}
async def _handle_request(self, request, user_id):
@@ -56,12 +59,17 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
+ is_appservice_ghost = content["is_appservice_ghost"]
- device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest
+ res = await self.registration_handler.register_device_inner(
+ user_id,
+ device_id,
+ initial_display_name,
+ is_guest,
+ is_appservice_ghost=is_appservice_ghost,
)
- return 200, {"device_id": device_id, "access_token": access_token}
+ return 200, res
def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 84afc4c1e4..2812ac12fc 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -15,9 +15,10 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID
from synapse.util.distributor import user_left_room
@@ -78,7 +79,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore
- self, request: Request, room_id: str, user_id: str
+ self, request: SynapseRequest, room_id: str, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
@@ -86,7 +87,6 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
-
request.requester = requester
logger.info("remote_join: %s into room: %s", user_id, room_id)
@@ -145,7 +145,10 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore
- self, request: Request, room_id: str, user_id: str,
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ user_id: str,
):
content = parse_json_object_from_request(request)
@@ -214,7 +217,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore
- self, request: Request, invite_event_id: str
+ self, request: SynapseRequest, invite_event_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
@@ -222,12 +225,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
-
request.requester = requester
# hopefully we're now on the master, so this won't recurse!
event_id, stream_id = await self.member_handler.remote_reject_invite(
- invite_event_id, txn_id, requester, event_content,
+ invite_event_id,
+ txn_id,
+ requester,
+ event_content,
)
return 200, {"event_id": event_id, "stream_id": stream_id}
@@ -278,7 +283,9 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore
- self, request: Request, knock_event_id: str,
+ self,
+ request: SynapseRequest,
+ knock_event_id: str,
):
content = parse_json_object_from_request(request)
@@ -291,7 +298,10 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
# hopefully we're now on the master, so this won't recurse!
event_id, stream_id = await self.member_handler.remote_rescind_knock(
- knock_event_id, txn_id, requester, event_content,
+ knock_event_id,
+ txn_id,
+ requester,
+ event_content,
)
return 200, {"event_id": event_id, "stream_id": stream_id}
diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py
new file mode 100644
index 0000000000..054ed64d34
--- /dev/null
+++ b/synapse/replication/http/push.py
@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
+ """Deletes the given pusher.
+
+ Request format:
+
+ POST /_synapse/replication/remove_pusher/:user_id
+
+ {
+ "app_id": "<some_id>",
+ "pushkey": "<some_key>"
+ }
+
+ """
+
+ NAME = "add_user_account_data"
+ PATH_ARGS = ("user_id",)
+ CACHE = False
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.pusher_pool = hs.get_pusherpool()
+
+ @staticmethod
+ async def _serialize_payload(app_id, pushkey, user_id):
+ payload = {
+ "app_id": app_id,
+ "pushkey": pushkey,
+ }
+
+ return payload
+
+ async def _handle_request(self, request, user_id):
+ content = parse_json_object_from_request(request)
+
+ app_id = content["app_id"]
+ pushkey = content["pushkey"]
+
+ await self.pusher_pool.remove_pusher(app_id, pushkey, user_id)
+
+ return 200, {}
+
+
+def register_servlets(hs, http_server):
+ ReplicationRemovePusherRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 7b12ec9060..d005f38767 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -22,8 +22,7 @@ logger = logging.getLogger(__name__)
class ReplicationRegisterServlet(ReplicationEndpoint):
- """Register a new user
- """
+ """Register a new user"""
NAME = "register_user"
PATH_ARGS = ("user_id",)
@@ -97,8 +96,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
- """Run any post registration actions
- """
+ """Run any post registration actions"""
NAME = "post_register"
PATH_ARGS = ("user_id",)
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 8fa104c8d3..a4c5b44292 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -40,6 +40,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
// containing the event
"event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. },
+ "outlier": true|false,
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
"requester": { .. serialized requester .. },
@@ -79,7 +80,6 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
"""
-
serialized_context = await context.serialize(event, store)
payload = {
@@ -87,6 +87,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"room_version": event.room_version.identifier,
"event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(),
+ "outlier": event.internal_metadata.is_outlier(),
"rejected_reason": event.rejected_reason,
"context": serialized_context,
"requester": requester.serialize(),
@@ -108,6 +109,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
event = make_event_from_dict(
event_dict, room_ver, internal_metadata, rejected_reason
)
+ event.internal_metadata.outlier = content["outlier"]
requester = Requester.deserialize(self.store, content["requester"])
context = EventContext.deserialize(self.storage, content["context"])
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index d0089fe06c..693c9ab901 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
database,
stream_name="caches",
instance_name=hs.get_instance_name(),
- table="cache_invalidation_stream_by_instance",
- instance_column="instance_name",
- id_column="stream_id",
+ tables=[
+ (
+ "cache_invalidation_stream_by_instance",
+ "instance_name",
+ "stream_id",
+ )
+ ],
sequence_name="cache_invalidation_stream_seq",
writers=[],
) # type: Optional[MultiWriterIdGenerator]
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index eb74903d68..0d39a93ed2 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -12,21 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional, Tuple
+from synapse.storage.types import Connection
from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker:
- def __init__(self, db_conn, table, column, extra_tables=[], step=1):
+ def __init__(
+ self,
+ db_conn: Connection,
+ table: str,
+ column: str,
+ extra_tables: Optional[List[Tuple[str, str]]] = None,
+ step: int = 1,
+ ):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
- for table, column in extra_tables:
- self.advance(None, _load_current_id(db_conn, table, column))
+ if extra_tables:
+ for table, column in extra_tables:
+ self.advance(None, _load_current_id(db_conn, table, column))
- def advance(self, instance_name, new_id):
+ def advance(self, instance_name: Optional[str], new_id: int):
self._current = (max if self.step > 0 else min)(self._current, new_id)
- def get_current_token(self):
+ def get_current_token(self) -> int:
"""
Returns:
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 4268565fc8..21afe5f155 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -15,47 +15,9 @@
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- self._account_data_id_gen = SlavedIdTracker(
- db_conn,
- "account_data",
- "stream_id",
- extra_tables=[
- ("room_account_data", "stream_id"),
- ("room_tags_revisions", "stream_id"),
- ],
- )
-
- super().__init__(database, db_conn, hs)
-
- def get_max_account_data_stream_id(self):
- return self._account_data_id_gen.get_current_token()
-
- def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- for row in rows:
- self.get_tags_for_user.invalidate((row.user_id,))
- self._account_data_stream_cache.entity_has_changed(row.user_id, token)
- elif stream_name == AccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- for row in rows:
- if not row.room_id:
- self.get_global_account_data_by_type_for_user.invalidate(
- (row.data_type, row.user_id)
- )
- self.get_account_data_for_user.invalidate((row.user_id,))
- self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
- self.get_account_data_for_room_and_type.invalidate(
- (row.user_id, row.room_id, row.data_type)
- )
- self._account_data_stream_cache.entity_has_changed(row.user_id, token)
- return super().process_replication_rows(stream_name, instance_name, token, rows)
+ pass
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 5b045bed02..1260f6d141 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -14,46 +14,8 @@
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import ToDeviceStream
-from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
-from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
- self._device_inbox_id_gen = SlavedIdTracker(
- db_conn, "device_inbox", "stream_id"
- )
- self._device_inbox_stream_cache = StreamChangeCache(
- "DeviceInboxStreamChangeCache",
- self._device_inbox_id_gen.get_current_token(),
- )
- self._device_federation_outbox_stream_cache = StreamChangeCache(
- "DeviceFederationOutboxStreamChangeCache",
- self._device_inbox_id_gen.get_current_token(),
- )
-
- self._last_device_delete_cache = ExpiringCache(
- cache_name="last_device_delete_cache",
- clock=self._clock,
- max_len=10000,
- expiry_ms=30 * 60 * 1000,
- )
-
- def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == ToDeviceStream.NAME:
- self._device_inbox_id_gen.advance(instance_name, token)
- for row in rows:
- if row.entity.startswith("@"):
- self._device_inbox_stream_cache.entity_has_changed(
- row.entity, token
- )
- else:
- self._device_federation_outbox_stream_cache.entity_has_changed(
- row.entity, token
- )
- return super().process_replication_rows(stream_name, instance_name, token, rows)
+ pass
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index c418730ba8..93161c3dfb 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -13,26 +13,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.pusher import PusherWorkerStore
+from synapse.storage.types import Connection
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
- self._pushers_id_gen = SlavedIdTracker(
+ self._pushers_id_gen = SlavedIdTracker( # type: ignore
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
- def get_pushers_stream_token(self):
+ def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token, rows
+ ) -> None:
if stream_name == PushersStream.NAME:
- self._pushers_id_gen.advance(instance_name, token)
+ self._pushers_id_gen.advance(instance_name, token) # type: ignore
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 6195917376..3dfdd9961d 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,43 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.tcp.streams import ReceiptsStream
-from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- # We instantiate this first as the ReceiptsWorkerStore constructor
- # needs to be able to call get_max_receipt_stream_id
- self._receipts_id_gen = SlavedIdTracker(
- db_conn, "receipts_linearized", "stream_id"
- )
-
- super().__init__(database, db_conn, hs)
-
- def get_max_receipt_stream_id(self):
- return self._receipts_id_gen.get_current_token()
-
- def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
- self.get_receipts_for_user.invalidate((user_id, receipt_type))
- self._get_linearized_receipts_for_room.invalidate_many((room_id,))
- self.get_last_receipt_event_id_for_user.invalidate(
- (user_id, room_id, receipt_type)
- )
- self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
- self.get_receipts_for_room.invalidate((room_id, receipt_type))
-
- def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == ReceiptsStream.NAME:
- self._receipts_id_gen.advance(instance_name, token)
- for row in rows:
- self.invalidate_caches_for_receipt(
- row.room_id, row.receipt_type, row.user_id
- )
- self._receipts_stream_cache.entity_has_changed(row.room_id, token)
-
- return super().process_replication_rows(stream_name, instance_name, token, rows)
+ pass
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2618eb1e53..3455839d67 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -108,9 +108,7 @@ class ReplicationDataHandler:
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
- self._streams_to_waiters = (
- {}
- ) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
+ self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]]
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index ac532ed588..8abed1f52d 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -196,8 +196,7 @@ class ErrorCommand(_SimpleCommand):
class PingCommand(_SimpleCommand):
- """Sent by either side as a keep alive. The data is arbitrary (often timestamp)
- """
+ """Sent by either side as a keep alive. The data is arbitrary (often timestamp)"""
NAME = "PING"
@@ -313,44 +312,19 @@ class FederationAckCommand(Command):
NAME = "FEDERATION_ACK"
- def __init__(self, instance_name, token):
+ def __init__(self, instance_name: str, token: int):
self.instance_name = instance_name
self.token = token
@classmethod
- def from_line(cls, line):
+ def from_line(cls, line: str) -> "FederationAckCommand":
instance_name, token = line.split(" ")
return cls(instance_name, int(token))
- def to_line(self):
+ def to_line(self) -> str:
return "%s %s" % (self.instance_name, self.token)
-class RemovePusherCommand(Command):
- """Sent by the client to request the master remove the given pusher.
-
- Format::
-
- REMOVE_PUSHER <app_id> <push_key> <user_id>
- """
-
- NAME = "REMOVE_PUSHER"
-
- def __init__(self, app_id, push_key, user_id):
- self.user_id = user_id
- self.app_id = app_id
- self.push_key = push_key
-
- @classmethod
- def from_line(cls, line):
- app_id, push_key, user_id = line.split(" ", 2)
-
- return cls(app_id, push_key, user_id)
-
- def to_line(self):
- return " ".join((self.app_id, self.push_key, self.user_id))
-
-
class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client.
@@ -417,7 +391,6 @@ _COMMANDS = (
ReplicateCommand,
UserSyncCommand,
FederationAckCommand,
- RemovePusherCommand,
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
@@ -444,7 +417,6 @@ VALID_CLIENT_COMMANDS = (
UserSyncCommand.NAME,
ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
- RemovePusherCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
new file mode 100644
index 0000000000..d89a36f25a
--- /dev/null
+++ b/synapse/replication/tcp/external_cache.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Any, Optional
+
+from prometheus_client import Counter
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util import json_decoder, json_encoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+set_counter = Counter(
+ "synapse_external_cache_set",
+ "Number of times we set a cache",
+ labelnames=["cache_name"],
+)
+
+get_counter = Counter(
+ "synapse_external_cache_get",
+ "Number of times we get a cache",
+ labelnames=["cache_name", "hit"],
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalCache:
+ """A cache backed by an external Redis. Does nothing if no Redis is
+ configured.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ self._redis_connection = hs.get_outbound_redis_connection()
+
+ def _get_redis_key(self, cache_name: str, key: str) -> str:
+ return "cache_v1:%s:%s" % (cache_name, key)
+
+ def is_enabled(self) -> bool:
+ """Whether the external cache is used or not.
+
+ It's safe to use the cache when this returns false, the methods will
+ just no-op, but the function is useful to avoid doing unnecessary work.
+ """
+ return self._redis_connection is not None
+
+ async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
+ """Add the key/value to the named cache, with the expiry time given."""
+
+ if self._redis_connection is None:
+ return
+
+ set_counter.labels(cache_name).inc()
+
+ # txredisapi requires the value to be string, bytes or numbers, so we
+ # encode stuff in JSON.
+ encoded_value = json_encoder.encode(value)
+
+ logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
+
+ return await make_deferred_yieldable(
+ self._redis_connection.set(
+ self._get_redis_key(cache_name, key),
+ encoded_value,
+ pexpire=expiry_ms,
+ )
+ )
+
+ async def get(self, cache_name: str, key: str) -> Optional[Any]:
+ """Look up a key/value in the named cache."""
+
+ if self._redis_connection is None:
+ return None
+
+ result = await make_deferred_yieldable(
+ self._redis_connection.get(self._get_redis_key(cache_name, key))
+ )
+
+ logger.debug("Got cache result %s %s: %r", cache_name, key, result)
+
+ get_counter.labels(cache_name, result is not None).inc()
+
+ if not result:
+ return None
+
+ # For some reason the integers get magically converted back to integers
+ if isinstance(result, int):
+ return result
+
+ return json_decoder.decode(result)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 95e5502bf2..a8894beadf 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import (
+ TYPE_CHECKING,
Any,
Awaitable,
Dict,
@@ -43,22 +44,28 @@ from synapse.replication.tcp.commands import (
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
- RemovePusherCommand,
ReplicateCommand,
UserIpCommand,
UserSyncCommand,
)
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
+ AccountDataStream,
BackfillStream,
CachesStream,
EventsStream,
FederationStream,
+ ReceiptsStream,
Stream,
+ TagAccountDataStream,
+ ToDeviceStream,
TypingStream,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -75,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
# the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[
- Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+ Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
]
@@ -84,7 +91,7 @@ class ReplicationCommandHandler:
back out to connections.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore()
@@ -115,6 +122,14 @@ class ReplicationCommandHandler:
continue
+ if isinstance(stream, ToDeviceStream):
+ # Only add ToDeviceStream as a source on instances in charge of
+ # sending to device messages.
+ if hs.get_instance_name() in hs.config.worker.writers.to_device:
+ self._streams_to_replicate.append(stream)
+
+ continue
+
if isinstance(stream, TypingStream):
# Only add TypingStream as a source on the instance in charge of
# typing.
@@ -123,6 +138,22 @@ class ReplicationCommandHandler:
continue
+ if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
+ # Only add AccountDataStream and TagAccountDataStream as a source on the
+ # instance in charge of account_data persistence.
+ if hs.get_instance_name() in hs.config.worker.writers.account_data:
+ self._streams_to_replicate.append(stream)
+
+ continue
+
+ if isinstance(stream, ReceiptsStream):
+ # Only add ReceiptsStream as a source on the instance in charge of
+ # receipts.
+ if hs.get_instance_name() in hs.config.worker.writers.receipts:
+ self._streams_to_replicate.append(stream)
+
+ continue
+
# Only add any other streams if we're on master.
if hs.config.worker_app is not None:
continue
@@ -143,7 +174,7 @@ class ReplicationCommandHandler:
# The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
- self._connections = [] # type: List[AbstractConnection]
+ self._connections = [] # type: List[IReplicationConnection]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
@@ -166,7 +197,7 @@ class ReplicationCommandHandler:
# For each connection, the incoming stream names that have received a POSITION
# from that connection.
- self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
+ self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
LaterGauge(
"synapse_replication_tcp_command_queue",
@@ -189,7 +220,7 @@ class ReplicationCommandHandler:
self._server_notices_sender = hs.get_server_notices_sender()
def _add_command_to_stream_queue(
- self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+ self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
@@ -236,7 +267,7 @@ class ReplicationCommandHandler:
async def _process_command(
self,
cmd: Union[PositionCommand, RdataCommand],
- conn: AbstractConnection,
+ conn: IReplicationConnection,
stream_name: str,
) -> None:
if isinstance(cmd, PositionCommand):
@@ -254,13 +285,6 @@ class ReplicationCommandHandler:
if hs.config.redis.redis_enabled:
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
- lazyConnection,
- )
-
- logger.info(
- "Connecting to redis (host=%r port=%r)",
- hs.config.redis_host,
- hs.config.redis_port,
)
# First let's ensure that we have a ReplicationStreamer started.
@@ -271,42 +295,36 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
- outbound_redis_connection = lazyConnection(
- reactor=hs.get_reactor(),
- host=hs.config.redis_host,
- port=hs.config.redis_port,
- password=hs.config.redis.redis_password,
- reconnect=True,
- )
+ outbound_redis_connection = hs.get_outbound_redis_connection()
# Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory(
hs, outbound_redis_connection
)
hs.get_reactor().connectTCP(
- hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
+ hs.config.redis.redis_host.encode(),
+ hs.config.redis.redis_port,
+ self._factory,
)
else:
client_name = hs.get_instance_name()
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
- hs.get_reactor().connectTCP(host, port, self._factory)
+ hs.get_reactor().connectTCP(host.encode(), port, self._factory)
def get_streams(self) -> Dict[str, Stream]:
- """Get a map from stream name to all streams.
- """
+ """Get a map from stream name to all streams."""
return self._streams
def get_streams_to_replicate(self) -> List[Stream]:
- """Get a list of streams that this instances replicates.
- """
+ """Get a list of streams that this instances replicates."""
return self._streams_to_replicate
- def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+ def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
- def send_positions_to_connection(self, conn: AbstractConnection):
+ def send_positions_to_connection(self, conn: IReplicationConnection):
"""Send current position of all streams this process is source of to
the connection.
"""
@@ -321,12 +339,15 @@ class ReplicationCommandHandler:
current_token = stream.current_token(self._instance_name)
self.send_command(
PositionCommand(
- stream.NAME, self._instance_name, current_token, current_token,
+ stream.NAME,
+ self._instance_name,
+ current_token,
+ current_token,
)
)
def on_USER_SYNC(
- self, conn: AbstractConnection, cmd: UserSyncCommand
+ self, conn: IReplicationConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]:
user_sync_counter.inc()
@@ -338,38 +359,23 @@ class ReplicationCommandHandler:
return None
def on_CLEAR_USER_SYNC(
- self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+ self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
) -> Optional[Awaitable[None]]:
if self._is_master:
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else:
return None
- def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
+ def on_FEDERATION_ACK(
+ self, conn: IReplicationConnection, cmd: FederationAckCommand
+ ):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
- def on_REMOVE_PUSHER(
- self, conn: AbstractConnection, cmd: RemovePusherCommand
- ) -> Optional[Awaitable[None]]:
- remove_pusher_counter.inc()
-
- if self._is_master:
- return self._handle_remove_pusher(cmd)
- else:
- return None
-
- async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
- await self._store.delete_pusher_by_app_id_pushkey_user_id(
- app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
- )
-
- self._notifier.on_new_replication_data()
-
def on_USER_IP(
- self, conn: AbstractConnection, cmd: UserIpCommand
+ self, conn: IReplicationConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
@@ -391,7 +397,7 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)
- def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+ def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@@ -408,7 +414,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_rdata(
- self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+ self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
) -> None:
"""Process an RDATA command
@@ -482,7 +488,7 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
- def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+ def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
@@ -492,7 +498,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_position(
- self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+ self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
) -> None:
"""Process a POSITION command
@@ -549,7 +555,9 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
- def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
+ def on_REMOTE_SERVER_UP(
+ self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
+ ):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
@@ -572,9 +580,8 @@ class ReplicationCommandHandler:
# between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn)
- def new_connection(self, connection: AbstractConnection):
- """Called when we have a new connection.
- """
+ def new_connection(self, connection: IReplicationConnection):
+ """Called when we have a new connection."""
self._connections.append(connection)
# If we are connected to replication as a client (rather than a server)
@@ -600,9 +607,8 @@ class ReplicationCommandHandler:
UserSyncCommand(self._instance_id, user_id, True, now)
)
- def lost_connection(self, connection: AbstractConnection):
- """Called when a connection is closed/lost.
- """
+ def lost_connection(self, connection: IReplicationConnection):
+ """Called when a connection is closed/lost."""
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
if streams:
@@ -622,7 +628,7 @@ class ReplicationCommandHandler:
return bool(self._connections)
def send_command(
- self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+ self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
):
"""Send a command to all connected connections.
@@ -659,18 +665,11 @@ class ReplicationCommandHandler:
def send_user_sync(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
):
- """Poke the master that a user has started/stopped syncing.
- """
+ """Poke the master that a user has started/stopped syncing."""
self.send_command(
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
)
- def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
- """Poke the master to remove a pusher for a user
- """
- cmd = RemovePusherCommand(app_id, push_key, user_id)
- self.send_command(cmd)
-
def send_user_ip(
self,
user_id: str,
@@ -680,8 +679,7 @@ class ReplicationCommandHandler:
device_id: str,
last_seen: int,
):
- """Tell the master that the user made a request.
- """
+ """Tell the master that the user made a request."""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index a509e599c2..e829add257 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-import abc
import fcntl
import logging
import struct
@@ -54,8 +53,10 @@ from inspect import isawaitable
from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
+from zope.interface import Interface, implementer
from twisted.internet import task
+from twisted.internet.tcp import Connection
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@@ -103,7 +104,7 @@ tcp_outbound_commands_counter = Counter(
# A list of all connected protocols. This allows us to send metrics about the
# connections.
-connected_connections = []
+connected_connections = [] # type: List[BaseReplicationStreamProtocol]
logger = logging.getLogger(__name__)
@@ -121,6 +122,14 @@ class ConnectionStates:
CLOSED = "closed"
+class IReplicationConnection(Interface):
+ """An interface for replication connections."""
+
+ def send_command(cmd: Command):
+ """Send the command down the connection"""
+
+
+@implementer(IReplicationConnection)
class BaseReplicationStreamProtocol(LineOnlyReceiver):
"""Base replication protocol shared between client and server.
@@ -137,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
(if they send a `PING` command)
"""
+ # The transport is going to be an ITCPTransport, but that doesn't have the
+ # (un)registerProducer methods, those are only on the implementation.
+ transport = None # type: Connection
+
delimiter = b"\n"
# Valid commands we expect to receive
@@ -172,8 +185,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
ctx_name = "replication-conn-%s" % self.conn_id
- self._logging_context = BackgroundProcessLoggingContext(ctx_name)
- self._logging_context.request = ctx_name
+ self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -182,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
connected_connections.append(self) # Register connection for metrics
+ assert self.transport is not None
self.transport.registerProducer(self, True) # For the *Producing callbacks
self._send_pending_commands()
@@ -206,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info(
"[%s] Failed to close connection gracefully, aborting", self.id()
)
+ assert self.transport is not None
self.transport.abortConnection()
else:
if now - self.last_sent_command >= PING_TIME:
@@ -223,8 +237,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.send_error("ping timeout")
def lineReceived(self, line: bytes):
- """Called when we've received a line
- """
+ """Called when we've received a line"""
with PreserveLoggingContext(self._logging_context):
self._parse_and_dispatch_line(line)
@@ -296,12 +309,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def close(self):
logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
+ assert self.transport is not None
self.transport.loseConnection()
self.on_connection_closed()
def send_error(self, error_string, *args):
- """Send an error to remote and close the connection.
- """
+ """Send an error to remote and close the connection."""
self.send_command(ErrorCommand(error_string % args))
self.close()
@@ -342,8 +355,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_sent_command = self.clock.time_msec()
def _queue_command(self, cmd):
- """Queue the command until the connection is ready to write to again.
- """
+ """Queue the command until the connection is ready to write to again."""
logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
self.pending_commands.append(cmd)
@@ -356,8 +368,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.close()
def _send_pending_commands(self):
- """Send any queued commandes
- """
+ """Send any queued commandes"""
pending = self.pending_commands
self.pending_commands = []
for cmd in pending:
@@ -381,8 +392,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.state = ConnectionStates.PAUSED
def resumeProducing(self):
- """The remote has caught up after we started buffering!
- """
+ """The remote has caught up after we started buffering!"""
logger.info("[%s] Resume producing", self.id())
self.state = ConnectionStates.ESTABLISHED
self._send_pending_commands()
@@ -397,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
+ assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc()
else:
connection_close_counter.labels(reason.__class__.__name__).inc()
@@ -441,8 +452,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
return "%s-%s" % (self.name, self.conn_id)
def lineLengthExceeded(self, line):
- """Called when we receive a line that is above the maximum line length
- """
+ """Called when we receive a line that is above the maximum line length"""
self.send_error("Line length exceeded")
@@ -496,29 +506,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_error("Wrong remote")
def replicate(self):
- """Send the subscription request to the server
- """
+ """Send the subscription request to the server"""
logger.info("[%s] Subscribing to replication streams", self.id())
self.send_command(ReplicateCommand())
-class AbstractConnection(abc.ABC):
- """An interface for replication connections.
- """
-
- @abc.abstractmethod
- def send_command(self, cmd: Command):
- """Send the command down the connection
- """
- pass
-
-
-# This tells python that `BaseReplicationStreamProtocol` implements the
-# interface.
-AbstractConnection.register(BaseReplicationStreamProtocol)
-
-
# The following simply registers metrics for the replication connections
pending_commands = LaterGauge(
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index bc6ba709a7..2f4d407f94 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,14 +15,21 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
+import attr
import txredisapi
+from zope.interface import implementer
+
+from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.interfaces import IAddress, IConnector
+from twisted.python.failure import Failure
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext,
run_as_background_process,
+ wrap_as_background_process,
)
from synapse.replication.tcp.commands import (
Command,
@@ -30,7 +37,7 @@ from synapse.replication.tcp.commands import (
parse_command_from_line,
)
from synapse.replication.tcp.protocol import (
- AbstractConnection,
+ IReplicationConnection,
tcp_inbound_commands_counter,
tcp_outbound_commands_counter,
)
@@ -41,8 +48,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+T = TypeVar("T")
+V = TypeVar("V")
+
+
+@attr.s
+class ConstantProperty(Generic[T, V]):
+ """A descriptor that returns the given constant, ignoring attempts to set
+ it.
+ """
+
+ constant = attr.ib() # type: V
-class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
+ def __get__(self, obj: Optional[T], objtype: Type[T] = None) -> V:
+ return self.constant
+
+ def __set__(self, obj: Optional[T], value: V):
+ pass
+
+
+@implementer(IReplicationConnection)
+class RedisSubscriber(txredisapi.SubscriberProtocol):
"""Connection to redis subscribed to replication stream.
This class fulfils two functions:
@@ -51,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
connection, parsing *incoming* messages into replication commands, and passing them
to `ReplicationCommandHandler`
- (b) it implements the AbstractConnection API, where it sends *outgoing* commands
+ (b) it implements the IReplicationConnection API, where it sends *outgoing* commands
onto outbound_redis_connection.
Due to the vagaries of `txredisapi` we don't want to have a custom
@@ -59,16 +85,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
immediately after initialisation.
Attributes:
- handler: The command handler to handle incoming commands.
- stream_name: The *redis* stream name to subscribe to and publish from
- (not anything to do with Synapse replication streams).
- outbound_redis_connection: The connection to redis to use to send
+ synapse_handler: The command handler to handle incoming commands.
+ synapse_stream_name: The *redis* stream name to subscribe to and publish
+ from (not anything to do with Synapse replication streams).
+ synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""
- handler = None # type: ReplicationCommandHandler
- stream_name = None # type: str
- outbound_redis_connection = None # type: txredisapi.RedisProtocol
+ synapse_handler = None # type: ReplicationCommandHandler
+ synapse_stream_name = None # type: str
+ synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -88,23 +114,22 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
- logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
- await make_deferred_yieldable(self.subscribe(self.stream_name))
+ logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
+ await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
- self.handler.new_connection(self)
+ self.synapse_handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
# We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the
# otherside won't know we've connected and so won't issue a REPLICATE.
- self.handler.send_positions_to_connection(self)
+ self.synapse_handler.send_positions_to_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str):
- """Received a message from redis.
- """
+ """Received a message from redis."""
with PreserveLoggingContext(self._logging_context):
self._parse_and_dispatch_message(message)
@@ -117,7 +142,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
cmd = parse_command_from_line(message)
except Exception:
logger.exception(
- "Failed to parse replication line: %r", message,
+ "Failed to parse replication line: %r",
+ message,
)
return
@@ -137,7 +163,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
cmd: received command
"""
- cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+ cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
return
@@ -155,7 +181,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason)
- self.handler.lost_connection(self)
+ self.synapse_handler.lost_connection(self)
# mark the logging context as finished
self._logging_context.__exit__(None, None, None)
@@ -183,11 +209,89 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
await make_deferred_yieldable(
- self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+ self.synapse_outbound_redis_connection.publish(
+ self.synapse_stream_name, encoded_string
+ )
+ )
+
+
+class SynapseRedisFactory(txredisapi.RedisFactory):
+ """A subclass of RedisFactory that periodically sends pings to ensure that
+ we detect dead connections.
+ """
+
+ # We want to *always* retry connecting, txredisapi will stop if there is a
+ # failure during certain operations, e.g. during AUTH.
+ continueTrying = cast(bool, ConstantProperty(True))
+
+ def __init__(
+ self,
+ hs: "HomeServer",
+ uuid: str,
+ dbid: Optional[int],
+ poolsize: int,
+ isLazy: bool = False,
+ handler: Type = txredisapi.ConnectionHandler,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ replyTimeout: int = 30,
+ convertNumbers: Optional[int] = True,
+ ):
+ super().__init__(
+ uuid=uuid,
+ dbid=dbid,
+ poolsize=poolsize,
+ isLazy=isLazy,
+ handler=handler,
+ charset=charset,
+ password=password,
+ replyTimeout=replyTimeout,
+ convertNumbers=convertNumbers,
)
+ hs.get_clock().looping_call(self._send_ping, 30 * 1000)
-class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+ @wrap_as_background_process("redis_ping")
+ async def _send_ping(self):
+ for connection in self.pool:
+ try:
+ await make_deferred_yieldable(connection.ping())
+ except Exception:
+ logger.warning("Failed to send ping to a redis connection")
+
+ # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
+ # it's rubbish. We add our own here.
+
+ def startedConnecting(self, connector: IConnector):
+ logger.info(
+ "Connecting to redis server %s", format_address(connector.getDestination())
+ )
+ super().startedConnecting(connector)
+
+ def clientConnectionFailed(self, connector: IConnector, reason: Failure):
+ logger.info(
+ "Connection to redis server %s failed: %s",
+ format_address(connector.getDestination()),
+ reason.value,
+ )
+ super().clientConnectionFailed(connector, reason)
+
+ def clientConnectionLost(self, connector: IConnector, reason: Failure):
+ logger.info(
+ "Connection to redis server %s lost: %s",
+ format_address(connector.getDestination()),
+ reason.value,
+ )
+ super().clientConnectionLost(connector, reason)
+
+
+def format_address(address: IAddress) -> str:
+ if isinstance(address, (IPv4Address, IPv6Address)):
+ return "%s:%i" % (address.host, address.port)
+ return str(address)
+
+
+class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
subscribes to a stream.
@@ -199,72 +303,68 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
"""
maxDelay = 5
- continueTrying = True
protocol = RedisSubscriber
def __init__(
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
):
- super().__init__()
-
- # This sets the password on the RedisFactory base class (as
- # SubscriberFactory constructor doesn't pass it through).
- self.password = hs.config.redis.redis_password
+ super().__init__(
+ hs,
+ uuid="subscriber",
+ dbid=None,
+ poolsize=1,
+ replyTimeout=30,
+ password=hs.config.redis.redis_password,
+ )
- self.handler = hs.get_tcp_replication()
- self.stream_name = hs.hostname
+ self.synapse_handler = hs.get_tcp_replication()
+ self.synapse_stream_name = hs.hostname
- self.outbound_redis_connection = outbound_redis_connection
+ self.synapse_outbound_redis_connection = outbound_redis_connection
def buildProtocol(self, addr):
- p = super().buildProtocol(addr) # type: RedisSubscriber
+ p = super().buildProtocol(addr)
+ p = cast(RedisSubscriber, p)
# We do this here rather than add to the constructor of `RedisSubcriber`
# as to do so would involve overriding `buildProtocol` entirely, however
# the base method does some other things than just instantiating the
# protocol.
- p.handler = self.handler
- p.outbound_redis_connection = self.outbound_redis_connection
- p.stream_name = self.stream_name
- p.password = self.password
+ p.synapse_handler = self.synapse_handler
+ p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
+ p.synapse_stream_name = self.synapse_stream_name
return p
def lazyConnection(
- reactor,
+ hs: "HomeServer",
host: str = "localhost",
port: int = 6379,
dbid: Optional[int] = None,
reconnect: bool = True,
- charset: str = "utf-8",
password: Optional[str] = None,
- connectTimeout: Optional[int] = None,
- replyTimeout: Optional[int] = None,
- convertNumbers: bool = True,
+ replyTimeout: int = 30,
) -> txredisapi.RedisProtocol:
- """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
- reactor.
+ """Creates a connection to Redis that is lazily set up and reconnects if the
+ connections is lost.
"""
- isLazy = True
- poolsize = 1
-
uuid = "%s:%d" % (host, port)
- factory = txredisapi.RedisFactory(
- uuid,
- dbid,
- poolsize,
- isLazy,
- txredisapi.ConnectionHandler,
- charset,
- password,
- replyTimeout,
- convertNumbers,
+ factory = SynapseRedisFactory(
+ hs,
+ uuid=uuid,
+ dbid=dbid,
+ poolsize=1,
+ isLazy=True,
+ handler=txredisapi.ConnectionHandler,
+ password=password,
+ replyTimeout=replyTimeout,
)
factory.continueTrying = reconnect
- for x in range(poolsize):
- reactor.connectTCP(host, port, factory, connectTimeout)
+
+ reactor = hs.get_reactor()
+ reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
return factory.handler
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 1d4ceac0f1..2018f9f29e 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -36,8 +36,7 @@ logger = logging.getLogger(__name__)
class ReplicationStreamProtocolFactory(Factory):
- """Factory for new replication connections.
- """
+ """Factory for new replication connections."""
def __init__(self, hs):
self.command_handler = hs.get_tcp_replication()
@@ -181,7 +180,8 @@ class ReplicationStreamer:
raise
logger.debug(
- "Sending %d updates", len(updates),
+ "Sending %d updates",
+ len(updates),
)
if updates:
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 61b282ab2d..3dfee76743 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -33,7 +33,7 @@ import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING:
- import synapse.server
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -183,7 +183,10 @@ class Stream:
return [], upto_token, False
updates, upto_token, limited = await self.update_function(
- instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
+ instance_name,
+ from_token,
+ upto_token,
+ _STREAM_UPDATE_TARGET_ROW_COUNT,
)
return updates, upto_token, limited
@@ -296,20 +299,23 @@ class TypingStream(Stream):
NAME = "typing"
ROW_TYPE = TypingStreamRow
- def __init__(self, hs):
- typing_handler = hs.get_typing_handler()
-
+ def __init__(self, hs: "HomeServer"):
writer_instance = hs.config.worker.writers.typing
if writer_instance == hs.get_instance_name():
# On the writer, query the typing handler
- update_function = typing_handler.get_all_typing_updates
+ typing_writer_handler = hs.get_typing_writer_handler()
+ update_function = (
+ typing_writer_handler.get_all_typing_updates
+ ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
+ current_token_function = typing_writer_handler.get_current_token
else:
# Query the typing writer process
update_function = make_http_update_function(hs, self.NAME)
+ current_token_function = hs.get_typing_handler().get_current_token
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(typing_handler.get_current_token),
+ current_token_without_instance(current_token_function),
update_function,
)
@@ -339,8 +345,7 @@ class ReceiptsStream(Stream):
class PushRulesStream(Stream):
- """A user has changed their push rules
- """
+ """A user has changed their push rules"""
PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
@@ -362,8 +367,7 @@ class PushRulesStream(Stream):
class PushersStream(Stream):
- """A user has added/changed/removed a pusher
- """
+ """A user has added/changed/removed a pusher"""
PushersStreamRow = namedtuple(
"PushersStreamRow",
@@ -416,8 +420,7 @@ class CachesStream(Stream):
class PublicRoomsStream(Stream):
- """The public rooms list changed
- """
+ """The public rooms list changed"""
PublicRoomsStreamRow = namedtuple(
"PublicRoomsStreamRow",
@@ -463,8 +466,7 @@ class DeviceListsStream(Stream):
class ToDeviceStream(Stream):
- """New to_device messages for a client
- """
+ """New to_device messages for a client"""
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
@@ -481,8 +483,7 @@ class ToDeviceStream(Stream):
class TagAccountDataStream(Stream):
- """Someone added/removed a tag for a room
- """
+ """Someone added/removed a tag for a room"""
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
@@ -501,18 +502,17 @@ class TagAccountDataStream(Stream):
class AccountDataStream(Stream):
- """Global or per room account data was changed
- """
+ """Global or per room account data was changed"""
AccountDataStreamRow = namedtuple(
- "AccountDataStream",
+ "AccountDataStreamRow",
("user_id", "room_id", "data_type"), # str # Optional[str] # str
)
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
- def __init__(self, hs: "synapse.server.HomeServer"):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -589,8 +589,7 @@ class GroupServerStream(Stream):
class UserSignatureStream(Stream):
- """A user has signed their own device with their user-signing key
- """
+ """A user has signed their own device with their user-signing key"""
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 86a62b71eb..fa5e37ba7b 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -113,8 +113,7 @@ TypeToRow = {Row.TypeId: Row for Row in _EventRows}
class EventsStream(Stream):
- """We received a new event, or an event went from being an outlier to not
- """
+ """We received a new event, or an event went from being an outlier to not"""
NAME = "events"
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 9bcd13b009..9bb8e9e177 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple
from synapse.replication.tcp.streams._base import (
Stream,
@@ -21,6 +22,9 @@ from synapse.replication.tcp.streams._base import (
make_http_update_function,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class FederationStream(Stream):
"""Data to be sent over federation. Only available when master has federation
@@ -38,7 +42,7 @@ class FederationStream(Stream):
NAME = "federation"
ROW_TYPE = FederationStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
if hs.config.worker_app is None:
# master process: get updates from the FederationRemoteSendQueue.
# (if the master is configured to send federation itself, federation_sender
@@ -48,7 +52,9 @@ class FederationStream(Stream):
current_token = current_token_without_instance(
federation_sender.get_current_token
)
- update_function = federation_sender.get_replication_rows
+ update_function = (
+ federation_sender.get_replication_rows
+ ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
elif hs.should_send_federation():
# federation sender: Query master process
@@ -69,5 +75,7 @@ class FederationStream(Stream):
return 0
@staticmethod
- async def _stub_update_function(instance_name, from_token, upto_token, limit):
+ async def _stub_update_function(
+ instance_name: str, from_token: int, upto_token: int, limit: int
+ ) -> Tuple[list, int, bool]:
return [], upto_token, False
diff --git a/synapse/res/templates/notif.html b/synapse/res/templates/notif.html
index 6d76064d13..0aaef97df8 100644
--- a/synapse/res/templates/notif.html
+++ b/synapse/res/templates/notif.html
@@ -29,7 +29,7 @@
{{ message.body_text_html }}
{%- elif message.msgtype == "m.notice" %}
{{ message.body_text_html }}
- {%- elif message.msgtype == "m.image" %}
+ {%- elif message.msgtype == "m.image" and message.image_url %}
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
{%- elif message.msgtype == "m.file" %}
<span class="filename">{{ message.body_text_plain }}</span>
diff --git a/synapse/res/templates/sso.css b/synapse/res/templates/sso.css
new file mode 100644
index 0000000000..338214f5d0
--- /dev/null
+++ b/synapse/res/templates/sso.css
@@ -0,0 +1,129 @@
+body, input, select, textarea {
+ font-family: "Inter", "Helvetica", "Arial", sans-serif;
+ font-size: 14px;
+ color: #17191C;
+}
+
+header, footer {
+ max-width: 480px;
+ width: 100%;
+ margin: 24px auto;
+ text-align: center;
+}
+
+@media screen and (min-width: 800px) {
+ header {
+ margin-top: 90px;
+ }
+}
+
+header {
+ min-height: 60px;
+}
+
+header p {
+ color: #737D8C;
+ line-height: 24px;
+}
+
+h1 {
+ font-size: 24px;
+}
+
+a {
+ color: #418DED;
+}
+
+.error_page h1 {
+ color: #FE2928;
+}
+
+h2 {
+ font-size: 14px;
+}
+
+h2 img {
+ vertical-align: middle;
+ margin-right: 8px;
+ width: 24px;
+ height: 24px;
+}
+
+label {
+ cursor: pointer;
+}
+
+main {
+ max-width: 360px;
+ width: 100%;
+ margin: 24px auto;
+}
+
+.primary-button {
+ border: none;
+ -webkit-appearance: none;
+ -moz-appearance: none;
+ appearance: none;
+ text-decoration: none;
+ padding: 12px;
+ color: white;
+ background-color: #418DED;
+ font-weight: bold;
+ display: block;
+ border-radius: 12px;
+ width: 100%;
+ box-sizing: border-box;
+ margin: 16px 0;
+ cursor: pointer;
+ text-align: center;
+}
+
+.profile {
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ justify-content: center;
+ margin: 24px;
+ padding: 13px;
+ border: 1px solid #E9ECF1;
+ border-radius: 4px;
+}
+
+.profile.with-avatar {
+ margin-top: 42px; /* (36px / 2) + 24px*/
+}
+
+.profile .avatar {
+ width: 36px;
+ height: 36px;
+ border-radius: 100%;
+ display: block;
+ margin-top: -32px;
+ margin-bottom: 8px;
+}
+
+.profile .display-name {
+ font-weight: bold;
+ margin-bottom: 4px;
+ font-size: 15px;
+ line-height: 18px;
+}
+.profile .user-id {
+ color: #737D8C;
+ font-size: 12px;
+ line-height: 12px;
+}
+
+footer {
+ margin-top: 80px;
+}
+
+footer svg {
+ display: block;
+ width: 46px;
+ margin: 0px auto 12px auto;
+}
+
+footer p {
+ color: #737D8C;
+}
\ No newline at end of file
diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html
index 4eb8db9fb4..c3e4deed93 100644
--- a/synapse/res/templates/sso_account_deactivated.html
+++ b/synapse/res/templates/sso_account_deactivated.html
@@ -1,10 +1,25 @@
<!DOCTYPE html>
<html lang="en">
-<head>
- <meta charset="UTF-8">
- <title>SSO account deactivated</title>
-</head>
- <body>
- <p>This account has been deactivated.</p>
+ <head>
+ <meta charset="UTF-8">
+ <title>SSO account deactivated</title>
+ <meta name="viewport" content="width=device-width, user-scalable=no">
+ <style type="text/css">
+ {% include "sso.css" without context %}
+ </style>
+ </head>
+ <body class="error_page">
+ <header>
+ <h1>Your account has been deactivated</h1>
+ <p>
+ <strong>No account found</strong>
+ </p>
+ <p>
+ Your account might have been deactivated by the server administrator.
+ You can either try to create a new account or contact the server’s
+ administrator.
+ </p>
+ </header>
+ {% include "sso_footer.html" without context %}
</body>
</html>
diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html
new file mode 100644
index 0000000000..00e1dcdbb8
--- /dev/null
+++ b/synapse/res/templates/sso_auth_account_details.html
@@ -0,0 +1,188 @@
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <title>Create your account</title>
+ <meta charset="utf-8">
+ <meta name="viewport" content="width=device-width, user-scalable=no">
+ <script type="text/javascript">
+ let wasKeyboard = false;
+ document.addEventListener("mousedown", function() { wasKeyboard = false; });
+ document.addEventListener("keydown", function() { wasKeyboard = true; });
+ document.addEventListener("focusin", function() {
+ if (wasKeyboard) {
+ document.body.classList.add("keyboard-focus");
+ } else {
+ document.body.classList.remove("keyboard-focus");
+ }
+ });
+ </script>
+ <style type="text/css">
+ {% include "sso.css" without context %}
+
+ body.keyboard-focus :focus, body.keyboard-focus .username_input:focus-within {
+ outline: 3px solid #17191C;
+ outline-offset: 4px;
+ }
+
+ .username_input {
+ display: flex;
+ border: 2px solid #418DED;
+ border-radius: 8px;
+ padding: 12px;
+ position: relative;
+ margin: 16px 0;
+ align-items: center;
+ font-size: 12px;
+ }
+
+ .username_input.invalid {
+ border-color: #FE2928;
+ }
+
+ .username_input.invalid input, .username_input.invalid label {
+ color: #FE2928;
+ }
+
+ .username_input div, .username_input input {
+ line-height: 18px;
+ font-size: 14px;
+ }
+
+ .username_input label {
+ position: absolute;
+ top: -5px;
+ left: 14px;
+ font-size: 10px;
+ line-height: 10px;
+ background: white;
+ padding: 0 2px;
+ }
+
+ .username_input input {
+ flex: 1;
+ display: block;
+ min-width: 0;
+ border: none;
+ }
+
+ /* only clear the outline if we know it will be shown on the parent div using :focus-within */
+ @supports selector(:focus-within) {
+ .username_input input {
+ outline: none !important;
+ }
+ }
+
+ .username_input div {
+ color: #8D99A5;
+ }
+
+ .idp-pick-details {
+ border: 1px solid #E9ECF1;
+ border-radius: 8px;
+ margin: 24px 0;
+ }
+
+ .idp-pick-details h2 {
+ margin: 0;
+ padding: 8px 12px;
+ }
+
+ .idp-pick-details .idp-detail {
+ border-top: 1px solid #E9ECF1;
+ padding: 12px;
+ display: block;
+ }
+ .idp-pick-details .check-row {
+ display: flex;
+ align-items: center;
+ }
+
+ .idp-pick-details .check-row .name {
+ flex: 1;
+ }
+
+ .idp-pick-details .use, .idp-pick-details .idp-value {
+ color: #737D8C;
+ }
+
+ .idp-pick-details .idp-value {
+ margin: 0;
+ margin-top: 8px;
+ }
+
+ .idp-pick-details .avatar {
+ width: 53px;
+ height: 53px;
+ border-radius: 100%;
+ display: block;
+ margin-top: 8px;
+ }
+
+ output {
+ padding: 0 14px;
+ display: block;
+ }
+
+ output.error {
+ color: #FE2928;
+ }
+ </style>
+ </head>
+ <body>
+ <header>
+ <h1>Your account is nearly ready</h1>
+ <p>Check your details before creating an account on {{ server_name }}</p>
+ </header>
+ <main>
+ <form method="post" class="form__input" id="form">
+ <div class="username_input" id="username_input">
+ <label for="field-username">Username</label>
+ <div class="prefix">@</div>
+ <input type="text" name="username" id="field-username" autofocus>
+ <div class="postfix">:{{ server_name }}</div>
+ </div>
+ <output for="username_input" id="field-username-output"></output>
+ <input type="submit" value="Continue" class="primary-button">
+ {% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %}
+ <section class="idp-pick-details">
+ <h2>{% if idp.idp_icon %}<img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>{% endif %}Information from {{ idp.idp_name }}</h2>
+ {% if user_attributes.avatar_url %}
+ <label class="idp-detail idp-avatar" for="idp-avatar">
+ <div class="check-row">
+ <span class="name">Avatar</span>
+ <span class="use">Use</span>
+ <input type="checkbox" name="use_avatar" id="idp-avatar" value="true" checked>
+ </div>
+ <img src="{{ user_attributes.avatar_url }}" class="avatar" />
+ </label>
+ {% endif %}
+ {% if user_attributes.display_name %}
+ <label class="idp-detail" for="idp-displayname">
+ <div class="check-row">
+ <span class="name">Display name</span>
+ <span class="use">Use</span>
+ <input type="checkbox" name="use_display_name" id="idp-displayname" value="true" checked>
+ </div>
+ <p class="idp-value">{{ user_attributes.display_name }}</p>
+ </label>
+ {% endif %}
+ {% for email in user_attributes.emails %}
+ <label class="idp-detail" for="idp-email{{ loop.index }}">
+ <div class="check-row">
+ <span class="name">E-mail</span>
+ <span class="use">Use</span>
+ <input type="checkbox" name="use_email" id="idp-email{{ loop.index }}" value="{{ email }}" checked>
+ </div>
+ <p class="idp-value">{{ email }}</p>
+ </label>
+ {% endfor %}
+ </section>
+ {% endif %}
+ </form>
+ </main>
+ {% include "sso_footer.html" without context %}
+ <script type="text/javascript">
+ {% include "sso_auth_account_details.js" without context %}
+ </script>
+ </body>
+</html>
diff --git a/synapse/res/templates/sso_auth_account_details.js b/synapse/res/templates/sso_auth_account_details.js
new file mode 100644
index 0000000000..3c45df9078
--- /dev/null
+++ b/synapse/res/templates/sso_auth_account_details.js
@@ -0,0 +1,116 @@
+const usernameField = document.getElementById("field-username");
+const usernameOutput = document.getElementById("field-username-output");
+const form = document.getElementById("form");
+
+// needed to validate on change event when no input was changed
+let needsValidation = true;
+let isValid = false;
+
+function throttle(fn, wait) {
+ let timeout;
+ const throttleFn = function() {
+ const args = Array.from(arguments);
+ if (timeout) {
+ clearTimeout(timeout);
+ }
+ timeout = setTimeout(fn.bind.apply(fn, [null].concat(args)), wait);
+ };
+ throttleFn.cancelQueued = function() {
+ clearTimeout(timeout);
+ };
+ return throttleFn;
+}
+
+function checkUsernameAvailable(username) {
+ let check_uri = 'check?username=' + encodeURIComponent(username);
+ return fetch(check_uri, {
+ // include the cookie
+ "credentials": "same-origin",
+ }).then(function(response) {
+ if(!response.ok) {
+ // for non-200 responses, raise the body of the response as an exception
+ return response.text().then((text) => { throw new Error(text); });
+ } else {
+ return response.json();
+ }
+ }).then(function(json) {
+ if(json.error) {
+ return {message: json.error};
+ } else if(json.available) {
+ return {available: true};
+ } else {
+ return {message: username + " is not available, please choose another."};
+ }
+ });
+}
+
+const allowedUsernameCharacters = new RegExp("^[a-z0-9\\.\\_\\-\\/\\=]+$");
+const allowedCharactersString = "lowercase letters, digits, ., _, -, /, =";
+
+function reportError(error) {
+ throttledCheckUsernameAvailable.cancelQueued();
+ usernameOutput.innerText = error;
+ usernameOutput.classList.add("error");
+ usernameField.parentElement.classList.add("invalid");
+ usernameField.focus();
+}
+
+function validateUsername(username) {
+ isValid = false;
+ needsValidation = false;
+ usernameOutput.innerText = "";
+ usernameField.parentElement.classList.remove("invalid");
+ usernameOutput.classList.remove("error");
+ if (!username) {
+ return reportError("Please provide a username");
+ }
+ if (username.length > 255) {
+ return reportError("Too long, please choose something shorter");
+ }
+ if (!allowedUsernameCharacters.test(username)) {
+ return reportError("Invalid username, please only use " + allowedCharactersString);
+ }
+ usernameOutput.innerText = "Checking if username is available …";
+ throttledCheckUsernameAvailable(username);
+}
+
+const throttledCheckUsernameAvailable = throttle(function(username) {
+ const handleError = function(err) {
+ // don't prevent form submission on error
+ usernameOutput.innerText = "";
+ isValid = true;
+ };
+ try {
+ checkUsernameAvailable(username).then(function(result) {
+ if (!result.available) {
+ reportError(result.message);
+ } else {
+ isValid = true;
+ usernameOutput.innerText = "";
+ }
+ }, handleError);
+ } catch (err) {
+ handleError(err);
+ }
+}, 500);
+
+form.addEventListener("submit", function(evt) {
+ if (needsValidation) {
+ validateUsername(usernameField.value);
+ evt.preventDefault();
+ return;
+ }
+ if (!isValid) {
+ evt.preventDefault();
+ usernameField.focus();
+ return;
+ }
+});
+usernameField.addEventListener("input", function(evt) {
+ validateUsername(usernameField.value);
+});
+usernameField.addEventListener("change", function(evt) {
+ if (needsValidation) {
+ validateUsername(usernameField.value);
+ }
+});
diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
new file mode 100644
index 0000000000..c061698a21
--- /dev/null
+++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -0,0 +1,26 @@
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <meta charset="UTF-8">
+ <title>Authentication failed</title>
+ <meta name="viewport" content="width=device-width, user-scalable=no">
+ <style type="text/css">
+ {% include "sso.css" without context %}
+ </style>
+ </head>
+ <body class="error_page">
+ <header>
+ <h1>That doesn't look right</h1>
+ <p>
+ <strong>We were unable to validate your {{ server_name }} account</strong>
+ via single sign‑on (SSO), because the SSO Identity
+ Provider returned different details than when you logged in.
+ </p>
+ <p>
+ Try the operation again, and ensure that you use the same details on
+ the Identity Provider as when you log into your account.
+ </p>
+ </header>
+ {% include "sso_footer.html" without context %}
+ </body>
+</html>
diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html
index 0d9de9d465..f9d0456f0a 100644
--- a/synapse/res/templates/sso_auth_confirm.html
+++ b/synapse/res/templates/sso_auth_confirm.html
@@ -1,14 +1,29 @@
-<html>
-<head>
- <title>Authentication</title>
-</head>
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <meta charset="UTF-8">
+ <title>Confirm it's you</title>
+ <meta name="viewport" content="width=device-width, user-scalable=no">
+ <style type="text/css">
+ {% include "sso.css" without context %}
+ </style>
+ </head>
<body>
- <div>
+ <header>
+ <h1>Confirm it's you to continue</h1>
<p>
- A client is trying to {{ description | e }}. To confirm this action,
- <a href="{{ redirect_url | e }}">re-authenticate with single sign-on</a>.
- If you did not expect this, your account may be compromised!
+ A client is trying to {{ description }}. To confirm this action
+ re-authorize your account with single sign-on.
</p>
- </div>
+ <p><strong>
+ If you did not expect this, your account may be compromised.
+ </strong></p>
+ </header>
+ <main>
+ <a href="{{ redirect_url }}" class="primary-button">
+ Continue with {{ idp.idp_name }}
+ </a>
+ </main>
+ {% include "sso_footer.html" without context %}
</body>
</html>
diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html
index 03f1419467..1ed3967e87 100644
--- a/synapse/res/templates/sso_auth_success.html
+++ b/synapse/res/templates/sso_auth_success.html
@@ -1,18 +1,28 @@
-<html>
-<head>
- <title>Authentication Successful</title>
- <script>
- if (window.onAuthDone) {
- window.onAuthDone();
- } else if (window.opener && window.opener.postMessage) {
- window.opener.postMessage("authDone", "*");
- }
- </script>
-</head>
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <meta charset="UTF-8">
+ <title>Authentication successful</title>
+ <meta name="viewport" content="width=device-width, user-scalable=no">
+ <style type="text/css">
+ {% include "sso.css" without context %}
+ </style>
+ <script>
+ if (window.onAuthDone) {
+ window.onAuthDone();
+ } else if (window.opener && window.opener.postMessage) {
+ window.opener.postMessage("authDone", "*");
+ }
+ </script>
+ </head>
<body>
- <div>
- <p>Thank you</p>
- <p>You may now close this window and return to the application</p>
- </div>
+ <header>
+ <h1>Thank you</h1>
+ <p>
+ Now we know it’s you, you can close this window and return to the
+ application.
+ </p>
+ </header>
+ {% include "sso_footer.html" without context %}
</body>
</html>
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index 944bc9c9ca..472309c350 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -1,53 +1,69 @@
<!DOCTYPE html>
<html lang="en">
-<head>
- <meta charset="UTF-8">
- <title>SSO error</title>
-</head>
-<body>
+ <head>
+ <meta charset="UTF-8">
+ <title>Authentication failed</title>
+ <meta name="viewport" content="width=device-width, user-scalable=no">
+ <style type="text/css">
+ {% include "sso.css" without context %}
+
+ #error_code {
+ margin-top: 56px;
+ }
+ </style>
+ </head>
+ <body class="error_page">
{# If an error of unauthorised is returned it means we have actively rejected their login #}
{% if error == "unauthorised" %}
- <p>You are not allowed to log in here.</p>
+ <header>
+ <p>You are not allowed to log in here.</p>
+ </header>
{% else %}
- <p>
- There was an error during authentication:
- </p>
- <div id="errormsg" style="margin:20px 80px">{{ error_description | e }}</div>
- <p>
- If you are seeing this page after clicking a link sent to you via email, make
- sure you only click the confirmation link once, and that you open the
- validation link in the same client you're logging in from.
- </p>
- <p>
- Try logging in again from your Matrix client and if the problem persists
- please contact the server's administrator.
- </p>
- <p>Error: <code>{{ error }}</code></p>
+ <header>
+ <h1>There was an error</h1>
+ <p>
+ <strong id="errormsg">{{ error_description }}</strong>
+ </p>
+ <p>
+ If you are seeing this page after clicking a link sent to you via email,
+ make sure you only click the confirmation link once, and that you open
+ the validation link in the same client you're logging in from.
+ </p>
+ <p>
+ Try logging in again from your Matrix client and if the problem persists
+ please contact the server's administrator.
+ </p>
+ <div id="error_code">
+ <p><strong>Error code</strong></p>
+ <p>{{ error }}</p>
+ </div>
+ </header>
+ {% include "sso_footer.html" without context %}
- <script type="text/javascript">
- // Error handling to support Auth0 errors that we might get through a GET request
- // to the validation endpoint. If an error is provided, it's either going to be
- // located in the query string or in a query string-like URI fragment.
- // We try to locate the error from any of these two locations, but if we can't
- // we just don't print anything specific.
- let searchStr = "";
- if (window.location.search) {
- // window.location.searchParams isn't always defined when
- // window.location.search is, so it's more reliable to parse the latter.
- searchStr = window.location.search;
- } else if (window.location.hash) {
- // Replace the # with a ? so that URLSearchParams does the right thing and
- // doesn't parse the first parameter incorrectly.
- searchStr = window.location.hash.replace("#", "?");
- }
+ <script type="text/javascript">
+ // Error handling to support Auth0 errors that we might get through a GET request
+ // to the validation endpoint. If an error is provided, it's either going to be
+ // located in the query string or in a query string-like URI fragment.
+ // We try to locate the error from any of these two locations, but if we can't
+ // we just don't print anything specific.
+ let searchStr = "";
+ if (window.location.search) {
+ // window.location.searchParams isn't always defined when
+ // window.location.search is, so it's more reliable to parse the latter.
+ searchStr = window.location.search;
+ } else if (window.location.hash) {
+ // Replace the # with a ? so that URLSearchParams does the right thing and
+ // doesn't parse the first parameter incorrectly.
+ searchStr = window.location.hash.replace("#", "?");
+ }
- // We might end up with no error in the URL, so we need to check if we have one
- // to print one.
- let errorDesc = new URLSearchParams(searchStr).get("error_description")
- if (errorDesc) {
- document.getElementById("errormsg").innerText = errorDesc;
- }
- </script>
+ // We might end up with no error in the URL, so we need to check if we have one
+ // to print one.
+ let errorDesc = new URLSearchParams(searchStr).get("error_description")
+ if (errorDesc) {
+ document.getElementById("errormsg").innerText = errorDesc;
+ }
+ </script>
{% endif %}
</body>
</html>
diff --git a/synapse/res/templates/sso_footer.html b/synapse/res/templates/sso_footer.html
new file mode 100644
index 0000000000..588a3d508d
--- /dev/null
+++ b/synapse/res/templates/sso_footer.html
@@ -0,0 +1,19 @@
+<footer>
+ <svg role="img" aria-label="[Matrix logo]" viewBox="0 0 200 85" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+ <g id="parent" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
+ <g id="child" transform="translate(-122.000000, -6.000000)" fill="#000000" fill-rule="nonzero">
+ <g id="matrix-logo" transform="translate(122.000000, 6.000000)">
+ <polygon id="left-bracket" points="2.24708861 1.93811009 2.24708861 82.7268844 8.10278481 82.7268844 8.10278481 84.6652459 0 84.6652459 0 0 8.10278481 0 8.10278481 1.93811009"></polygon>
+ <path d="M24.8073418,27.5493174 L24.8073418,31.6376991 L24.924557,31.6376991 C26.0227848,30.0814294 27.3455696,28.8730642 28.8951899,28.0163743 C30.4437975,27.1611927 32.2189873,26.7318422 34.218481,26.7318422 C36.1394937,26.7318422 37.8946835,27.102622 39.4825316,27.8416679 C41.0708861,28.5819706 42.276962,29.8856073 43.1005063,31.7548404 C44.0017722,30.431345 45.2270886,29.2629486 46.7767089,28.2506569 C48.3253165,27.2388679 50.158481,26.7318422 52.2764557,26.7318422 C53.8843038,26.7318422 55.3736709,26.9269101 56.7473418,27.3162917 C58.1189873,27.7056734 59.295443,28.3285835 60.2759494,29.185022 C61.255443,30.0422147 62.02,31.1615927 62.5701266,32.5426532 C63.1187342,33.9262275 63.3936709,35.5898349 63.3936709,37.5372459 L63.3936709,57.7443688 L55.0410127,57.7441174 L55.0410127,40.6319376 C55.0410127,39.6201486 55.0020253,38.6661761 54.9232911,37.7700202 C54.8440506,36.8751211 54.6293671,36.0968606 54.2764557,35.4339817 C53.9232911,34.772611 53.403038,34.2464807 52.7177215,33.8568477 C52.0313924,33.4689743 51.0997468,33.2731523 49.9235443,33.2731523 C48.7473418,33.2731523 47.7962025,33.4983853 47.0706329,33.944578 C46.344557,34.393033 45.7764557,34.9774826 45.3650633,35.6969211 C44.9534177,36.4181193 44.6787342,37.2353431 44.5417722,38.150855 C44.4037975,39.0653615 44.3356962,39.9904257 44.3356962,40.9247908 L44.3356962,57.7443688 L35.9835443,57.7443688 L35.9835443,40.8079009 C35.9835443,39.9124991 35.963038,39.0263982 35.9253165,38.150855 C35.8853165,37.2743064 35.7192405,36.4666349 35.424557,35.7263321 C35.1303797,34.9872862 34.64,34.393033 33.9539241,33.944578 C33.2675949,33.4983853 32.2579747,33.2731523 30.9248101,33.2731523 C30.5321519,33.2731523 30.0126582,33.3608826 29.3663291,33.5365945 C28.7192405,33.7118037 28.0913924,34.0433688 27.4840506,34.5292789 C26.875443,35.0164459 26.3564557,35.7172826 25.9250633,36.6315376 C25.4934177,37.5470495 25.2779747,38.7436 25.2779747,40.2229486 L25.2779747,57.7441174 L16.9260759,57.7443688 L16.9260759,27.5493174 L24.8073418,27.5493174 Z" id="m"></path>
+ <path d="M68.7455696,31.9886202 C69.6075949,30.7033339 70.7060759,29.672189 72.0397468,28.8926716 C73.3724051,28.1141596 74.8716456,27.5596239 76.5387342,27.2283101 C78.2050633,26.8977505 79.8817722,26.7315908 81.5678481,26.7315908 C83.0974684,26.7315908 84.6458228,26.8391798 86.2144304,27.0525982 C87.7827848,27.2675248 89.2144304,27.6865688 90.5086076,28.3087248 C91.8025316,28.9313835 92.8610127,29.7983798 93.6848101,30.9074514 C94.5083544,32.0170257 94.92,33.4870734 94.92,35.3173431 L94.92,51.026844 C94.92,52.3913138 94.998481,53.6941963 95.1556962,54.9400165 C95.3113924,56.1865908 95.5863291,57.120956 95.9787342,57.7436147 L87.5091139,57.7436147 C87.3518987,57.276055 87.2240506,56.7996972 87.1265823,56.3125303 C87.0278481,55.8266202 86.9592405,55.3301523 86.9207595,54.8236294 C85.5873418,56.1865908 84.0182278,57.1405633 82.2156962,57.6857982 C80.4113924,58.2295248 78.5683544,58.503022 76.6860759,58.503022 C75.2346835,58.503022 73.8817722,58.3275615 72.6270886,57.9776459 C71.3718987,57.6269761 70.2744304,57.082244 69.3334177,56.3411872 C68.3921519,55.602644 67.656962,54.6680275 67.1275949,53.5390972 C66.5982278,52.410167 66.3331646,51.065556 66.3331646,49.5087835 C66.3331646,47.7961578 66.6367089,46.384178 67.2455696,45.2756092 C67.8529114,44.1652807 68.6367089,43.2799339 69.5987342,42.6173064 C70.5589873,41.9556844 71.6567089,41.4592165 72.8924051,41.1284055 C74.1273418,40.7978459 75.3721519,40.5356606 76.6270886,40.3398385 C77.8820253,40.1457761 79.116962,39.9896716 80.3329114,39.873033 C81.5483544,39.7558917 82.6270886,39.5804312 83.5681013,39.3469028 C84.5093671,39.1133743 85.2536709,38.7732624 85.8032911,38.3250587 C86.3513924,37.8773578 86.6063291,37.2252881 86.5678481,36.3680954 C86.5678481,35.4731963 86.4210127,34.7620532 86.1268354,34.2366771 C85.8329114,33.7113009 85.4405063,33.3018092 84.9506329,33.0099615 C84.4602532,32.7181138 83.8916456,32.5232972 83.2450633,32.4255119 C82.5977215,32.3294862 81.9010127,32.2797138 81.156962,32.2797138 C79.5098734,32.2797138 78.2159494,32.6303835 77.2746835,33.3312202 C76.3339241,34.0320569 75.7837975,35.2007046 75.6275949,36.8354037 L67.275443,36.8354037 C67.3924051,34.8892495 67.8817722,33.2726495 68.7455696,31.9886202 Z M85.2440506,43.6984752 C84.7149367,43.873433 84.1460759,44.0189798 83.5387342,44.1361211 C82.9306329,44.253011 82.2936709,44.350545 81.6270886,44.4279688 C80.96,44.5066495 80.2934177,44.6034294 79.6273418,44.7203193 C78.9994937,44.8362037 78.3820253,44.9933138 77.7749367,45.1871248 C77.1663291,45.3829468 76.636962,45.6451321 76.1865823,45.9759431 C75.7349367,46.3070055 75.3724051,46.7263009 75.0979747,47.2313156 C74.8232911,47.7375872 74.6863291,48.380356 74.6863291,49.1588679 C74.6863291,49.8979138 74.8232911,50.5218294 75.0979747,51.026844 C75.3724051,51.5338697 75.7455696,51.9328037 76.2159494,52.2246514 C76.6863291,52.5164991 77.2349367,52.7213706 77.8632911,52.8375064 C78.4898734,52.9546477 79.136962,53.012967 79.8037975,53.012967 C81.4506329,53.012967 82.724557,52.740978 83.6273418,52.1952404 C84.5288608,51.6507596 85.1949367,50.9981872 85.6270886,50.2382771 C86.0579747,49.4793725 86.323038,48.7119211 86.4212658,47.9321523 C86.518481,47.1536404 86.5681013,46.5304789 86.5681013,46.063422 L86.5681013,42.9677248 C86.2146835,43.2799339 85.7736709,43.5230147 85.2440506,43.6984752 Z" id="a"></path>
+ <path d="M116.917975,27.5493174 L116.917975,33.0976917 L110.801266,33.0976917 L110.801266,48.0492936 C110.801266,49.4502128 111.036203,50.3850807 111.507089,50.8518862 C111.976962,51.3191945 112.918734,51.5527229 114.33038,51.5527229 C114.801013,51.5527229 115.251392,51.5336183 115.683038,51.4944037 C116.114177,51.4561945 116.526076,51.3968697 116.917975,51.3194459 L116.917975,57.7438661 C116.212152,57.860756 115.427595,57.9381798 114.565316,57.9778972 C113.702785,58.0153523 112.859747,58.0357138 112.036203,58.0357138 C110.742278,58.0357138 109.516456,57.9477321 108.36,57.7722716 C107.202785,57.5975651 106.183544,57.2577046 105.301519,56.7509303 C104.418987,56.2454128 103.722785,55.5242147 103.213418,54.5898495 C102.703038,53.6562385 102.448608,52.4292716 102.448608,50.9099541 L102.448608,33.0976917 L97.3903797,33.0976917 L97.3903797,27.5493174 L102.448608,27.5493174 L102.448608,18.4967596 L110.801013,18.4967596 L110.801013,27.5493174 L116.917975,27.5493174 Z" id="t"></path>
+ <path d="M128.857975,27.5493174 L128.857975,33.1565138 L128.975696,33.1565138 C129.367089,32.2213945 129.896203,31.3559064 130.563544,30.557033 C131.23038,29.7596679 131.99443,29.0776844 132.857215,28.5130936 C133.719241,27.9495083 134.641266,27.5113596 135.622532,27.1988991 C136.601772,26.8879468 137.622025,26.7315908 138.681013,26.7315908 C139.229873,26.7315908 139.836962,26.8296275 140.504304,27.0239413 L140.504304,34.7336477 C140.111646,34.6552183 139.641013,34.586844 139.092658,34.5290275 C138.543291,34.4704569 138.014177,34.4410459 137.504304,34.4410459 C135.974937,34.4410459 134.681013,34.6949358 133.622785,35.2004532 C132.564051,35.7067248 131.711392,36.397255 131.064051,37.2735523 C130.417215,38.1501009 129.955443,39.1714422 129.681266,40.3398385 C129.407089,41.5074807 129.269873,42.7736624 129.269873,44.1361211 L129.269873,57.7438661 L120.917722,57.7438661 L120.917722,27.5493174 L128.857975,27.5493174 Z" id="r"></path>
+ <path d="M144.033165,22.8767376 L144.033165,16.0435798 L152.386076,16.0435798 L152.386076,22.8767376 L144.033165,22.8767376 Z M152.386076,27.5493174 L152.386076,57.7438661 L144.033165,57.7438661 L144.033165,27.5493174 L152.386076,27.5493174 Z" id="i"></path>
+ <polygon id="x" points="156.738228 27.5493174 166.266582 27.5493174 171.619494 35.4337303 176.913418 27.5493174 186.147848 27.5493174 176.148861 41.6831927 187.383544 57.7441174 177.85443 57.7441174 171.501772 48.2245028 165.148861 57.7441174 155.797468 57.7441174 166.737468 41.8589046"></polygon>
+ <polygon id="right-bracket" points="197.580759 82.7268844 197.580759 1.93811009 191.725063 1.93811009 191.725063 0 199.828354 0 199.828354 84.6652459 191.725063 84.6652459 191.725063 82.7268844"></polygon>
+ </g>
+ </g>
+ </g>
+ </svg>
+ <p>An open network for secure, decentralized communication.<br>© 2021 The Matrix.org Foundation C.I.C.</p>
+</footer>
\ No newline at end of file
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
new file mode 100644
index 0000000000..53b82db84e
--- /dev/null
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -0,0 +1,61 @@
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <meta charset="UTF-8">
+ <title>Choose identity provider</title>
+ <style type="text/css">
+ {% include "sso.css" without context %}
+
+ .providers {
+ list-style: none;
+ padding: 0;
+ }
+
+ .providers li {
+ margin: 12px;
+ }
+
+ .providers a {
+ display: block;
+ border-radius: 4px;
+ border: 1px solid #17191C;
+ padding: 8px;
+ text-align: center;
+ text-decoration: none;
+ color: #17191C;
+ display: flex;
+ align-items: center;
+ font-weight: bold;
+ }
+
+ .providers a img {
+ width: 24px;
+ height: 24px;
+ }
+ .providers a span {
+ flex: 1;
+ }
+ </style>
+ </head>
+ <body>
+ <header>
+ <h1>Log in to {{ server_name }} </h1>
+ <p>Choose an identity provider to log in</p>
+ </header>
+ <main>
+ <ul class="providers">
+ {% for p in providers %}
+ <li>
+ <a href="pick_idp?idp={{ p.idp_id }}&redirectUrl={{ redirect_url | urlencode }}">
+ {% if p.idp_icon %}
+ <img src="{{ p.idp_icon | mxc_to_http(32, 32) }}"/>
+ {% endif %}
+ <span>{{ p.idp_name }}</span>
+ </a>
+ </li>
+ {% endfor %}
+ </ul>
+ </main>
+ {% include "sso_footer.html" without context %}
+ </body>
+</html>
diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html
new file mode 100644
index 0000000000..68c8b9f33a
--- /dev/null
+++ b/synapse/res/templates/sso_new_user_consent.html
@@ -0,0 +1,32 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <title>Agree to terms and conditions</title>
+ <meta name="viewport" content="width=device-width, user-scalable=no">
+ <style type="text/css">
+ {% include "sso.css" without context %}
+
+ #consent_form {
+ margin-top: 56px;
+ }
+ </style>
+</head>
+ <body>
+ <header>
+ <h1>Your account is nearly ready</h1>
+ <p>Agree to the terms to create your account.</p>
+ </header>
+ <main>
+ {% include "sso_partial_profile.html" %}
+ <form method="post" action="{{my_url}}" id="consent_form">
+ <p>
+ <input id="accepted_version" type="checkbox" name="accepted_version" value="{{ consent_version }}" required>
+ <label for="accepted_version">I have read and agree to the <a href="{{ terms_url }}" target="_blank" rel="noopener">terms and conditions</a>.</label>
+ </p>
+ <input type="submit" class="primary-button" value="Continue"/>
+ </form>
+ </main>
+ {% include "sso_footer.html" without context %}
+ </body>
+</html>
diff --git a/synapse/res/templates/sso_partial_profile.html b/synapse/res/templates/sso_partial_profile.html
new file mode 100644
index 0000000000..c9c76c455e
--- /dev/null
+++ b/synapse/res/templates/sso_partial_profile.html
@@ -0,0 +1,19 @@
+{# html fragment to be included in SSO pages, to show the user's profile #}
+
+<div class="profile{% if user_profile.avatar_url %} with-avatar{% endif %}">
+ {% if user_profile.avatar_url %}
+ <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
+ {% endif %}
+ {# users that signed up with SSO will have a display_name of some sort;
+ however that is not the case for users who signed up via other
+ methods, so we need to handle that.
+ #}
+ {% if user_profile.display_name %}
+ <div class="display-name">{{ user_profile.display_name }}</div>
+ {% else %}
+ {# split the userid on ':', take the part before the first ':',
+ and then remove the leading '@'. #}
+ <div class="display-name">{{ user_id.split(":")[0][1:] }}</div>
+ {% endif %}
+ <div class="user-id">{{ user_id }}</div>
+</div>
diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
index 20a15e1e74..1b01471ac8 100644
--- a/synapse/res/templates/sso_redirect_confirm.html
+++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -2,13 +2,39 @@
<html lang="en">
<head>
<meta charset="UTF-8">
- <title>SSO redirect confirmation</title>
+ <title>Continue to your account</title>
+ <meta name="viewport" content="width=device-width, user-scalable=no">
+ <style type="text/css">
+ {% include "sso.css" without context %}
+
+ .confirm-trust {
+ margin: 34px 0;
+ color: #8D99A5;
+ }
+ .confirm-trust strong {
+ color: #17191C;
+ }
+
+ .confirm-trust::before {
+ content: "";
+ background-image: url('data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMTgiIGhlaWdodD0iMTgiIHZpZXdCb3g9IjAgMCAxOCAxOCIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZmlsbC1ydWxlPSJldmVub2RkIiBjbGlwLXJ1bGU9ImV2ZW5vZGQiIGQ9Ik0xNi41IDlDMTYuNSAxMy4xNDIxIDEzLjE0MjEgMTYuNSA5IDE2LjVDNC44NTc4NiAxNi41IDEuNSAxMy4xNDIxIDEuNSA5QzEuNSA0Ljg1Nzg2IDQuODU3ODYgMS41IDkgMS41QzEzLjE0MjEgMS41IDE2LjUgNC44NTc4NiAxNi41IDlaTTcuMjUgOUM3LjI1IDkuNDY1OTYgNy41Njg2OSA5Ljg1NzQ4IDggOS45Njg1VjEyLjM3NUM4IDEyLjkyNzMgOC40NDc3MiAxMy4zNzUgOSAxMy4zNzVIMTAuMTI1QzEwLjY3NzMgMTMuMzc1IDExLjEyNSAxMi45MjczIDExLjEyNSAxMi4zNzVDMTEuMTI1IDExLjgyMjcgMTAuNjc3MyAxMS4zNzUgMTAuMTI1IDExLjM3NUgxMFY5QzEwIDguOTY1NDggOS45OTgyNSA4LjkzMTM3IDkuOTk0ODQgOC44OTc3NkM5Ljk0MzYzIDguMzkzNSA5LjUxNzc3IDggOSA4SDguMjVDNy42OTc3MiA4IDcuMjUgOC40NDc3MiA3LjI1IDlaTTkgNy41QzkuNjIxMzIgNy41IDEwLjEyNSA2Ljk5NjMyIDEwLjEyNSA2LjM3NUMxMC4xMjUgNS43NTM2OCA5LjYyMTMyIDUuMjUgOSA1LjI1QzguMzc4NjggNS4yNSA3Ljg3NSA1Ljc1MzY4IDcuODc1IDYuMzc1QzcuODc1IDYuOTk2MzIgOC4zNzg2OCA3LjUgOSA3LjVaIiBmaWxsPSIjQzFDNkNEIi8+Cjwvc3ZnPgoK');
+ background-repeat: no-repeat;
+ width: 24px;
+ height: 24px;
+ display: block;
+ float: left;
+ }
+ </style>
</head>
<body>
- <p>The application at <span style="font-weight:bold">{{ display_url | e }}</span> is requesting full access to your <span style="font-weight:bold">{{ server_name }}</span> Matrix account.</p>
- <p>If you don't recognise this address, you should ignore this and close this tab.</p>
- <p>
- <a href="{{ redirect_url | e }}">I trust this address</a>
- </p>
+ <header>
+ <h1>Continue to your account</h1>
+ </header>
+ <main>
+ {% include "sso_partial_profile.html" %}
+ <p class="confirm-trust">Continuing will grant <strong>{{ display_url }}</strong> access to your account.</p>
+ <a href="{{ redirect_url }}" class="primary-button">Continue</a>
+ </main>
+ {% include "sso_footer.html" without context %}
</body>
-</html>
\ No newline at end of file
+</html>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 55ddebb4fe..8457db1e22 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
+# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
+
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,10 +38,14 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
+ ForwardExtremitiesRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
+ MakeRoomAdminRestServlet,
+ RoomEventContextServlet,
RoomMembersRestServlet,
RoomRestServlet,
+ RoomStateRestServlet,
ShutdownRoomRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
@@ -50,6 +56,7 @@ from synapse.rest.admin.users import (
PushersRestServlet,
ResetPasswordRestServlet,
SearchUsersRestServlet,
+ ShadowBanRestServlet,
UserAdminServlet,
UserMediaRestServlet,
UserMembershipRestServlet,
@@ -208,6 +215,7 @@ def register_servlets(hs, http_server):
"""
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
+ RoomStateRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
RoomMembersRestServlet(hs).register(http_server)
DeleteRoomRestServlet(hs).register(http_server)
@@ -228,6 +236,10 @@ def register_servlets(hs, http_server):
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
+ MakeRoomAdminRestServlet(hs).register(http_server)
+ ShadowBanRestServlet(hs).register(http_server)
+ ForwardExtremitiesRestServlet(hs).register(http_server)
+ RoomEventContextServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index e09234c644..7681e55b58 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -15,10 +15,9 @@
import re
-import twisted.web.server
-
-import synapse.api.auth
+from synapse.api.auth import Auth
from synapse.api.errors import AuthError
+from synapse.http.site import SynapseRequest
from synapse.types import UserID
@@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"):
return patterns
-async def assert_requester_is_admin(
- auth: synapse.api.auth.Auth, request: twisted.web.server.Request
-) -> None:
+async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
"""Verify that the requester is an admin user
Args:
- auth: api.auth.Auth singleton
+ auth: Auth singleton
request: incoming request
Raises:
@@ -53,11 +50,11 @@ async def assert_requester_is_admin(
await assert_user_is_admin(auth, requester.user)
-async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
+async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
"""Verify that the given user is an admin user
Args:
- auth: api.auth.Auth singleton
+ auth: Auth singleton
user_id: user to check
Raises:
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index ffd3aa38f7..5996de11c3 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.servlet import (
@@ -20,8 +21,12 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -35,14 +40,16 @@ class DeviceRestServlet(RestServlet):
"/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
- async def on_GET(self, request, user_id, device_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id, device_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -58,7 +65,9 @@ class DeviceRestServlet(RestServlet):
)
return 200, device
- async def on_DELETE(self, request, user_id, device_id):
+ async def on_DELETE(
+ self, request: SynapseRequest, user_id: str, device_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -72,7 +81,9 @@ class DeviceRestServlet(RestServlet):
await self.device_handler.delete_device(target_user.to_string(), device_id)
return 200, {}
- async def on_PUT(self, request, user_id, device_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str, device_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -97,7 +108,7 @@ class DevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
@@ -107,7 +118,9 @@ class DevicesRestServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -130,13 +143,15 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
- async def on_POST(self, request, user_id):
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index fd482f0e32..381c3fe685 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -14,10 +14,16 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -45,12 +51,12 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0)
@@ -106,26 +112,28 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, report_id):
+ async def on_GET(
+ self, request: SynapseRequest, report_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
message = (
"The report_id parameter must be a string representing a positive integer."
)
try:
- report_id = int(report_id)
+ resolved_report_id = int(report_id)
except ValueError:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
- if report_id < 0:
+ if resolved_report_id < 0:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
- ret = await self.store.get_event_report(report_id)
+ ret = await self.store.get_event_report(resolved_report_id)
if not ret:
raise NotFoundError("Event report not found")
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index d0c86b204a..ebc587aa06 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -22,8 +22,7 @@ logger = logging.getLogger(__name__)
class DeleteGroupAdminRestServlet(RestServlet):
- """Allows deleting of local groups
- """
+ """Allows deleting of local groups"""
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index c82b4f87d6..40646ef241 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,14 +15,20 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
)
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -39,11 +45,13 @@ class QuarantineMediaInRoom(RestServlet):
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, room_id: str):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -64,11 +72,13 @@ class QuarantineMediaByUser(RestServlet):
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, user_id: str):
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -91,11 +101,13 @@ class QuarantineMediaByID(RestServlet):
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, server_name: str, media_id: str):
+ async def on_POST(
+ self, request: SynapseRequest, server_name: str, media_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -109,17 +121,41 @@ class QuarantineMediaByID(RestServlet):
return 200, {}
+class ProtectMediaByID(RestServlet):
+ """Protect local media from being quarantined."""
+
+ PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
+
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(
+ self, request: SynapseRequest, media_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Protecting local media by ID: %s", media_id)
+
+ # Quarantine this media id
+ await self.store.mark_local_media_as_safe(media_id)
+
+ return 200, {}
+
+
class ListMediaInRoom(RestServlet):
- """Lists all of the media in a given room.
- """
+ """Lists all of the media in a given room."""
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@@ -133,11 +169,11 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_media_cache")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -149,18 +185,19 @@ class PurgeMediaCacheRestServlet(RestServlet):
class DeleteMediaByID(RestServlet):
- """Delete local media by a given ID. Removes it from this server.
- """
+ """Delete local media by a given ID. Removes it from this server."""
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_DELETE(self, request, server_name: str, media_id: str):
+ async def on_DELETE(
+ self, request: SynapseRequest, server_name: str, media_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if self.server_name != server_name:
@@ -182,13 +219,15 @@ class DeleteMediaByDateSize(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_POST(self, request, server_name: str):
+ async def on_POST(
+ self, request: SynapseRequest, server_name: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -222,7 +261,7 @@ class DeleteMediaByDateSize(RestServlet):
return 200, {"deleted_media": deleted_media, "total": total}
-def register_servlets_for_media_repo(hs, http_server):
+def register_servlets_for_media_repo(hs: "HomeServer", http_server):
"""
Media repo specific APIs.
"""
@@ -230,6 +269,7 @@ def register_servlets_for_media_repo(hs, http_server):
QuarantineMediaInRoom(hs).register(http_server)
QuarantineMediaByID(hs).register(http_server)
QuarantineMediaByUser(hs).register(http_server)
+ ProtectMediaByID(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
DeleteMediaByID(hs).register(http_server)
DeleteMediaByDateSize(hs).register(http_server)
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
index 8b7bb6d44e..49966ee3e0 100644
--- a/synapse/rest/admin/purge_room_servlet.py
+++ b/synapse/rest/admin/purge_room_servlet.py
@@ -12,13 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Tuple
+
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class PurgeRoomServlet(RestServlet):
@@ -36,16 +43,12 @@ class PurgeRoomServlet(RestServlet):
PATTERNS = admin_patterns("/purge_room$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.pagination_handler = hs.get_pagination_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 9f01cb20cd..cfe1bebb91 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,28 +14,78 @@
# limitations under the License.
import logging
from http import HTTPStatus
+from typing import TYPE_CHECKING, List, Optional, Tuple
+from urllib import parse as urlparse
-from synapse.api.constants import EventTypes, JoinRules
-from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.filtering import Filter
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
- parse_list_from_args,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
)
from synapse.storage.databases.main.room import RoomSortOrder
-from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+from synapse.util import json_decoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
+class ResolveRoomIdMixin:
+ def __init__(self, hs: "HomeServer"):
+ self.room_member_handler = hs.get_room_member_handler()
+
+ async def resolve_room_id(
+ self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None
+ ) -> Tuple[str, Optional[List[str]]]:
+ """
+ Resolve a room identifier to a room ID, if necessary.
+
+ This also performanes checks to ensure the room ID is of the proper form.
+
+ Args:
+ room_identifier: The room ID or alias.
+ remote_room_hosts: The potential remote room hosts to use.
+
+ Returns:
+ The resolved room ID.
+
+ Raises:
+ SynapseError if the room ID is of the wrong form.
+ """
+ if RoomID.is_valid(room_identifier):
+ resolved_room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ (
+ room_id,
+ remote_room_hosts,
+ ) = await self.room_member_handler.lookup_room_alias(room_alias)
+ resolved_room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+ if not resolved_room_id:
+ raise SynapseError(
+ 400, "Unknown room ID or room alias %s" % room_identifier
+ )
+ return resolved_room_id, remote_room_hosts
+
+
class ShutdownRoomRestServlet(RestServlet):
"""Shuts down a room by removing all local users from the room and blocking
all future invites and joins to the room. Any local aliases will be repointed
@@ -45,12 +95,14 @@ class ShutdownRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -86,13 +138,15 @@ class DeleteRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
self.pagination_handler = hs.get_pagination_handler()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -146,12 +200,12 @@ class ListRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -236,19 +290,24 @@ class RoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room_with_stats(room_id)
if not ret:
raise NotFoundError("Room not found")
- return 200, ret
+ members = await self.store.get_users_in_room(room_id)
+ ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
+
+ return (200, ret)
class RoomMembersRestServlet(RestServlet):
@@ -258,12 +317,14 @@ class RoomMembersRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
@@ -276,18 +337,62 @@ class RoomMembersRestServlet(RestServlet):
return 200, ret
-class JoinRoomAliasServlet(RestServlet):
+class RoomStateRestServlet(RestServlet):
+ """
+ Get full state within a room.
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self._event_serializer = hs.get_event_client_serializer()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ ret = await self.store.get_room(room_id)
+ if not ret:
+ raise NotFoundError("Room not found")
+
+ event_ids = await self.store.get_current_state_ids(room_id)
+ events = await self.store.get_events(event_ids.values())
+ now = self.clock.time_msec()
+ room_state = await self._event_serializer.serialize_events(
+ events.values(),
+ now,
+ # We don't bother bundling aggregations in when asked for state
+ # events, as clients won't use them.
+ bundle_aggregations=False,
+ )
+ ret = {"state": room_state}
+
+ return 200, ret
+
+
+class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
- self.room_member_handler = hs.get_room_member_handler()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
- async def on_POST(self, request, room_identifier):
+ async def on_POST(
+ self, request: SynapseRequest, room_identifier: str
+ ) -> Tuple[int, JsonDict]:
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -302,21 +407,16 @@ class JoinRoomAliasServlet(RestServlet):
if not await self.admin_handler.get_user(target_user):
raise NotFoundError("User not found")
- if RoomID.is_valid(room_identifier):
- room_id = room_identifier
- try:
- remote_room_hosts = parse_list_from_args(request.args, "server_name")
- except KeyError:
- remote_room_hosts = None
- elif RoomAlias.is_valid(room_identifier):
- handler = self.room_member_handler
- room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
- room_id = room_id.to_string()
- else:
- raise SynapseError(
- 400, "%s was not legal room ID or room alias" % (room_identifier,)
- )
+ # Get the room ID from the identifier.
+ try:
+ remote_room_hosts = [
+ x.decode("ascii") for x in request.args[b"server_name"]
+ ] # type: Optional[List[str]]
+ except Exception:
+ remote_room_hosts = None
+ room_id, remote_room_hosts = await self.resolve_room_id(
+ room_identifier, remote_room_hosts
+ )
fake_requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
@@ -349,3 +449,249 @@ class JoinRoomAliasServlet(RestServlet):
)
return 200, {"room_id": room_id}
+
+
+class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
+ """Allows a server admin to get power in a room if a local user has power in
+ a room. Will also invite the user if they're not in the room and it's a
+ private room. Can specify another user (rather than the admin user) to be
+ granted power, e.g.:
+
+ POST/_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin
+ {
+ "user_id": "@foo:example.com"
+ }
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.state_handler = hs.get_state_handler()
+ self.is_mine_id = hs.is_mine_id
+
+ async def on_POST(
+ self, request: SynapseRequest, room_identifier: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+ content = parse_json_object_from_request(request, allow_empty_body=True)
+
+ room_id, _ = await self.resolve_room_id(room_identifier)
+
+ # Which user to grant room admin rights to.
+ user_to_add = content.get("user_id", requester.user.to_string())
+
+ # Figure out which local users currently have power in the room, if any.
+ room_state = await self.state_handler.get_current_state(room_id)
+ if not room_state:
+ raise SynapseError(400, "Server not in room")
+
+ create_event = room_state[(EventTypes.Create, "")]
+ power_levels = room_state.get((EventTypes.PowerLevels, ""))
+
+ if power_levels is not None:
+ # We pick the local user with the highest power.
+ user_power = power_levels.content.get("users", {})
+ admin_users = [
+ user_id for user_id in user_power if self.is_mine_id(user_id)
+ ]
+ admin_users.sort(key=lambda user: user_power[user])
+
+ if not admin_users:
+ raise SynapseError(400, "No local admin user in room")
+
+ admin_user_id = None
+
+ for admin_user in reversed(admin_users):
+ if room_state.get((EventTypes.Member, admin_user)):
+ admin_user_id = admin_user
+ break
+
+ if not admin_user_id:
+ raise SynapseError(
+ 400,
+ "No local admin user in room",
+ )
+
+ pl_content = power_levels.content
+ else:
+ # If there is no power level events then the creator has rights.
+ pl_content = {}
+ admin_user_id = create_event.sender
+ if not self.is_mine_id(admin_user_id):
+ raise SynapseError(
+ 400,
+ "No local admin user in room",
+ )
+
+ # Grant the user power equal to the room admin by attempting to send an
+ # updated power level event.
+ new_pl_content = dict(pl_content)
+ new_pl_content["users"] = dict(pl_content.get("users", {}))
+ new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id]
+
+ fake_requester = create_requester(
+ admin_user_id,
+ authenticated_entity=requester.authenticated_entity,
+ )
+
+ try:
+ await self.event_creation_handler.create_and_send_nonmember_event(
+ fake_requester,
+ event_dict={
+ "content": new_pl_content,
+ "sender": admin_user_id,
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "room_id": room_id,
+ },
+ )
+ except AuthError:
+ # The admin user we found turned out not to have enough power.
+ raise SynapseError(
+ 400, "No local admin user in room with power to update power levels."
+ )
+
+ # Now we check if the user we're granting admin rights to is already in
+ # the room. If not and it's not a public room we invite them.
+ member_event = room_state.get((EventTypes.Member, user_to_add))
+ is_joined = False
+ if member_event:
+ is_joined = member_event.content["membership"] in (
+ Membership.JOIN,
+ Membership.INVITE,
+ )
+
+ if is_joined:
+ return 200, {}
+
+ join_rules = room_state.get((EventTypes.JoinRules, ""))
+ is_public = False
+ if join_rules:
+ is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
+
+ if is_public:
+ return 200, {}
+
+ await self.room_member_handler.update_membership(
+ fake_requester,
+ target=UserID.from_string(user_to_add),
+ room_id=room_id,
+ action=Membership.INVITE,
+ )
+
+ return 200, {}
+
+
+class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
+ """Allows a server admin to get or clear forward extremities.
+
+ Clearing does not require restarting the server.
+
+ Clear forward extremities:
+ DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+
+ Get forward_extremities:
+ GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_DELETE(
+ self, request: SynapseRequest, room_identifier: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ room_id, _ = await self.resolve_room_id(room_identifier)
+
+ deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
+ return 200, {"deleted": deleted_count}
+
+ async def on_GET(
+ self, request: SynapseRequest, room_identifier: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ room_id, _ = await self.resolve_room_id(room_identifier)
+
+ extremities = await self.store.get_forward_extremities_for_room(room_id)
+ return 200, {"count": len(extremities), "results": extremities}
+
+
+class RoomEventContextServlet(RestServlet):
+ """
+ Provide the context for an event.
+ This API is designed to be used when system administrators wish to look at
+ an abuse report and understand what happened during and immediately prior
+ to this event.
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.clock = hs.get_clock()
+ self.room_context_handler = hs.get_room_context_handler()
+ self._event_serializer = hs.get_event_client_serializer()
+ self.auth = hs.get_auth()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ limit = parse_integer(request, "limit", default=10)
+
+ # picking the API shape for symmetry with /messages
+ filter_str = parse_string(request, b"filter", encoding="utf-8")
+ if filter_str:
+ filter_json = urlparse.unquote(filter_str)
+ event_filter = Filter(
+ json_decoder.decode(filter_json)
+ ) # type: Optional[Filter]
+ else:
+ event_filter = None
+
+ results = await self.room_context_handler.get_event_context(
+ requester,
+ room_id,
+ event_id,
+ limit,
+ event_filter,
+ use_admin_priviledge=True,
+ )
+
+ if not results:
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+
+ time_now = self.clock.time_msec()
+ results["events_before"] = await self._event_serializer.serialize_events(
+ results["events_before"], time_now
+ )
+ results["event"] = await self._event_serializer.serialize_event(
+ results["event"], time_now
+ )
+ results["events_after"] = await self._event_serializer.serialize_events(
+ results["events_after"], time_now
+ )
+ results["state"] = await self._event_serializer.serialize_events(
+ results["state"],
+ time_now,
+ # No need to bundle aggregations for state events
+ bundle_aggregations=False,
+ )
+
+ return 200, results
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 375d055445..f495666f4a 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,17 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Optional, Tuple
+
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class SendServerNoticeServlet(RestServlet):
@@ -44,17 +51,13 @@ class SendServerNoticeServlet(RestServlet):
}
"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
self.snm = hs.get_server_notices_manager()
- def register(self, json_resource):
+ def register(self, json_resource: HttpServer):
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
@@ -66,7 +69,9 @@ class SendServerNoticeServlet(RestServlet):
self.__class__.__name__,
)
- async def on_POST(self, request, txn_id=None):
+ async def on_POST(
+ self, request: SynapseRequest, txn_id: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content"))
@@ -90,7 +95,7 @@ class SendServerNoticeServlet(RestServlet):
return 200, {"event_id": event.event_id}
- def on_PUT(self, request, txn_id):
+ def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b0ff5e1ead..309bd2771b 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -16,7 +16,7 @@ import hashlib
import hmac
import logging
from http import HTTPStatus
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -35,6 +35,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
)
from synapse.rest.client.v2_alpha._base import client_patterns
+from synapse.storage.databases.main.media_repository import MediaSortOrder
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
@@ -42,28 +43,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-_GET_PUSHERS_ALLOWED_KEYS = {
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
-}
-
class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, List[JsonDict]]:
target_user = UserID.from_string(user_id)
await assert_requester_is_admin(self.auth, request)
@@ -94,17 +86,32 @@ class UsersRestServletV2(RestServlet):
The parameter `deactivated` can be used to include deactivated users.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
+
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter from must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter limit must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
user_id = parse_string(request, "user_id", default=None)
name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
@@ -114,7 +121,7 @@ class UsersRestServletV2(RestServlet):
start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
- if len(users) >= limit:
+ if (start + limit) < total:
ret["next_token"] = str(start + len(users))
return 200, ret
@@ -148,7 +155,7 @@ class UserRestServletV2(RestServlet):
otherwise an error.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
@@ -160,7 +167,9 @@ class UserRestServletV2(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -174,7 +183,9 @@ class UserRestServletV2(RestServlet):
return 200, ret
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -255,10 +266,13 @@ class UserRestServletV2(RestServlet):
if deactivate and not user["deactivated"]:
await self.deactivate_account_handler.deactivate_account(
- target_user.to_string(), False
+ target_user.to_string(), False, requester, by_admin=True
)
elif not deactivate and user["deactivated"]:
- if "password" not in body:
+ if (
+ "password" not in body
+ and self.auth_handler.can_change_password()
+ ):
raise SynapseError(
400, "Must provide a password to re-activate an account."
)
@@ -268,6 +282,8 @@ class UserRestServletV2(RestServlet):
)
user = await self.admin_handler.get_user(target_user)
+ assert user is not None
+
return 200, user
else: # create user
@@ -320,14 +336,15 @@ class UserRestServletV2(RestServlet):
data={},
)
- if "avatar_url" in body and type(body["avatar_url"]) == str:
+ if "avatar_url" in body and isinstance(body["avatar_url"], str):
await self.profile_handler.set_avatar_url(
- user_id, requester, body["avatar_url"], True
+ target_user, requester, body["avatar_url"], True
)
- ret = await self.admin_handler.get_user(target_user)
+ user = await self.admin_handler.get_user(target_user)
+ assert user is not None
- return 201, ret
+ return 201, user
class UserRegisterServlet(RestServlet):
@@ -341,10 +358,10 @@ class UserRegisterServlet(RestServlet):
PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor()
- self.nonces = {}
+ self.nonces = {} # type: Dict[str, int]
self.hs = hs
def _clear_old_nonces(self):
@@ -357,7 +374,7 @@ class UserRegisterServlet(RestServlet):
if now - v > self.NONCE_TIMEOUT:
del self.nonces[k]
- def on_GET(self, request):
+ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""
Generate a new nonce.
"""
@@ -367,7 +384,7 @@ class UserRegisterServlet(RestServlet):
self.nonces[nonce] = int(self.reactor.seconds())
return 200, {"nonce": nonce}
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
self._clear_old_nonces()
if not self.hs.config.registration_shared_secret:
@@ -420,6 +437,9 @@ class UserRegisterServlet(RestServlet):
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
raise SynapseError(400, "Invalid user type")
+ if "mac" not in body:
+ raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+
got_mac = body["mac"]
want_mac_builder = hmac.new(
@@ -470,12 +490,14 @@ class WhoisRestServlet(RestServlet):
client_patterns("/admin" + path_regex, v1=True)
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
auth_user = requester.user
@@ -494,12 +516,24 @@ class WhoisRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
self.auth = hs.get_auth()
+ self.is_mine = hs.is_mine
+ self.store = hs.get_datastore()
+
+ async def on_POST(
+ self, request: SynapseRequest, target_user_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ if not self.is_mine(UserID.from_string(target_user_id)):
+ raise SynapseError(400, "Can only deactivate local users")
+
+ if not await self.store.get_user_by_id(target_user_id):
+ raise NotFoundError("User not found")
- async def on_POST(self, request, target_user_id):
- await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True)
erase = body.get("erase", False)
if not isinstance(erase, bool):
@@ -509,10 +543,8 @@ class DeactivateAccountRestServlet(RestServlet):
Codes.BAD_JSON,
)
- UserID.from_string(target_user_id)
-
result = await self._deactivate_account_handler.deactivate_account(
- target_user_id, erase
+ target_user_id, erase, requester, by_admin=True
)
if result:
id_server_unbind_result = "success"
@@ -534,7 +566,7 @@ class AccountValidityRenewServlet(RestServlet):
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
@@ -564,18 +596,20 @@ class ResetPasswordRestServlet(RestServlet):
}
Returns:
200 OK with empty object if success otherwise an error.
- """
+ """
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler()
- async def on_POST(self, request, target_user_id):
+ async def on_POST(
+ self, request: SynapseRequest, target_user_id: str
+ ) -> Tuple[int, JsonDict]:
"""Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse.
"""
@@ -610,12 +644,14 @@ class SearchUsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request, target_user_id):
+ async def on_GET(
+ self, request: SynapseRequest, target_user_id: str
+ ) -> Tuple[int, Optional[List[JsonDict]]]:
"""Get request to search user table for specific users according to
search term.
This needs user to have a administrator access in Synapse.
@@ -666,12 +702,14 @@ class UserAdminServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -683,7 +721,9 @@ class UserAdminServlet(RestServlet):
return 200, {"admin": is_admin}
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
@@ -714,21 +754,16 @@ class UserMembershipRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- if not self.is_mine(UserID.from_string(user_id)):
- raise SynapseError(400, "Can only lookup local users")
-
- user = await self.store.get_user_by_id(user_id)
- if user is None:
- raise NotFoundError("Unknown user")
-
room_ids = await self.store.get_rooms_for_user(user_id)
ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
return 200, ret
@@ -744,12 +779,12 @@ class PushersRestServlet(RestServlet):
Returns:
pushers: Dictionary containing pushers information.
- total: Number of pushers in dictonary `pushers`.
+ total: Number of pushers in dictionary `pushers`.
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -767,10 +802,7 @@ class PushersRestServlet(RestServlet):
pushers = await self.store.get_pushers_by_user_id(user_id)
- filtered_pushers = [
- {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
- for p in pushers
- ]
+ filtered_pushers = [p.as_dict() for p in pushers]
return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
@@ -793,7 +825,7 @@ class UserMediaRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -801,6 +833,9 @@ class UserMediaRestServlet(RestServlet):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
@@ -827,8 +862,33 @@ class UserMediaRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
+ # If neither `order_by` nor `dir` is set, set the default order
+ # to newest media is on top for backward compatibility.
+ if b"order_by" not in request.args and b"dir" not in request.args:
+ order_by = MediaSortOrder.CREATED_TS.value
+ direction = "b"
+ else:
+ order_by = parse_string(
+ request,
+ "order_by",
+ default=MediaSortOrder.CREATED_TS.value,
+ allowed_values=(
+ MediaSortOrder.MEDIA_ID.value,
+ MediaSortOrder.UPLOAD_NAME.value,
+ MediaSortOrder.CREATED_TS.value,
+ MediaSortOrder.LAST_ACCESS_TS.value,
+ MediaSortOrder.MEDIA_LENGTH.value,
+ MediaSortOrder.MEDIA_TYPE.value,
+ MediaSortOrder.QUARANTINED_BY.value,
+ MediaSortOrder.SAFE_FROM_QUARANTINE.value,
+ ),
+ )
+ direction = parse_string(
+ request, "dir", default="f", allowed_values=("f", "b")
+ )
+
media, total = await self.store.get_local_media_by_user_paginate(
- start, limit, user_id
+ start, limit, user_id, order_by, direction
)
ret = {"media": media, "total": total}
@@ -860,7 +920,9 @@ class UserTokenRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- async def on_POST(self, request, user_id):
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
@@ -885,3 +947,41 @@ class UserTokenRestServlet(RestServlet):
)
return 200, {"access_token": token}
+
+
+class ShadowBanRestServlet(RestServlet):
+ """An admin API for shadow-banning a user.
+
+ A shadow-banned users receives successful responses to their client-server
+ API requests, but the events are not propagated into rooms.
+
+ Shadow-banning a user should be used as a tool of last resort and may lead
+ to confusing or broken behaviour for the client.
+
+ Example:
+
+ POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
+ {}
+
+ 200 OK
+ {}
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Only local users can be shadow-banned")
+
+ await self.store.set_shadow_banned(UserID.from_string(user_id), True)
+
+ return 200, {}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d7ae148214..e4c352f572 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,12 +14,16 @@
# limitations under the License.
import logging
-from typing import Awaitable, Callable, Dict, Optional
+import re
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.urls import CLIENT_API_PREFIX
from synapse.appservice import ApplicationService
-from synapse.http.server import finish_request
+from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http import get_request_uri
+from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@@ -30,6 +34,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -42,7 +49,7 @@ class LoginRestServlet(RestServlet):
JWT_TYPE_DEPRECATED = "m.login.jwt"
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -57,11 +64,14 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+ self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self.auth = hs.get_auth()
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
+ self._sso_handler = hs.get_sso_handler()
+
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
@@ -86,8 +96,27 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
- # While its valid for us to advertise this login type generally,
+ sso_flow = {
+ "type": LoginRestServlet.SSO_TYPE,
+ "identity_providers": [
+ _get_auth_flow_dict_for_idp(
+ idp,
+ )
+ for idp in self._sso_handler.get_identity_providers().values()
+ ],
+ } # type: JsonDict
+
+ if self._msc2858_enabled:
+ # backwards-compatibility support for clients which don't
+ # support the stable API yet
+ sso_flow["org.matrix.msc2858.identity_providers"] = [
+ _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
+ for idp in self._sso_handler.get_identity_providers().values()
+ ]
+
+ flows.append(sso_flow)
+
+ # While it's valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the
# SSO login flow.
# Generally we don't want to advertise login flows that clients
@@ -105,22 +134,27 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
async def on_POST(self, request: SynapseRequest):
- self._address_ratelimiter.ratelimit(request.getClientIP())
-
login_submission = parse_json_object_from_request(request)
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
appservice = self.auth.get_appservice_by_req(request)
+
+ if appservice.is_rate_limited():
+ self._address_ratelimiter.ratelimit(request.getClientIP())
+
result = await self._do_appservice_login(login_submission, appservice)
elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_token_login(login_submission)
else:
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_other_login(login_submission)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -159,7 +193,9 @@ class LoginRestServlet(RestServlet):
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
- return await self._complete_login(qualified_user_id, login_submission)
+ return await self._complete_login(
+ qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+ )
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
@@ -194,6 +230,8 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
+ ratelimit: bool = True,
+ auth_provider_id: Optional[str] = None,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -208,6 +246,9 @@ class LoginRestServlet(RestServlet):
callback: Callback function to run after login.
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
+ ratelimit: Whether to ratelimit the login request.
+ auth_provider_id: The SSO IdP the user used, if any (just used for the
+ prometheus metrics).
Returns:
result: Dictionary of account information after successful login.
@@ -216,7 +257,8 @@ class LoginRestServlet(RestServlet):
# Before we actually log them in we check if they've already logged in
# too often. This happens here rather than before as we don't
# necessarily know the user before now.
- self._account_ratelimiter.ratelimit(user_id.lower())
+ if ratelimit:
+ self._account_ratelimiter.ratelimit(user_id.lower())
if create_non_existent_users:
canonical_uid = await self.auth_handler.check_user_exists(user_id)
@@ -229,7 +271,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name
+ user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
)
result = {
@@ -256,12 +298,13 @@ class LoginRestServlet(RestServlet):
"""
token = login_submission["token"]
auth_handler = self.auth_handler
- user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
- token
- )
+ res = await auth_handler.validate_short_term_login_token(token)
return await self._complete_login(
- user_id, login_submission, self.auth_handler._sso_login_callback
+ res.user_id,
+ login_submission,
+ self.auth_handler._sso_login_callback,
+ auth_provider_id=res.auth_provider_id,
)
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
@@ -284,7 +327,9 @@ class LoginRestServlet(RestServlet):
except jwt.PyJWTError as e:
# A JWT error occurred, return some info back to the client.
raise LoginError(
- 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
+ 403,
+ "JWT validation failed: %s" % (str(e),),
+ errcode=Codes.FORBIDDEN,
)
user = payload.get("sub", None)
@@ -298,48 +343,109 @@ class LoginRestServlet(RestServlet):
return result
-class BaseSSORedirectServlet(RestServlet):
- """Common base class for /login/sso/redirect impls"""
+def _get_auth_flow_dict_for_idp(
+ idp: SsoIdentityProvider, use_unstable_brands: bool = False
+) -> JsonDict:
+ """Return an entry for the login flow dict
+
+ Returns an entry suitable for inclusion in "identity_providers" in the
+ response to GET /_matrix/client/r0/login
+
+ Args:
+ idp: the identity provider to describe
+ use_unstable_brands: whether we should use brand identifiers suitable
+ for the unstable API
+ """
+ e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
+ if idp.idp_icon:
+ e["icon"] = idp.idp_icon
+ if idp.idp_brand:
+ e["brand"] = idp.idp_brand
+ # use the stable brand identifier if the unstable identifier isn't defined.
+ if use_unstable_brands and idp.unstable_idp_brand:
+ e["brand"] = idp.unstable_idp_brand
+ return e
+
+
+class SsoRedirectServlet(RestServlet):
+ PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
+ re.compile(
+ "^"
+ + CLIENT_API_PREFIX
+ + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
+ )
+ ]
+
+ def __init__(self, hs: "HomeServer"):
+ # make sure that the relevant handlers are instantiated, so that they
+ # register themselves with the main SSOHandler.
+ if hs.config.cas_enabled:
+ hs.get_cas_handler()
+ if hs.config.saml2_enabled:
+ hs.get_saml_handler()
+ if hs.config.oidc_enabled:
+ hs.get_oidc_handler()
+ self._sso_handler = hs.get_sso_handler()
+ self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+ self._public_baseurl = hs.config.public_baseurl
+
+ def register(self, http_server: HttpServer) -> None:
+ super().register(http_server)
+ if self._msc2858_enabled:
+ # expose additional endpoint for MSC2858 support: backwards-compat support
+ # for clients which don't yet support the stable endpoints.
+ http_server.register_paths(
+ "GET",
+ client_patterns(
+ "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
+ releases=(),
+ unstable=True,
+ ),
+ self.on_GET,
+ self.__class__.__name__,
+ )
- PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+ async def on_GET(
+ self, request: SynapseRequest, idp_id: Optional[str] = None
+ ) -> None:
+ if not self._public_baseurl:
+ raise SynapseError(400, "SSO requires a valid public_baseurl")
+
+ # if this isn't the expected hostname, redirect to the right one, so that we
+ # get our cookies back.
+ requested_uri = get_request_uri(request)
+ baseurl_bytes = self._public_baseurl.encode("utf-8")
+ if not requested_uri.startswith(baseurl_bytes):
+ # swap out the incorrect base URL for the right one.
+ #
+ # The idea here is to redirect from
+ # https://foo.bar/whatever/_matrix/...
+ # to
+ # https://public.baseurl/_matrix/...
+ #
+ i = requested_uri.index(b"/_matrix")
+ new_uri = baseurl_bytes[:-1] + requested_uri[i:]
+ logger.info(
+ "Requested URI %s is not canonical: redirecting to %s",
+ requested_uri.decode("utf-8", errors="replace"),
+ new_uri.decode("utf-8", errors="replace"),
+ )
+ request.redirect(new_uri)
+ finish_request(request)
+ return
- async def on_GET(self, request: SynapseRequest):
- args = request.args
- if b"redirectUrl" not in args:
- return 400, "Redirect URL not specified for SSO auth"
- client_redirect_url = args[b"redirectUrl"][0]
- sso_url = await self.get_sso_url(request, client_redirect_url)
+ client_redirect_url = parse_string(
+ request, "redirectUrl", required=True, encoding=None
+ )
+ sso_url = await self._sso_handler.handle_redirect_request(
+ request,
+ client_redirect_url,
+ idp_id,
+ )
+ logger.info("Redirecting to %s", sso_url)
request.redirect(sso_url)
finish_request(request)
- async def get_sso_url(
- self, request: SynapseRequest, client_redirect_url: bytes
- ) -> bytes:
- """Get the URL to redirect to, to perform SSO auth
-
- Args:
- request: The client request to redirect.
- client_redirect_url: the URL that we should redirect the
- client to when everything is done
-
- Returns:
- URL to redirect to
- """
- # to be implemented by subclasses
- raise NotImplementedError()
-
-
-class CasRedirectServlet(BaseSSORedirectServlet):
- def __init__(self, hs):
- self._cas_handler = hs.get_cas_handler()
-
- async def get_sso_url(
- self, request: SynapseRequest, client_redirect_url: bytes
- ) -> bytes:
- return self._cas_handler.get_redirect_url(
- {"redirectUrl": client_redirect_url}
- ).encode("ascii")
-
class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
@@ -366,40 +472,8 @@ class CasTicketServlet(RestServlet):
)
-class SAMLRedirectServlet(BaseSSORedirectServlet):
- PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
- def __init__(self, hs):
- self._saml_handler = hs.get_saml_handler()
-
- async def get_sso_url(
- self, request: SynapseRequest, client_redirect_url: bytes
- ) -> bytes:
- return self._saml_handler.handle_redirect_request(client_redirect_url)
-
-
-class OIDCRedirectServlet(BaseSSORedirectServlet):
- """Implementation for /login/sso/redirect for the OIDC login flow."""
-
- PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
- def __init__(self, hs):
- self._oidc_handler = hs.get_oidc_handler()
-
- async def get_sso_url(
- self, request: SynapseRequest, client_redirect_url: bytes
- ) -> bytes:
- return await self._oidc_handler.handle_redirect_request(
- request, client_redirect_url
- )
-
-
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
+ SsoRedirectServlet(hs).register(http_server)
if hs.config.cas_enabled:
- CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
- elif hs.config.saml2_enabled:
- SAMLRedirectServlet(hs).register(http_server)
- elif hs.config.oidc_enabled:
- OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index b5fa1cc464..d77e20e135 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -62,7 +62,9 @@ class ProfileDisplaynameRestServlet(RestServlet):
new_name = content["displayname"]
except Exception:
raise SynapseError(
- code=400, msg="Unable to parse name", errcode=Codes.BAD_JSON,
+ code=400,
+ msg="Unable to parse name",
+ errcode=Codes.BAD_JSON,
)
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 8fe83f321a..0c148a213d 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
-ALLOWED_KEYS = {
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
-}
-
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
@@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet):
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
- filtered_pushers = [
- {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
- ]
+ filtered_pushers = [p.as_dict() for p in pushers]
return 200, {"pushers": filtered_pushers}
@@ -172,7 +159,9 @@ class PushersRemoveRestServlet(RestServlet):
self.notifier.on_new_replication_data()
respond_with_html_bytes(
- request, 200, PushersRemoveRestServlet.SUCCESS_HTML,
+ request,
+ 200,
+ PushersRemoveRestServlet.SUCCESS_HTML,
)
return None
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 6d3be08758..d01b738e32 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -17,7 +17,7 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
import re
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Tuple
from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership
@@ -34,22 +34,31 @@ from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
+ parse_boolean,
parse_integer,
parse_json_object_from_request,
parse_list_from_args,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
-from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
+from synapse.types import (
+ JsonDict,
+ RoomAlias,
+ RoomID,
+ StreamToken,
+ ThirdPartyInstanceID,
+ UserID,
+)
from synapse.util import json_decoder
-from synapse.util.stringutils import random_string
+from synapse.util.stringutils import parse_and_validate_server_name, random_string
if TYPE_CHECKING:
- import synapse.server
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -345,8 +354,6 @@ class PublicRoomListRestServlet(TransactionRestServlet):
# provided.
if server:
raise e
- else:
- pass
limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None)
@@ -357,6 +364,16 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server_name:
+ # Ensure the server is valid.
+ try:
+ parse_and_validate_server_name(server)
+ except ValueError:
+ raise SynapseError(
+ 400,
+ "Invalid server name: %s" % (server,),
+ Codes.INVALID_PARAM,
+ )
+
try:
data = await handler.get_remote_public_room_list(
server, limit=limit, since_token=since_token
@@ -400,6 +417,16 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server_name:
+ # Ensure the server is valid.
+ try:
+ parse_and_validate_server_name(server)
+ except ValueError:
+ raise SynapseError(
+ 400,
+ "Invalid server name: %s" % (server,),
+ Codes.INVALID_PARAM,
+ )
+
try:
data = await handler.get_remote_public_room_list(
server,
@@ -634,7 +661,7 @@ class RoomEventContextServlet(RestServlet):
event_filter = None
results = await self.room_context_handler.get_event_context(
- requester.user, room_id, event_id, limit, event_filter
+ requester, room_id, event_id, limit, event_filter
)
if not results:
@@ -651,7 +678,10 @@ class RoomEventContextServlet(RestServlet):
results["events_after"], time_now
)
results["state"] = await self._event_serializer.serialize_events(
- results["state"], time_now
+ results["state"],
+ time_now,
+ # No need to bundle aggregations for state events
+ bundle_aggregations=False,
)
return 200, results
@@ -824,10 +854,10 @@ class RoomTypingRestServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
+ self.hs = hs
self.presence_handler = hs.get_presence_handler()
- self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
# If we're not on the typing writer instance we should scream if we get
@@ -852,16 +882,19 @@ class RoomTypingRestServlet(RestServlet):
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
+ # Defer getting the typing handler since it will raise on workers.
+ typing_handler = self.hs.get_typing_writer_handler()
+
try:
if content["typing"]:
- await self.typing_handler.started_typing(
+ await typing_handler.started_typing(
target_user=target_user,
requester=requester,
room_id=room_id,
timeout=timeout,
)
else:
- await self.typing_handler.stopped_typing(
+ await typing_handler.stopped_typing(
target_user=target_user, requester=requester, room_id=room_id
)
except ShadowBanError:
@@ -879,7 +912,7 @@ class RoomAliasListServlet(RestServlet):
),
]
- def __init__(self, hs: "synapse.server.HomeServer"):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.directory_handler = hs.get_directory_handler()
@@ -962,25 +995,82 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
)
-def register_servlets(hs, http_server):
+class RoomSpaceSummaryRestServlet(RestServlet):
+ PATTERNS = (
+ re.compile(
+ "^/_matrix/client/unstable/org.matrix.msc2946"
+ "/rooms/(?P<room_id>[^/]*)/spaces$"
+ ),
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._auth = hs.get_auth()
+ self._space_summary_handler = hs.get_space_summary_handler()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self._auth.get_user_by_req(request, allow_guest=True)
+
+ return 200, await self._space_summary_handler.get_space_summary(
+ requester.user.to_string(),
+ room_id,
+ suggested_only=parse_boolean(request, "suggested_only", default=False),
+ max_rooms_per_space=parse_integer(request, "max_rooms_per_space"),
+ )
+
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self._auth.get_user_by_req(request, allow_guest=True)
+ content = parse_json_object_from_request(request)
+
+ suggested_only = content.get("suggested_only", False)
+ if not isinstance(suggested_only, bool):
+ raise SynapseError(
+ 400, "'suggested_only' must be a boolean", Codes.BAD_JSON
+ )
+
+ max_rooms_per_space = content.get("max_rooms_per_space")
+ if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int):
+ raise SynapseError(
+ 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON
+ )
+
+ return 200, await self._space_summary_handler.get_space_summary(
+ requester.user.to_string(),
+ room_id,
+ suggested_only=suggested_only,
+ max_rooms_per_space=max_rooms_per_space,
+ )
+
+
+def register_servlets(hs: "HomeServer", http_server, is_worker=False):
RoomStateEventRestServlet(hs).register(http_server)
- RoomCreateRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
- RoomForgetRestServlet(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server)
RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server)
- SearchRestServlet(hs).register(http_server)
- JoinedRoomsRestServlet(hs).register(http_server)
- RoomEventServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server)
- RoomAliasListServlet(hs).register(http_server)
+
+ if hs.config.experimental.spaces_enabled:
+ RoomSpaceSummaryRestServlet(hs).register(http_server)
+
+ # Some servlets only get registered for the main process.
+ if not is_worker:
+ RoomCreateRestServlet(hs).register(http_server)
+ RoomForgetRestServlet(hs).register(http_server)
+ SearchRestServlet(hs).register(http_server)
+ JoinedRoomsRestServlet(hs).register(http_server)
+ RoomEventServlet(hs).register(http_server)
+ RoomAliasListServlet(hs).register(http_server)
def register_deprecated_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 8e4a01a7e8..80ee0d2d8e 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -21,9 +21,6 @@ from http import HTTPStatus
from typing import TYPE_CHECKING
from urllib.parse import urlparse
-if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
-
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
@@ -32,6 +29,7 @@ from synapse.api.errors import (
ThreepidValidationError,
)
from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
@@ -48,13 +46,17 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.datastore = hs.get_datastore()
@@ -103,6 +105,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
@@ -202,7 +206,6 @@ class PasswordRestServlet(RestServlet):
requester,
request,
body,
- self.hs.get_ip_from_request(request),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
@@ -225,7 +228,6 @@ class PasswordRestServlet(RestServlet):
[[LoginType.EMAIL_IDENTITY]],
request,
body,
- self.hs.get_ip_from_request(request),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
@@ -237,7 +239,9 @@ class PasswordRestServlet(RestServlet):
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
- e.session_id, "password_hash", password_hash
+ e.session_id,
+ UIAuthSessionDataConstants.PASSWORD_HASH,
+ password_hash,
)
raise
@@ -264,14 +268,18 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- # If we have a password in this request, prefer it. Otherwise, there
- # must be a password hash from an earlier request.
+ # If we have a password in this request, prefer it. Otherwise, use the
+ # password hash from an earlier request.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
- else:
+ elif session_id is not None:
password_hash = await self.auth_handler.get_session_data(
- session_id, "password_hash", None
+ session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
+ else:
+ # UI validation was skipped, but the request did not include a new
+ # password.
+ password_hash = None
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
@@ -330,7 +338,7 @@ class DeactivateAccountRestServlet(RestServlet):
# allow ASes to deactivate their own users
if requester.app_service:
await self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(), erase
+ requester.user.to_string(), erase, requester
)
return 200, {}
@@ -338,11 +346,13 @@ class DeactivateAccountRestServlet(RestServlet):
requester,
request,
body,
- self.hs.get_ip_from_request(request),
"deactivate your account",
)
result = await self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(), erase, id_server=body.get("id_server")
+ requester.user.to_string(),
+ erase,
+ requester,
+ id_server=body.get("id_server"),
)
if result:
id_server_unbind_result = "success"
@@ -405,6 +415,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -456,7 +468,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__()
self.store = self.hs.get_datastore()
@@ -484,6 +496,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -772,7 +788,6 @@ class ThreepidAddRestServlet(RestServlet):
requester,
request,
body,
- self.hs.get_ip_from_request(request),
"add a third-party identifier to your account",
)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 617ee6d62a..c9f13e4ac5 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -38,14 +38,12 @@ class AccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
+ self.handler = hs.get_account_data_handler()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
self._profile_handler = hs.get_profile_handler()
async def on_PUT(self, request, user_id, account_data_type):
- if self._is_worker:
- raise Exception("Cannot handle PUT /account_data on worker")
-
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -57,10 +55,9 @@ class AccountDataServlet(RestServlet):
hide_profile = body.get("hide_profile")
await self._profile_handler.set_active([user], not hide_profile, True)
- max_id = await self.store.add_account_data_for_user(
+ max_id = await self.handler.add_account_data_for_user(
user_id, account_data_type, body
)
-
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {}
@@ -96,13 +93,9 @@ class RoomAccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
- self._is_worker = hs.config.worker_app is not None
+ self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, account_data_type):
- if self._is_worker:
- raise Exception("Cannot handle PUT /account_data on worker")
-
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -116,12 +109,10 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers",
)
- max_id = await self.store.add_account_data_to_room(
+ await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
return 200, {}
async def on_GET(self, request, user_id, room_id, account_data_type):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index fab077747f..75ece1c911 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
@@ -23,6 +24,9 @@ from synapse.http.servlet import RestServlet, parse_string
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -35,28 +39,12 @@ class AuthRestServlet(RestServlet):
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
-
- # SSO configuration.
- self._cas_enabled = hs.config.cas_enabled
- if self._cas_enabled:
- self._cas_handler = hs.get_cas_handler()
- self._cas_server_url = hs.config.cas_server_url
- self._cas_service_url = hs.config.cas_service_url
- self._saml_enabled = hs.config.saml2_enabled
- if self._saml_enabled:
- self._saml_handler = hs.get_saml_handler()
- self._oidc_enabled = hs.config.oidc_enabled
- if self._oidc_enabled:
- self._oidc_handler = hs.get_oidc_handler()
- self._cas_server_url = hs.config.cas_server_url
- self._cas_service_url = hs.config.cas_service_url
-
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
self.success_template = hs.config.fallback_success_template
@@ -85,32 +73,7 @@ class AuthRestServlet(RestServlet):
elif stagetype == LoginType.SSO:
# Display a confirmation page which prompts the user to
# re-authenticate with their SSO provider.
- if self._cas_enabled:
- # Generate a request to CAS that redirects back to an endpoint
- # to verify the successful authentication.
- sso_redirect_url = self._cas_handler.get_redirect_url(
- {"session": session},
- )
-
- elif self._saml_enabled:
- # Some SAML identity providers (e.g. Google) require a
- # RelayState parameter on requests. It is not necessary here, so
- # pass in a dummy redirect URL (which will never get used).
- client_redirect_url = b"unused"
- sso_redirect_url = self._saml_handler.handle_redirect_request(
- client_redirect_url, session
- )
-
- elif self._oidc_enabled:
- client_redirect_url = b""
- sso_redirect_url = await self._oidc_handler.handle_redirect_request(
- request, client_redirect_url, session
- )
-
- else:
- raise SynapseError(400, "Homeserver not configured for SSO.")
-
- html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+ html = await self.auth_handler.start_sso_ui_auth(request, session)
else:
raise SynapseError(404, "Unknown auth stage type")
@@ -134,7 +97,7 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
success = await self.auth_handler.add_oob_auth(
- LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
+ LoginType.RECAPTCHA, authdict, request.getClientIP()
)
if success:
@@ -150,7 +113,7 @@ class AuthRestServlet(RestServlet):
authdict = {"session": session}
success = await self.auth_handler.add_oob_auth(
- LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
+ LoginType.TERMS, authdict, request.getClientIP()
)
if success:
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index 76879ac559..44ccf10ed4 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -13,12 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -27,21 +33,16 @@ class CapabilitiesRestServlet(RestServlet):
PATTERNS = client_patterns("/capabilities$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.config = hs.config
self.auth = hs.get_auth()
- self.store = hs.get_datastore()
+ self.auth_handler = hs.get_auth_handler()
- async def on_GET(self, request):
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
- user = await self.store.get_user_by_id(requester.user.to_string())
- change_password = bool(user["password_hash"])
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.auth.get_user_by_req(request, allow_guest=True)
+ change_password = self.auth_handler.can_change_password()
response = {
"capabilities": {
@@ -58,5 +59,5 @@ class CapabilitiesRestServlet(RestServlet):
return 200, response
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
CapabilitiesRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index af117cb27c..3d07aadd39 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -86,7 +86,6 @@ class DeleteDevicesRestServlet(RestServlet):
requester,
request,
body,
- self.hs.get_ip_from_request(request),
"remove device(s) from your account",
)
@@ -136,7 +135,6 @@ class DeviceRestServlet(RestServlet):
requester,
request,
body,
- self.hs.get_ip_from_request(request),
"remove a device from your account",
)
@@ -214,7 +212,9 @@ class DehydratedDeviceServlet(RestServlet):
if "device_data" not in submission:
raise errors.SynapseError(
- 400, "device_data missing", errcode=errors.Codes.MISSING_PARAM,
+ 400,
+ "device_data missing",
+ errcode=errors.Codes.MISSING_PARAM,
)
elif not isinstance(submission["device_data"], dict):
raise errors.SynapseError(
@@ -267,11 +267,15 @@ class ClaimDehydratedDeviceServlet(RestServlet):
if "device_id" not in submission:
raise errors.SynapseError(
- 400, "device_id missing", errcode=errors.Codes.MISSING_PARAM,
+ 400,
+ "device_id missing",
+ errcode=errors.Codes.MISSING_PARAM,
)
elif not isinstance(submission["device_id"], str):
raise errors.SynapseError(
- 400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM,
+ 400,
+ "device_id must be a string",
+ errcode=errors.Codes.INVALID_PARAM,
)
result = await self.device_handler.rehydrate_device(
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index a3bb095c2d..08fb6b2b06 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -15,29 +15,65 @@
# limitations under the License.
import logging
-
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import GroupID
+from functools import wraps
+from typing import TYPE_CHECKING, Optional, Tuple
+
+from twisted.web.server import Request
+
+from synapse.api.constants import (
+ MAX_GROUP_CATEGORYID_LENGTH,
+ MAX_GROUP_ROLEID_LENGTH,
+ MAX_GROUPID_LENGTH,
+)
+from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.http.site import SynapseRequest
+from synapse.types import GroupID, JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-class GroupServlet(RestServlet):
- """Get the group profile
+def _validate_group_id(f):
+ """Wrapper to validate the form of the group ID.
+
+ Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
"""
+ @wraps(f)
+ def wrapper(self, request: Request, group_id: str, *args, **kwargs):
+ if not GroupID.is_valid(group_id):
+ raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
+
+ return f(self, request, group_id, *args, **kwargs)
+
+ return wrapper
+
+
+class GroupServlet(RestServlet):
+ """Get the group profile"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, group_id):
+ @_validate_group_id
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -47,11 +83,20 @@ class GroupServlet(RestServlet):
return 200, group_description
- async def on_POST(self, request, group_id):
+ @_validate_group_id
+ async def on_POST(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert_params_in_dict(
+ content, ("name", "avatar_url", "short_description", "long_description")
+ )
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot create group profiles."
await self.groups_handler.update_group_profile(
group_id, requester_user_id, content
)
@@ -60,18 +105,20 @@ class GroupServlet(RestServlet):
class GroupSummaryServlet(RestServlet):
- """Get the full group summary
- """
+ """Get the full group summary"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, group_id):
+ @_validate_group_id
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -96,17 +143,38 @@ class GroupSummaryRoomsCatServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_PUT(self, request, group_id, category_id, room_id):
+ @_validate_group_id
+ async def on_PUT(
+ self,
+ request: SynapseRequest,
+ group_id: str,
+ category_id: Optional[str],
+ room_id: str,
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
+
+ if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+ raise SynapseError(
+ 400,
+ "category_id may not be longer than %s characters"
+ % (MAX_GROUP_CATEGORYID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group summaries."
resp = await self.groups_handler.update_group_summary_room(
group_id,
requester_user_id,
@@ -117,10 +185,16 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp
- async def on_DELETE(self, request, group_id, category_id, room_id):
+ @_validate_group_id
+ async def on_DELETE(
+ self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group profiles."
resp = await self.groups_handler.delete_group_summary_room(
group_id, requester_user_id, room_id=room_id, category_id=category_id
)
@@ -129,20 +203,22 @@ class GroupSummaryRoomsCatServlet(RestServlet):
class GroupCategoryServlet(RestServlet):
- """Get/add/update/delete a group category
- """
+ """Get/add/update/delete a group category"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, group_id, category_id):
+ @_validate_group_id
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -152,21 +228,44 @@ class GroupCategoryServlet(RestServlet):
return 200, category
- async def on_PUT(self, request, group_id, category_id):
+ @_validate_group_id
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ if not category_id:
+ raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
+
+ if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+ raise SynapseError(
+ 400,
+ "category_id may not be longer than %s characters"
+ % (MAX_GROUP_CATEGORYID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group categories."
resp = await self.groups_handler.update_group_category(
group_id, requester_user_id, category_id=category_id, content=content
)
return 200, resp
- async def on_DELETE(self, request, group_id, category_id):
+ @_validate_group_id
+ async def on_DELETE(
+ self, request: SynapseRequest, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group categories."
resp = await self.groups_handler.delete_group_category(
group_id, requester_user_id, category_id=category_id
)
@@ -175,18 +274,20 @@ class GroupCategoryServlet(RestServlet):
class GroupCategoriesServlet(RestServlet):
- """Get all group categories
- """
+ """Get all group categories"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, group_id):
+ @_validate_group_id
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -198,18 +299,20 @@ class GroupCategoriesServlet(RestServlet):
class GroupRoleServlet(RestServlet):
- """Get/add/update/delete a group role
- """
+ """Get/add/update/delete a group role"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, group_id, role_id):
+ @_validate_group_id
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -219,21 +322,44 @@ class GroupRoleServlet(RestServlet):
return 200, category
- async def on_PUT(self, request, group_id, role_id):
+ @_validate_group_id
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ if not role_id:
+ raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
+
+ if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+ raise SynapseError(
+ 400,
+ "role_id may not be longer than %s characters"
+ % (MAX_GROUP_ROLEID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group roles."
resp = await self.groups_handler.update_group_role(
group_id, requester_user_id, role_id=role_id, content=content
)
return 200, resp
- async def on_DELETE(self, request, group_id, role_id):
+ @_validate_group_id
+ async def on_DELETE(
+ self, request: SynapseRequest, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group roles."
resp = await self.groups_handler.delete_group_role(
group_id, requester_user_id, role_id=role_id
)
@@ -242,18 +368,20 @@ class GroupRoleServlet(RestServlet):
class GroupRolesServlet(RestServlet):
- """Get all group roles
- """
+ """Get all group roles"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, group_id):
+ @_validate_group_id
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -278,17 +406,38 @@ class GroupSummaryUsersRoleServlet(RestServlet):
"/users/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_PUT(self, request, group_id, role_id, user_id):
+ @_validate_group_id
+ async def on_PUT(
+ self,
+ request: SynapseRequest,
+ group_id: str,
+ role_id: Optional[str],
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
+
+ if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+ raise SynapseError(
+ 400,
+ "role_id may not be longer than %s characters"
+ % (MAX_GROUP_ROLEID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group summaries."
resp = await self.groups_handler.update_group_summary_user(
group_id,
requester_user_id,
@@ -299,10 +448,16 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp
- async def on_DELETE(self, request, group_id, role_id, user_id):
+ @_validate_group_id
+ async def on_DELETE(
+ self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group summaries."
resp = await self.groups_handler.delete_group_summary_user(
group_id, requester_user_id, user_id=user_id, role_id=role_id
)
@@ -311,24 +466,23 @@ class GroupSummaryUsersRoleServlet(RestServlet):
class GroupRoomServlet(RestServlet):
- """Get all rooms in a group
- """
+ """Get all rooms in a group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, group_id):
+ @_validate_group_id
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- if not GroupID.is_valid(group_id):
- raise SynapseError(400, "%s was not legal group ID" % (group_id,))
-
result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id
)
@@ -337,18 +491,20 @@ class GroupRoomServlet(RestServlet):
class GroupUsersServlet(RestServlet):
- """Get all users in a group
- """
+ """Get all users in a group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, group_id):
+ @_validate_group_id
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -360,18 +516,20 @@ class GroupUsersServlet(RestServlet):
class GroupInvitedUsersServlet(RestServlet):
- """Get users invited to a group
- """
+ """Get users invited to a group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, group_id):
+ @_validate_group_id
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -383,22 +541,27 @@ class GroupInvitedUsersServlet(RestServlet):
class GroupSettingJoinPolicyServlet(RestServlet):
- """Set group join policy
- """
+ """Set group join policy"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
- async def on_PUT(self, request, group_id):
+ @_validate_group_id
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group join policy."
result = await self.groups_handler.set_group_join_policy(
group_id, requester_user_id, content
)
@@ -407,19 +570,18 @@ class GroupSettingJoinPolicyServlet(RestServlet):
class GroupCreateServlet(RestServlet):
- """Create a group
- """
+ """Create a group"""
PATTERNS = client_patterns("/create_group$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -428,6 +590,19 @@ class GroupCreateServlet(RestServlet):
localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string()
+ if not localpart:
+ raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
+
+ if len(group_id) > MAX_GROUPID_LENGTH:
+ raise SynapseError(
+ 400,
+ "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot create groups."
result = await self.groups_handler.create_group(
group_id, requester_user_id, content
)
@@ -436,34 +611,45 @@ class GroupCreateServlet(RestServlet):
class GroupAdminRoomsServlet(RestServlet):
- """Add a room to the group
- """
+ """Add a room to the group"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_PUT(self, request, group_id, room_id):
+ @_validate_group_id
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify rooms in a group."
result = await self.groups_handler.add_room_to_group(
group_id, requester_user_id, room_id, content
)
return 200, result
- async def on_DELETE(self, request, group_id, room_id):
+ @_validate_group_id
+ async def on_DELETE(
+ self, request: SynapseRequest, group_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group categories."
result = await self.groups_handler.remove_room_from_group(
group_id, requester_user_id, room_id
)
@@ -472,25 +658,30 @@ class GroupAdminRoomsServlet(RestServlet):
class GroupAdminRoomsConfigServlet(RestServlet):
- """Update the config of a room in a group
- """
+ """Update the config of a room in a group"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_PUT(self, request, group_id, room_id, config_key):
+ @_validate_group_id
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group categories."
result = await self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content
)
@@ -499,14 +690,13 @@ class GroupAdminRoomsConfigServlet(RestServlet):
class GroupAdminUsersInviteServlet(RestServlet):
- """Invite a user to the group
- """
+ """Invite a user to the group"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@@ -514,12 +704,18 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
- async def on_PUT(self, request, group_id, user_id):
+ @_validate_group_id
+ async def on_PUT(
+ self, request: SynapseRequest, group_id, user_id
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
config = content.get("config", {})
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot invite users to a group."
result = await self.groups_handler.invite(
group_id, user_id, requester_user_id, config
)
@@ -528,24 +724,29 @@ class GroupAdminUsersInviteServlet(RestServlet):
class GroupAdminUsersKickServlet(RestServlet):
- """Kick a user from the group
- """
+ """Kick a user from the group"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_PUT(self, request, group_id, user_id):
+ @_validate_group_id
+ async def on_PUT(
+ self, request: SynapseRequest, group_id, user_id
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot kick users from a group."
result = await self.groups_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content
)
@@ -554,22 +755,27 @@ class GroupAdminUsersKickServlet(RestServlet):
class GroupSelfLeaveServlet(RestServlet):
- """Leave a joined group
- """
+ """Leave a joined group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_PUT(self, request, group_id):
+ @_validate_group_id
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot leave a group for a users."
result = await self.groups_handler.remove_user_from_group(
group_id, requester_user_id, requester_user_id, content
)
@@ -578,22 +784,27 @@ class GroupSelfLeaveServlet(RestServlet):
class GroupSelfJoinServlet(RestServlet):
- """Attempt to join a group, or knock
- """
+ """Attempt to join a group, or knock"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_PUT(self, request, group_id):
+ @_validate_group_id
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot join a user to a group."
result = await self.groups_handler.join_group(
group_id, requester_user_id, content
)
@@ -602,22 +813,27 @@ class GroupSelfJoinServlet(RestServlet):
class GroupSelfAcceptInviteServlet(RestServlet):
- """Accept a group invite
- """
+ """Accept a group invite"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_PUT(self, request, group_id):
+ @_validate_group_id
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot accept an invite to a group."
result = await self.groups_handler.accept_invite(
group_id, requester_user_id, content
)
@@ -626,18 +842,20 @@ class GroupSelfAcceptInviteServlet(RestServlet):
class GroupSelfUpdatePublicityServlet(RestServlet):
- """Update whether we publicise a users membership of a group
- """
+ """Update whether we publicise a users membership of a group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- async def on_PUT(self, request, group_id):
+ @_validate_group_id
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -649,19 +867,20 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
class PublicisedGroupsForUserServlet(RestServlet):
- """Get the list of groups a user is advertising
- """
+ """Get the list of groups a user is advertising"""
PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -670,19 +889,18 @@ class PublicisedGroupsForUserServlet(RestServlet):
class PublicisedGroupsForUsersServlet(RestServlet):
- """Get the list of groups a user is advertising
- """
+ """Get the list of groups a user is advertising"""
PATTERNS = client_patterns("/publicised_groups$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
@@ -694,18 +912,17 @@ class PublicisedGroupsForUsersServlet(RestServlet):
class GroupsForUserServlet(RestServlet):
- """Get all groups the logged in user is joined to
- """
+ """Get all groups the logged in user is joined to"""
PATTERNS = client_patterns("/joined_groups$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -714,7 +931,7 @@ class GroupsForUserServlet(RestServlet):
return 200, result
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
GroupServlet(hs).register(http_server)
GroupSummaryServlet(hs).register(http_server)
GroupInvitedUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index b91996c738..f092e5b3a2 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -274,7 +274,6 @@ class SigningKeyUploadServlet(RestServlet):
requester,
request,
body,
- self.hs.get_ip_from_request(request),
"add a device signing key to your account",
)
diff --git a/synapse/rest/client/v2_alpha/knock.py b/synapse/rest/client/v2_alpha/knock.py
index 8439da447e..51dac88fc7 100644
--- a/synapse/rest/client/v2_alpha/knock.py
+++ b/synapse/rest/client/v2_alpha/knock.py
@@ -25,6 +25,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_list_from_args,
)
+from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict, RoomAlias, RoomID
@@ -53,7 +54,10 @@ class KnockRoomAliasServlet(RestServlet):
self.auth = hs.get_auth()
async def on_POST(
- self, request: Request, room_identifier: str, txn_id: Optional[str] = None,
+ self,
+ request: SynapseRequest,
+ room_identifier: str,
+ txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -65,7 +69,7 @@ class KnockRoomAliasServlet(RestServlet):
if RoomID.is_valid(room_identifier):
room_id = room_identifier
try:
- remote_room_hosts = parse_list_from_args(request.args, "server_name")
+ remote_room_hosts = parse_list_from_args(request.args, "server_name") # type: ignore
except KeyError:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index f53bb71cb2..a7aea914e9 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -40,6 +40,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
@@ -127,6 +128,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
@@ -192,6 +195,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
body, ["client_secret", "country", "phone_number", "send_attempt"]
)
client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
country = body["country"]
phone_number = body["phone_number"]
send_attempt = body["send_attempt"]
@@ -208,6 +212,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn
)
@@ -290,6 +298,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
sid = parse_string(request, "sid", required=True)
client_secret = parse_string(request, "client_secret", required=True)
+ assert_valid_client_secret(client_secret)
token = parse_string(request, "token", required=True)
# Attempt to validate a 3PID session
@@ -458,7 +467,7 @@ class RegisterRestServlet(RestServlet):
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
- raise SynapseError(403, "Registration has been disabled")
+ raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
# Check if this account is upgrading from a guest account.
guest_access_token = body.get("guest_access_token", None)
@@ -490,11 +499,11 @@ class RegisterRestServlet(RestServlet):
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = await self.auth_handler.get_session_data(
- session_id, "registered_user_id", None
+ session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
)
# Extract the previously-hashed password from the session.
password_hash = await self.auth_handler.get_session_data(
- session_id, "password_hash", None
+ session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
# Check if the user-interactive authentication flows are complete, if
@@ -504,7 +513,6 @@ class RegisterRestServlet(RestServlet):
self._registration_flows,
request,
body,
- self.hs.get_ip_from_request(request),
"register a new account",
)
except InteractiveAuthIncompleteError as e:
@@ -520,7 +528,9 @@ class RegisterRestServlet(RestServlet):
if not password_hash and password:
password_hash = await self.auth_handler.hash(password)
await self.auth_handler.set_session_data(
- e.session_id, "password_hash", password_hash
+ e.session_id,
+ UIAuthSessionDataConstants.PASSWORD_HASH,
+ password_hash,
)
raise
@@ -709,7 +719,9 @@ class RegisterRestServlet(RestServlet):
# Remember that the user account has been registered (and the user
# ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
- session_id, "registered_user_id", registered_user_id
+ session_id,
+ UIAuthSessionDataConstants.REGISTERED_USER_ID,
+ registered_user_id,
)
registered = True
@@ -742,7 +754,11 @@ class RegisterRestServlet(RestServlet):
user_id = await self.registration_handler.appservice_register(
username, as_token, password, display_name
)
- result = await self._create_registration_details(user_id, body)
+ result = await self._create_registration_details(
+ user_id,
+ body,
+ is_appservice_ghost=True,
+ )
auth_result = body.get("auth_result")
if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
@@ -759,7 +775,9 @@ class RegisterRestServlet(RestServlet):
return result
- async def _create_registration_details(self, user_id, params):
+ async def _create_registration_details(
+ self, user_id, params, is_appservice_ghost=False
+ ):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@@ -776,7 +794,11 @@ class RegisterRestServlet(RestServlet):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest=False
+ user_id,
+ device_id,
+ initial_display_name,
+ is_guest=False,
+ is_appservice_ghost=is_appservice_ghost,
)
result.update({"access_token": access_token, "device_id": device_id})
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 18c75738f8..fe765da23c 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -244,7 +244,9 @@ class RelationAggregationPaginationServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string(), allow_departed_users=True,
+ room_id,
+ requester.user.to_string(),
+ allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
@@ -322,7 +324,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string(), allow_departed_users=True,
+ room_id,
+ requester.user.to_string(),
+ allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index bf030e0ff4..147920767f 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class RoomUpgradeRestServlet(RestServlet):
- """Handler for room uprade requests.
+ """Handler for room upgrade requests.
Handles requests of the form:
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index bc4f43639a..79c1b526ee 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -17,7 +17,7 @@ import logging
from typing import Tuple
from synapse.http import servlet
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.logging.opentracing import set_tag, trace
from synapse.rest.client.transactions import HttpTransactionCache
@@ -54,11 +54,10 @@ class SendToDeviceRestServlet(servlet.RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
-
- sender_user_id = requester.user.to_string()
+ assert_params_in_dict(content, ("messages",))
await self.device_message_handler.send_device_message(
- sender_user_id, message_type, content["messages"]
+ requester, message_type, content["messages"]
)
response = (200, {}) # type: Tuple[int, dict]
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 582c999abd..c01ba14cd2 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -14,7 +14,7 @@
# limitations under the License.
import itertools
import logging
-from typing import Any, Callable, Dict, List
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
@@ -26,11 +26,15 @@ from synapse.events.utils import (
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import KnockedSyncResult, SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
-from synapse.types import StreamToken
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder
from ._base import client_patterns, set_timeline_upper_limit
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -73,7 +77,7 @@ class SyncRestServlet(RestServlet):
PATTERNS = client_patterns("/sync$")
ALLOWED_PRESENCE = {"online", "offline", "unavailable"}
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -85,7 +89,10 @@ class SyncRestServlet(RestServlet):
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'.
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index bf3a79db44..a97cd66c52 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -58,8 +58,7 @@ class TagServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.auth = hs.get_auth()
- self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
+ self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, tag):
requester = await self.auth.get_user_by_req(request)
@@ -68,9 +67,7 @@ class TagServlet(RestServlet):
body = parse_json_object_from_request(request)
- max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
-
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+ await self.handler.add_tag_to_room(user_id, room_id, tag, body)
return 200, {}
@@ -79,9 +76,7 @@ class TagServlet(RestServlet):
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
- max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
-
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+ await self.handler.remove_tag_from_room(user_id, room_id, tag)
return 200, {}
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index c9b9e7f5ff..54eed3251b 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -84,6 +84,8 @@ class VersionsRestServlet(RestServlet):
"io.element.e2ee_forced.public": self.e2ee_forced_public,
"io.element.e2ee_forced.private": self.e2ee_forced_private,
"io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private,
+ # Supports the busy presence state described in MSC3026.
+ "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
},
},
)
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index b3e4d5612e..8b9ef26cf2 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource):
consent_template_directory = hs.config.user_consent_template_dir
+ # TODO: switch to synapse.util.templates.build_jinja_env
loader = jinja2.FileSystemLoader(consent_template_directory)
self._jinja_env = jinja2.Environment(
loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"])
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f843f02454..c57ac22e58 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import Dict, Set
+from typing import Dict
from signedjson.sign import sign_json
@@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource):
time_now_ms = self.clock.time_msec()
- cache_misses = {} # type: Dict[str, Set[str]]
+ # Note that the value is unused.
+ cache_misses = {} # type: Dict[str, Dict[str, int]]
for (server_name, key_id, from_server), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]
if not results and key_id is not None:
- cache_misses.setdefault(server_name, set()).add(key_id)
+ cache_misses.setdefault(server_name, {})[key_id] = 0
continue
if key_id is not None:
@@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource):
)
if miss:
- cache_misses.setdefault(server_name, set()).add(key_id)
+ cache_misses.setdefault(server_name, {})[key_id] = 0
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
else:
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 67aa993f19..6366947071 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 New Vector Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,10 +17,11 @@
import logging
import os
import urllib
-from typing import Awaitable
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
+from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
@@ -46,20 +47,22 @@ TEXT_CONTENT_TYPES = [
]
-def parse_media_id(request):
+def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try:
+ # The type on postpath seems incorrect in Twisted 21.2.0.
+ postpath = request.postpath # type: List[bytes] # type: ignore
+ assert postpath
+
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
- server_name, media_id = request.postpath[:2]
-
- if isinstance(server_name, bytes):
- server_name = server_name.decode("utf-8")
- media_id = media_id.decode("utf8")
+ server_name_bytes, media_id_bytes = postpath[:2]
+ server_name = server_name_bytes.decode("utf-8")
+ media_id = media_id_bytes.decode("utf8")
file_name = None
- if len(request.postpath) > 2:
+ if len(postpath) > 2:
try:
- file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8"))
+ file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
except UnicodeDecodeError:
pass
return server_name, media_id, file_name
@@ -69,7 +72,7 @@ def parse_media_id(request):
)
-def respond_404(request):
+def respond_404(request: Request) -> None:
respond_with_json(
request,
404,
@@ -79,8 +82,12 @@ def respond_404(request):
async def respond_with_file(
- request, media_type, file_path, file_size=None, upload_name=None
-):
+ request: Request,
+ media_type: str,
+ file_path: str,
+ file_size: Optional[int] = None,
+ upload_name: Optional[str] = None,
+) -> None:
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
@@ -98,15 +105,20 @@ async def respond_with_file(
respond_404(request)
-def add_file_headers(request, media_type, file_size, upload_name):
+def add_file_headers(
+ request: Request,
+ media_type: str,
+ file_size: Optional[int],
+ upload_name: Optional[str],
+) -> None:
"""Adds the correct response headers in preparation for responding with the
media.
Args:
- request (twisted.web.http.Request)
- media_type (str): The media/content type.
- file_size (int): Size in bytes of the media, if known.
- upload_name (str): The name of the requested file, if any.
+ request
+ media_type: The media/content type.
+ file_size: Size in bytes of the media, if known.
+ upload_name: The name of the requested file, if any.
"""
def _quote(x):
@@ -127,7 +139,7 @@ def add_file_headers(request, media_type, file_size, upload_name):
# section 3.6 [2] to be a `token` or a `quoted-string`, where a `token`
# is (essentially) a single US-ASCII word, and a `quoted-string` is a
# US-ASCII string surrounded by double-quotes, using backslash as an
- # escape charater. Note that %-encoding is *not* permitted.
+ # escape character. Note that %-encoding is *not* permitted.
#
# `filename*` is defined to be an `ext-value`, which is defined in
# RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`,
@@ -153,7 +165,13 @@ def add_file_headers(request, media_type, file_size, upload_name):
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
- request.setHeader(b"Content-Length", b"%d" % (file_size,))
+ if file_size is not None:
+ request.setHeader(b"Content-Length", b"%d" % (file_size,))
+
+ # Tell web crawlers to not index, archive, or follow links in media. This
+ # should help to prevent things in the media repo from showing up in web
+ # search results.
+ request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
# separators as defined in RFC2616. SP and HT are handled separately.
@@ -179,7 +197,7 @@ _FILENAME_SEPARATOR_CHARS = {
}
-def _can_encode_filename_as_token(x):
+def _can_encode_filename_as_token(x: str) -> bool:
for c in x:
# from RFC2616:
#
@@ -201,17 +219,21 @@ def _can_encode_filename_as_token(x):
async def respond_with_responder(
- request, responder, media_type, file_size, upload_name=None
-):
+ request: Request,
+ responder: "Optional[Responder]",
+ media_type: str,
+ file_size: Optional[int],
+ upload_name: Optional[str] = None,
+) -> None:
"""Responds to the request with given responder. If responder is None then
returns 404.
Args:
- request (twisted.web.http.Request)
- responder (Responder|None)
- media_type (str): The media/content type.
- file_size (int|None): Size in bytes of the media. If not known it should be None
- upload_name (str|None): The name of the requested file, if any.
+ request
+ responder
+ media_type: The media/content type.
+ file_size: Size in bytes of the media. If not known it should be None
+ upload_name: The name of the requested file, if any.
"""
if request._disconnected:
logger.warning(
@@ -280,6 +302,7 @@ class FileInfo:
thumbnail_height (int)
thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png
+ thumbnail_length (int): The size of the media file, in bytes.
"""
def __init__(
@@ -292,6 +315,7 @@ class FileInfo:
thumbnail_height=None,
thumbnail_method=None,
thumbnail_type=None,
+ thumbnail_length=None,
):
self.server_name = server_name
self.file_id = file_id
@@ -301,24 +325,25 @@ class FileInfo:
self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type
+ self.thumbnail_length = thumbnail_length
-def get_filename_from_headers(headers):
+def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
"""
Get the filename of the downloaded file by inspecting the
Content-Disposition HTTP header.
Args:
- headers (dict[bytes, list[bytes]]): The HTTP request headers.
+ headers: The HTTP request headers.
Returns:
- A Unicode string of the filename, or None.
+ The filename, or None.
"""
content_disposition = headers.get(b"Content-Disposition", [b""])
# No header, bail out.
if not content_disposition[0]:
- return
+ return None
_, params = _parse_header(content_disposition[0])
@@ -351,17 +376,16 @@ def get_filename_from_headers(headers):
return upload_name
-def _parse_header(line):
+def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
"""Parse a Content-type like header.
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
- line (bytes): header to be parsed
+ line: header to be parsed
Returns:
- Tuple[bytes, dict[bytes, bytes]]:
- the main content-type, followed by the parameter dictionary
+ The main content-type, followed by the parameter dictionary
"""
parts = _parseparam(b";" + line)
key = next(parts)
@@ -381,16 +405,16 @@ def _parse_header(line):
return key, pdict
-def _parseparam(s):
+def _parseparam(s: bytes) -> Generator[bytes, None, None]:
"""Generator which splits the input on ;, respecting double-quoted sequences
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
- s (bytes): header to be parsed
+ s: header to be parsed
Returns:
- Iterable[bytes]: the split input
+ The split input
"""
while s[:1] == b";":
s = s[1:]
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 68dd2a1c8a..c41a7ab412 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 Will Hunt <will@half-shot.uk>
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,22 +15,30 @@
# limitations under the License.
#
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
+
from synapse.http.server import DirectServeJsonResource, respond_with_json
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class MediaConfigResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
config = hs.get_config()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
- async def _async_render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index d3d8457303..5dadaeaf57 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,24 +14,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
-import synapse.http.servlet
from synapse.http.server import DirectServeJsonResource, set_cors_headers
+from synapse.http.servlet import parse_boolean
from ._base import parse_media_id, respond_404
+if TYPE_CHECKING:
+ from synapse.rest.media.v1.media_repository import MediaRepository
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class DownloadResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self.server_name = hs.hostname
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request)
request.setHeader(
b"Content-Security-Policy",
@@ -43,15 +51,14 @@ class DownloadResource(DirectServeJsonResource):
b" object-src 'self';",
)
request.setHeader(
- b"Referrer-Policy", b"no-referrer",
+ b"Referrer-Policy",
+ b"no-referrer",
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name)
else:
- allow_remote = synapse.http.servlet.parse_boolean(
- request, "allow_remote", default=True
- )
+ allow_remote = parse_boolean(request, "allow_remote", default=True)
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 9e079f672f..7792f26e78 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,11 +17,12 @@
import functools
import os
import re
+from typing import Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
-def _wrap_in_base_path(func):
+def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
"""Takes a function that returns a relative path and turns it into an
absolute path based on the location of the primary media store
"""
@@ -41,12 +43,18 @@ class MediaFilePaths:
to write to the backup media store (when one is configured)
"""
- def __init__(self, primary_base_path):
+ def __init__(self, primary_base_path: str):
self.base_path = primary_base_path
def default_thumbnail_rel(
- self, default_top_level, default_sub_type, width, height, content_type, method
- ):
+ self,
+ default_top_level: str,
+ default_sub_type: str,
+ width: int,
+ height: int,
+ content_type: str,
+ method: str,
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@@ -55,12 +63,14 @@ class MediaFilePaths:
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
- def local_media_filepath_rel(self, media_id):
+ def local_media_filepath_rel(self, media_id: str) -> str:
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
- def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
+ def local_media_thumbnail_rel(
+ self, media_id: str, width: int, height: int, content_type: str, method: str
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@@ -86,7 +96,7 @@ class MediaFilePaths:
media_id[4:],
)
- def remote_media_filepath_rel(self, server_name, file_id):
+ def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
return os.path.join(
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
)
@@ -94,8 +104,14 @@ class MediaFilePaths:
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
def remote_media_thumbnail_rel(
- self, server_name, file_id, width, height, content_type, method
- ):
+ self,
+ server_name: str,
+ file_id: str,
+ width: int,
+ height: int,
+ content_type: str,
+ method: str,
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@@ -113,7 +129,7 @@ class MediaFilePaths:
# Should be removed after some time, when most of the thumbnails are stored
# using the new path.
def remote_media_thumbnail_rel_legacy(
- self, server_name, file_id, width, height, content_type
+ self, server_name: str, file_id: str, width: int, height: int, content_type: str
):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
@@ -126,7 +142,7 @@ class MediaFilePaths:
file_name,
)
- def remote_media_thumbnail_dir(self, server_name, file_id):
+ def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
return os.path.join(
self.base_path,
"remote_thumbnail",
@@ -136,7 +152,7 @@ class MediaFilePaths:
file_id[4:],
)
- def url_cache_filepath_rel(self, media_id):
+ def url_cache_filepath_rel(self, media_id: str) -> str:
if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -146,7 +162,7 @@ class MediaFilePaths:
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
- def url_cache_filepath_dirs_to_delete(self, media_id):
+ def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
@@ -156,7 +172,9 @@ class MediaFilePaths:
os.path.join(self.base_path, "url_cache", media_id[0:2]),
]
- def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
+ def url_cache_thumbnail_rel(
+ self, media_id: str, width: int, height: int, content_type: str, method: str
+ ) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -178,7 +196,7 @@ class MediaFilePaths:
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
- def url_cache_thumbnail_directory(self, media_id):
+ def url_cache_thumbnail_directory(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -195,7 +213,7 @@ class MediaFilePaths:
media_id[4:],
)
- def url_cache_thumbnail_dirs_to_delete(self, media_id):
+ def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 83beb02b05..0c041b542d 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +13,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import errno
import logging
import os
import shutil
-from typing import IO, Dict, List, Optional, Tuple
+from io import BytesIO
+from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
-from twisted.web.http import Request
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.api.errors import (
FederationDeniedError,
@@ -35,6 +35,7 @@ from synapse.api.errors import (
from synapse.config._base import ConfigError
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import random_string
@@ -56,6 +57,9 @@ from .thumbnail_resource import ThumbnailResource
from .thumbnailer import Thumbnailer, ThumbnailError
from .upload_resource import UploadResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -63,7 +67,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.client = hs.get_federation_http_client()
@@ -73,16 +77,16 @@ class MediaRepository:
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
- self.primary_base_path = hs.config.media_store_path
- self.filepaths = MediaFilePaths(self.primary_base_path)
+ self.primary_base_path = hs.config.media_store_path # type: str
+ self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
- self.recently_accessed_remotes = set()
- self.recently_accessed_locals = set()
+ self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
+ self.recently_accessed_locals = set() # type: Set[str]
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@@ -113,7 +117,7 @@ class MediaRepository:
"update_recently_accessed_media", self._update_recently_accessed
)
- async def _update_recently_accessed(self):
+ async def _update_recently_accessed(self) -> None:
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
@@ -124,12 +128,12 @@ class MediaRepository:
local_media, remote_media, self.clock.time_msec()
)
- def mark_recently_accessed(self, server_name, media_id):
+ def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
"""Mark the given media as recently accessed.
Args:
- server_name (str|None): Origin server of media, or None if local
- media_id (str): The media ID of the content
+ server_name: Origin server of media, or None if local
+ media_id: The media ID of the content
"""
if server_name:
self.recently_accessed_remotes.add((server_name, media_id))
@@ -142,7 +146,7 @@ class MediaRepository:
upload_name: Optional[str],
content: IO,
content_length: int,
- auth_user: str,
+ auth_user: UserID,
) -> str:
"""Store uploaded content for a local user and return the mxc URL
@@ -181,7 +185,7 @@ class MediaRepository:
async def get_local_media(
self, request: Request, media_id: str, name: Optional[str]
) -> None:
- """Responds to reqests for local media, if exists, or returns 404.
+ """Responds to requests for local media, if exists, or returns 404.
Args:
request: The incoming request.
@@ -303,7 +307,7 @@ class MediaRepository:
media_info = await self.store.get_cached_remote_media(server_name, media_id)
# file_id is the ID we use to track the file locally. If we've already
- # seen the file then reuse the existing ID, otherwise genereate a new
+ # seen the file then reuse the existing ID, otherwise generate a new
# one.
# If we have an entry in the DB, try and look for it
@@ -322,7 +326,10 @@ class MediaRepository:
# Failed to find the file anywhere, lets download it.
try:
- media_info = await self._download_remote_file(server_name, media_id,)
+ media_info = await self._download_remote_file(
+ server_name,
+ media_id,
+ )
except SynapseError:
raise
except Exception as e:
@@ -348,7 +355,11 @@ class MediaRepository:
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
+ async def _download_remote_file(
+ self,
+ server_name: str,
+ media_id: str,
+ ) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -459,7 +470,14 @@ class MediaRepository:
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
- def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
+ def _generate_thumbnail(
+ self,
+ thumbnailer: Thumbnailer,
+ t_width: int,
+ t_height: int,
+ t_method: str,
+ t_type: str,
+ ) -> Optional[BytesIO]:
m_width = thumbnailer.width
m_height = thumbnailer.height
@@ -470,22 +488,20 @@ class MediaRepository:
m_height,
self.max_image_pixels,
)
- return
+ return None
if thumbnailer.transpose_method is not None:
m_width, m_height = thumbnailer.transpose()
if t_method == "crop":
- t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
+ return thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
t_width, t_height = thumbnailer.aspect(t_width, t_height)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
- t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
- else:
- t_byte_source = None
+ return thumbnailer.scale(t_width, t_height, t_type)
- return t_byte_source
+ return None
async def generate_local_exact_thumbnail(
self,
@@ -494,7 +510,7 @@ class MediaRepository:
t_height: int,
t_method: str,
t_type: str,
- url_cache: str,
+ url_cache: Optional[str],
) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
@@ -765,7 +781,11 @@ class MediaRepository:
)
except Exception as e:
thumbnail_exists = await self.store.get_remote_media_thumbnail(
- server_name, media_id, t_width, t_height, t_type,
+ server_name,
+ media_id,
+ t_width,
+ t_height,
+ t_type,
)
if not thumbnail_exists:
raise e
@@ -776,7 +796,7 @@ class MediaRepository:
return {"width": m_width, "height": m_height}
- async def delete_old_remote_media(self, before_ts):
+ async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0
@@ -824,7 +844,10 @@ class MediaRepository:
return await self._remove_local_media_from_disk([media_id])
async def delete_old_local_media(
- self, before_ts: int, size_gt: int = 0, keep_profiles: bool = True,
+ self,
+ before_ts: int,
+ size_gt: int = 0,
+ keep_profiles: bool = True,
) -> Tuple[List[str], int]:
"""
Delete local or remote media from this server by size and timestamp. Removes
@@ -841,7 +864,9 @@ class MediaRepository:
A tuple of (list of deleted media IDs, total deleted media IDs).
"""
old_media = await self.store.get_local_media_before(
- before_ts, size_gt, keep_profiles,
+ before_ts,
+ size_gt,
+ keep_profiles,
)
return await self._remove_local_media_from_disk(old_media)
@@ -919,16 +944,16 @@ class MediaRepositoryResource(Resource):
<thumbnail>
- The thumbnail methods are "crop" and "scale". "scale" trys to return an
+ The thumbnail methods are "crop" and "scale". "scale" tries to return an
image where either the width or the height is smaller than the requested
size. The client should then scale and letterbox the image if it needs to
- fit within a given rectangle. "crop" trys to return an image where the
+ fit within a given rectangle. "crop" tries to return an image where the
width and height are close to the requested size and the aspect matches
the requested size. The client should scale the image if it needs to fit
within a given rectangle.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
# If we're not configured to use it, raise if we somehow got here.
if not hs.config.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.")
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 268e0c8f50..b1b1c9e6ec 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vecotr Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,11 +16,17 @@ import contextlib
import logging
import os
import shutil
-from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
+from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
+import attr
+
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
+from synapse.api.errors import NotFoundError
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.util import Clock
from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder
@@ -56,6 +62,8 @@ class MediaStorage:
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
+ self.spam_checker = hs.get_spam_checker()
+ self.clock = hs.get_clock()
async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
@@ -77,8 +85,7 @@ class MediaStorage:
return fname
async def write_to_file(self, source: IO, output: IO):
- """Asynchronously write the `source` to `output`.
- """
+ """Asynchronously write the `source` to `output`."""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@contextlib.contextmanager
@@ -125,18 +132,29 @@ class MediaStorage:
f.flush()
f.close()
+ spam = await self.spam_checker.check_media_file_for_spam(
+ ReadableFileWrapper(self.clock, fname), file_info
+ )
+ if spam:
+ logger.info("Blocking media due to spam checker")
+ # Note that we'll delete the stored media, due to the
+ # try/except below. The media also won't be stored in
+ # the DB.
+ raise SpamMediaException()
+
for provider in self.storage_providers:
await provider.store_file(path, file_info)
finished_called[0] = True
yield f, fname, finish
- except Exception:
+ except Exception as e:
try:
os.remove(fname)
except Exception:
pass
- raise
+
+ raise e from None
if not finished_called:
raise Exception("Finished callback not called")
@@ -226,7 +244,7 @@ class MediaStorage:
await consumer.wait()
return local_path
- raise Exception("file could not be found")
+ raise NotFoundError()
def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path.
@@ -270,7 +288,7 @@ class MediaStorage:
return self.filepaths.local_media_filepath_rel(file_info.file_id)
-def _write_file_synchronously(source, dest):
+def _write_file_synchronously(source: IO, dest: IO) -> None:
"""Write `source` to the file like `dest` synchronously. Should be called
from a thread.
@@ -286,17 +304,52 @@ class FileResponder(Responder):
"""Wraps an open file that can be sent to a request.
Args:
- open_file (file): A file like object to be streamed ot the client,
+ open_file: A file like object to be streamed ot the client,
is closed when finished streaming.
"""
- def __init__(self, open_file):
+ def __init__(self, open_file: IO):
self.open_file = open_file
- def write_to_consumer(self, consumer):
+ def write_to_consumer(self, consumer: IConsumer) -> Deferred:
return make_deferred_yieldable(
FileSender().beginFileTransfer(self.open_file, consumer)
)
def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close()
+
+
+class SpamMediaException(NotFoundError):
+ """The media was blocked by a spam checker, so we simply 404 the request (in
+ the same way as if it was quarantined).
+ """
+
+
+@attr.s(slots=True)
+class ReadableFileWrapper:
+ """Wrapper that allows reading a file in chunks, yielding to the reactor,
+ and writing to a callback.
+
+ This is simplified `FileSender` that takes an IO object rather than an
+ `IConsumer`.
+ """
+
+ CHUNK_SIZE = 2 ** 14
+
+ clock = attr.ib(type=Clock)
+ path = attr.ib(type=str)
+
+ async def write_chunks_to(self, callback: Callable[[bytes], None]):
+ """Reads the file in chunks and calls the callback with each chunk."""
+
+ with open(self.path, "rb") as file:
+ while True:
+ chunk = file.read(self.CHUNK_SIZE)
+ if not chunk:
+ break
+
+ callback(chunk)
+
+ # We yield to the reactor by sleeping for 0 seconds.
+ await self.clock.sleep(0)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 7171c1bfd6..c4ed9dfdb4 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import datetime
import errno
import fnmatch
@@ -23,12 +23,13 @@ import re
import shutil
import sys
import traceback
-from typing import Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
from urllib import parse as urlparse
import attr
from twisted.internet.error import DNSLookupError
+from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
@@ -38,9 +39,11 @@ from synapse.http.server import (
respond_with_json_bytes,
)
from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
@@ -48,9 +51,18 @@ from synapse.util.stringutils import random_string
from ._base import FileInfo
+if TYPE_CHECKING:
+ from lxml import etree
+
+ from synapse.rest.media.v1.media_repository import MediaRepository
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
+_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I)
+_xml_encoding_match = re.compile(
+ br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I
+)
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
OG_TAG_NAME_MAXLEN = 50
@@ -119,7 +131,12 @@ class OEmbedError(Exception):
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs, media_repo, media_storage):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ media_repo: "MediaRepository",
+ media_storage: MediaStorage,
+ ):
super().__init__()
self.auth = hs.get_auth()
@@ -165,11 +182,13 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000
)
- async def _async_render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request: Request) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
@@ -286,24 +305,7 @@ class PreviewUrlResource(DirectServeJsonResource):
with open(media_info["filename"], "rb") as file:
body = file.read()
- encoding = None
-
- # Let's try and figure out if it has an encoding set in a meta tag.
- # Limit it to the first 1kb, since it ought to be in the meta tags
- # at the top.
- match = _charset_match.search(body[:1000])
-
- # If we find a match, it should take precedence over the
- # Content-Type header, so set it here.
- if match:
- encoding = match.group(1).decode("ascii")
-
- # If we don't find a match, we'll look at the HTTP Content-Type, and
- # if that doesn't exist, we'll fall back to UTF-8.
- if not encoding:
- content_match = _content_type_match.match(media_info["media_type"])
- encoding = content_match.group(1) if content_match else "utf-8"
-
+ encoding = get_html_media_encoding(body, media_info["media_type"])
og = decode_and_calc_og(body, media_info["uri"], encoding)
# pre-cache the image for posterity
@@ -372,7 +374,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Check whether the URL should be downloaded as oEmbed content instead.
- Params:
+ Args:
url: The URL to check.
Returns:
@@ -389,7 +391,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Request content from an oEmbed endpoint.
- Params:
+ Args:
endpoint: The oEmbed API endpoint.
url: The URL to pass to the API.
@@ -449,7 +451,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
- async def _download_url(self, url: str, user):
+ async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -579,9 +581,8 @@ class PreviewUrlResource(DirectServeJsonResource):
"expire_url_cache_data", self._expire_url_cache_data
)
- async def _expire_url_cache_data(self):
- """Clean up expired url cache content, media and thumbnails.
- """
+ async def _expire_url_cache_data(self) -> None:
+ """Clean up expired url cache content, media and thumbnails."""
# TODO: Delete from backup media store
assert self._worker_run_media_background_jobs
@@ -675,24 +676,101 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
-def decode_and_calc_og(body, media_uri, request_encoding=None):
+def get_html_media_encoding(body: bytes, content_type: str) -> str:
+ """
+ Get the encoding of the body based on the (presumably) HTML body or media_type.
+
+ The precedence used for finding a character encoding is:
+
+ 1. meta tag with a charset declared.
+ 2. The XML document's character encoding attribute.
+ 3. The Content-Type header.
+ 4. Fallback to UTF-8.
+
+ Args:
+ body: The HTML document, as bytes.
+ content_type: The Content-Type header.
+
+ Returns:
+ The character encoding of the body, as a string.
+ """
+ # Limit searches to the first 1kb, since it ought to be at the top.
+ body_start = body[:1024]
+
+ # Let's try and figure out if it has an encoding set in a meta tag.
+ match = _charset_match.search(body_start)
+ if match:
+ return match.group(1).decode("ascii")
+
+ # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
+
+ # If we didn't find a match, see if it an XML document with an encoding.
+ match = _xml_encoding_match.match(body_start)
+ if match:
+ return match.group(1).decode("ascii")
+
+ # If we don't find a match, we'll look at the HTTP Content-Type, and
+ # if that doesn't exist, we'll fall back to UTF-8.
+ content_match = _content_type_match.match(content_type)
+ if content_match:
+ return content_match.group(1)
+
+ return "utf-8"
+
+
+def decode_and_calc_og(
+ body: bytes, media_uri: str, request_encoding: Optional[str] = None
+) -> Dict[str, Optional[str]]:
+ """
+ Calculate metadata for an HTML document.
+
+ This uses lxml to parse the HTML document into the OG response. If errors
+ occur during processing of the document, an empty response is returned.
+
+ Args:
+ body: The HTML document, as bytes.
+ media_url: The URI used to download the body.
+ request_encoding: The character encoding of the body, as a string.
+
+ Returns:
+ The OG response as a dictionary.
+ """
+ # If there's no body, nothing useful is going to be found.
+ if not body:
+ return {}
+
from lxml import etree
+ # Create an HTML parser. If this fails, log and return no metadata.
try:
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body, parser)
- og = _calc_og(tree, media_uri)
+ except LookupError:
+ # blindly consider the encoding as utf-8.
+ parser = etree.HTMLParser(recover=True, encoding="utf-8")
+ except Exception as e:
+ logger.warning("Unable to create HTML parser: %s" % (e,))
+ return {}
+
+ def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ tree = etree.fromstring(body_attempt, parser)
+
+ # The data was successfully parsed, but no tree was found.
+ if tree is None:
+ return {}
+
+ return _calc_og(tree, media_uri)
+
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ try:
+ return _attempt_calc_og(body)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
- parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
- og = _calc_og(tree, media_uri)
+ return _attempt_calc_og(body.decode("utf-8", "ignore"))
- return og
-
-def _calc_og(tree, media_uri):
+def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
@@ -796,7 +874,9 @@ def _calc_og(tree, media_uri):
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og["og:description"] = summarize_paragraphs(text_nodes)
- else:
+ elif og["og:description"]:
+ # This must be a non-empty string at this point.
+ assert isinstance(og["og:description"], str)
og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling,
@@ -804,7 +884,9 @@ def _calc_og(tree, media_uri):
return og
-def _iterate_over_text(tree, *tags_to_ignore):
+def _iterate_over_text(
+ tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
@@ -838,32 +920,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
)
-def _rebase_url(url, base):
- base = list(urlparse.urlparse(base))
- url = list(urlparse.urlparse(url))
- if not url[0]: # fix up schema
- url[0] = base[0] or "http"
- if not url[1]: # fix up hostname
- url[1] = base[1]
- if not url[2].startswith("/"):
- url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
- return urlparse.urlunparse(url)
+def _rebase_url(url: str, base: str) -> str:
+ base_parts = list(urlparse.urlparse(base))
+ url_parts = list(urlparse.urlparse(url))
+ if not url_parts[0]: # fix up schema
+ url_parts[0] = base_parts[0] or "http"
+ if not url_parts[1]: # fix up hostname
+ url_parts[1] = base_parts[1]
+ if not url_parts[2].startswith("/"):
+ url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+ return urlparse.urlunparse(url_parts)
-def _is_media(content_type):
- if content_type.lower().startswith("image/"):
- return True
+def _is_media(content_type: str) -> bool:
+ return content_type.lower().startswith("image/")
-def _is_html(content_type):
+def _is_html(content_type: str) -> bool:
content_type = content_type.lower()
- if content_type.startswith("text/html") or content_type.startswith(
+ return content_type.startswith("text/html") or content_type.startswith(
"application/xhtml"
- ):
- return True
+ )
-def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
+def summarize_paragraphs(
+ text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
+) -> Optional[str]:
# Try to get a summary of between 200 and 500 words, respecting
# first paragraph and then word boundaries.
# TODO: Respect sentences?
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 18c9ed48d6..031947557d 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,27 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
+import abc
import logging
import os
import shutil
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
+from synapse.util.async_helpers import maybe_awaitable
from ._base import FileInfo, Responder
from .media_storage import FileResponder
logger = logging.getLogger(__name__)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
-class StorageProvider:
+
+class StorageProvider(metaclass=abc.ABCMeta):
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
- async def store_file(self, path: str, file_info: FileInfo):
+ @abc.abstractmethod
+ async def store_file(self, path: str, file_info: FileInfo) -> None:
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
@@ -42,6 +47,7 @@ class StorageProvider:
file_info: The metadata of the file.
"""
+ @abc.abstractmethod
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
@@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
self.store_synchronous = store_synchronous
self.store_remote = store_remote
- def __str__(self):
+ def __str__(self) -> str:
return "StorageProviderWrapper[%s]" % (self.backend,)
- async def store_file(self, path, file_info):
+ async def store_file(self, path: str, file_info: FileInfo) -> None:
if not file_info.server_name and not self.store_local:
return None
@@ -91,39 +97,34 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
- result = self.backend.store_file(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else:
# TODO: Handle errors.
async def store():
try:
- result = self.backend.store_file(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(
+ self.backend.store_file(path, file_info)
+ )
except Exception:
logger.exception("Error storing file")
run_in_background(store)
- return None
- async def fetch(self, path, file_info):
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
- result = self.backend.fetch(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(self.backend.fetch(path, file_info))
class FileStorageProviderBackend(StorageProvider):
"""A storage provider that stores files in a directory on a filesystem.
Args:
- hs (HomeServer)
+ hs
config: The config returned by `parse_config`.
"""
- def __init__(self, hs, config):
+ def __init__(self, hs: "HomeServer", config: str):
self.hs = hs
self.cache_directory = hs.config.media_store_path
self.base_directory = config
@@ -131,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
- async def store_file(self, path, file_info):
+ async def store_file(self, path: str, file_info: FileInfo) -> None:
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@@ -141,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
- return await defer_to_thread(
+ await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
- async def fetch(self, path, file_info):
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname):
return FileResponder(open(backup_fname, "rb"))
+ return None
+
@staticmethod
- def parse_config(config):
+ def parse_config(config: dict) -> str:
"""Called on startup to parse config supplied. This should parse
the config and raise if there is a problem.
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 30421b663a..af802bc0b1 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,10 +16,14 @@
import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
+from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
FileInfo,
@@ -28,13 +33,22 @@ from ._base import (
respond_with_responder,
)
+if TYPE_CHECKING:
+ from synapse.rest.media.v1.media_repository import MediaRepository
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ThumbnailResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs, media_repo, media_storage):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ media_repo: "MediaRepository",
+ media_storage: MediaStorage,
+ ):
super().__init__()
self.store = hs.get_datastore()
@@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True)
@@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
self.media_repo.mark_recently_accessed(server_name, media_id)
async def _respond_local_thumbnail(
- self, request, media_id, width, height, method, m_type
- ):
+ self,
+ request: Request,
+ media_id: str,
+ width: int,
+ height: int,
+ method: str,
+ m_type: str,
+ ) -> None:
media_info = await self.store.get_local_media(media_id)
if not media_info:
@@ -86,41 +106,28 @@ class ThumbnailResource(DirectServeJsonResource):
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-
- if thumbnail_infos:
- thumbnail_info = self._select_thumbnail(
- width, height, method, m_type, thumbnail_infos
- )
-
- file_info = FileInfo(
- server_name=None,
- file_id=media_id,
- url_cache=media_info["url_cache"],
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- )
-
- t_type = file_info.thumbnail_type
- t_length = thumbnail_info["thumbnail_length"]
-
- responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(request, responder, t_type, t_length)
- else:
- logger.info("Couldn't find any generated thumbnails")
- respond_404(request)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_id,
+ media_id,
+ url_cache=media_info["url_cache"],
+ server_name=None,
+ )
async def _select_or_generate_local_thumbnail(
self,
- request,
- media_id,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- ):
+ request: Request,
+ media_id: str,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ ) -> None:
media_info = await self.store.get_local_media(media_id)
if not media_info:
@@ -178,14 +185,14 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail(
self,
- request,
- server_name,
- media_id,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- ):
+ request: Request,
+ server_name: str,
+ media_id: str,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ ) -> None:
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = await self.store.get_remote_media_thumbnails(
@@ -239,8 +246,15 @@ class ThumbnailResource(DirectServeJsonResource):
raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail(
- self, request, server_name, media_id, width, height, method, m_type
- ):
+ self,
+ request: Request,
+ server_name: str,
+ media_id: str,
+ width: int,
+ height: int,
+ method: str,
+ m_type: str,
+ ) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
@@ -249,97 +263,238 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_id,
+ media_info["filesystem_id"],
+ url_cache=None,
+ server_name=server_name,
+ )
+ async def _select_and_respond_with_thumbnail(
+ self,
+ request: Request,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ thumbnail_infos: List[Dict[str, Any]],
+ media_id: str,
+ file_id: str,
+ url_cache: Optional[str] = None,
+ server_name: Optional[str] = None,
+ ) -> None:
+ """
+ Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ request: The incoming request.
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: The URL cache value.
+ server_name: The server name, if this is a remote thumbnail.
+ """
if thumbnail_infos:
- thumbnail_info = self._select_thumbnail(
- width, height, method, m_type, thumbnail_infos
+ file_info = self._select_thumbnail(
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ thumbnail_infos,
+ file_id,
+ url_cache,
+ server_name,
)
- file_info = FileInfo(
- server_name=server_name,
- file_id=media_info["filesystem_id"],
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
+ if not file_info:
+ logger.info("Couldn't find a thumbnail matching the desired inputs")
+ respond_404(request)
+ return
+
+ responder = await self.media_storage.fetch_media(file_info)
+ if responder:
+ await respond_with_responder(
+ request,
+ responder,
+ file_info.thumbnail_type,
+ file_info.thumbnail_length,
+ )
+ return
+
+ # If we can't find the thumbnail we regenerate it. This can happen
+ # if e.g. we've deleted the thumbnails but still have the original
+ # image somewhere.
+ #
+ # Since we have an entry for the thumbnail in the DB we a) know we
+ # have have successfully generated the thumbnail in the past (so we
+ # don't need to worry about repeatedly failing to generate
+ # thumbnails), and b) have already calculated that appropriate
+ # width/height/method so we can just call the "generate exact"
+ # methods.
+
+ # First let's check that we do actually have the original image
+ # still. This will throw a 404 if we don't.
+ # TODO: We should refetch the thumbnails for remote media.
+ await self.media_storage.ensure_media_is_in_local_cache(
+ FileInfo(server_name, file_id, url_cache=url_cache)
)
- t_type = file_info.thumbnail_type
- t_length = thumbnail_info["thumbnail_length"]
+ if server_name:
+ await self.media_repo.generate_remote_exact_thumbnail(
+ server_name,
+ file_id=file_id,
+ media_id=media_id,
+ t_width=file_info.thumbnail_width,
+ t_height=file_info.thumbnail_height,
+ t_method=file_info.thumbnail_method,
+ t_type=file_info.thumbnail_type,
+ )
+ else:
+ await self.media_repo.generate_local_exact_thumbnail(
+ media_id=media_id,
+ t_width=file_info.thumbnail_width,
+ t_height=file_info.thumbnail_height,
+ t_method=file_info.thumbnail_method,
+ t_type=file_info.thumbnail_type,
+ url_cache=url_cache,
+ )
responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(
+ request,
+ responder,
+ file_info.thumbnail_type,
+ file_info.thumbnail_length,
+ )
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)
def _select_thumbnail(
self,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- thumbnail_infos,
- ):
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ thumbnail_infos: List[Dict[str, Any]],
+ file_id: str,
+ url_cache: Optional[str],
+ server_name: Optional[str],
+ ) -> Optional[FileInfo]:
+ """
+ Choose an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: The URL cache value.
+ server_name: The server name, if this is a remote thumbnail.
+
+ Returns:
+ The thumbnail which best matches the desired parameters.
+ """
+ desired_method = desired_method.lower()
+
+ # The chosen thumbnail.
+ thumbnail_info = None
+
d_w = desired_width
d_h = desired_height
- if desired_method.lower() == "crop":
+ if desired_method == "crop":
+ # Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list = []
+ # Other thumbnails.
crop_info_list2 = []
for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info["thumbnail_method"] != "crop":
+ continue
+
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
- t_method = info["thumbnail_method"]
- if t_method == "crop":
- aspect_quality = abs(d_w * t_h - d_h * t_w)
- min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
- size_quality = abs((d_w - t_w) * (d_h - t_h))
- type_quality = desired_type != info["thumbnail_type"]
- length_quality = info["thumbnail_length"]
- if t_w >= d_w or t_h >= d_h:
- crop_info_list.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
+ aspect_quality = abs(d_w * t_h - d_h * t_w)
+ min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+ size_quality = abs((d_w - t_w) * (d_h - t_h))
+ type_quality = desired_type != info["thumbnail_type"]
+ length_quality = info["thumbnail_length"]
+ if t_w >= d_w or t_h >= d_h:
+ crop_info_list.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
)
- else:
- crop_info_list2.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
+ )
+ else:
+ crop_info_list2.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
)
+ )
if crop_info_list:
- return min(crop_info_list)[-1]
- else:
- return min(crop_info_list2)[-1]
- else:
+ thumbnail_info = min(crop_info_list)[-1]
+ elif crop_info_list2:
+ thumbnail_info = min(crop_info_list2)[-1]
+ elif desired_method == "scale":
+ # Thumbnails that match equal or larger sizes of desired width/height.
info_list = []
+ # Other thumbnails.
info_list2 = []
+
for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info["thumbnail_method"] != "scale":
+ continue
+
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
- t_method = info["thumbnail_method"]
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
- if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
+ if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info))
- elif t_method == "scale":
+ else:
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
if info_list:
- return min(info_list)[-1]
- else:
- return min(info_list2)[-1]
+ thumbnail_info = min(info_list)[-1]
+ elif info_list2:
+ thumbnail_info = min(info_list2)[-1]
+
+ if thumbnail_info:
+ return FileInfo(
+ file_id=file_id,
+ url_cache=url_cache,
+ server_name=server_name,
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
+ thumbnail_length=thumbnail_info["thumbnail_length"],
+ )
+
+ # No matching thumbnail was found.
+ return None
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 32a8e4f960..988f52c78f 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +15,7 @@
# limitations under the License.
import logging
from io import BytesIO
+from typing import Tuple
from PIL import Image
@@ -39,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
- def __init__(self, input_path):
+ def __init__(self, input_path: str):
try:
self.image = Image.open(input_path)
except OSError as e:
@@ -59,11 +61,11 @@ class Thumbnailer:
# A lot of parsing errors can happen when parsing EXIF
logger.info("Error parsing image EXIF information: %s", e)
- def transpose(self):
+ def transpose(self) -> Tuple[int, int]:
"""Transpose the image using its EXIF Orientation tag
Returns:
- Tuple[int, int]: (width, height) containing the new image size in pixels.
+ A tuple containing the new image size in pixels as (width, height).
"""
if self.transpose_method is not None:
self.image = self.image.transpose(self.transpose_method)
@@ -73,7 +75,7 @@ class Thumbnailer:
self.image.info["exif"] = None
return self.image.size
- def aspect(self, max_width, max_height):
+ def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
"""Calculate the largest size that preserves aspect ratio which
fits within the given rectangle::
@@ -91,15 +93,20 @@ class Thumbnailer:
else:
return (max_height * self.width) // self.height, max_height
- def _resize(self, width, height):
+ def _resize(self, width: int, height: int) -> Image:
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
- # looks awful
- if self.image.mode in ["1", "P"]:
- self.image = self.image.convert("RGB")
+ # looks awful.
+ #
+ # If the image has transparency, use RGBA instead.
+ if self.image.mode in ["1", "L", "P"]:
+ mode = "RGB"
+ if self.image.info.get("transparency", None) is not None:
+ mode = "RGBA"
+ self.image = self.image.convert(mode)
return self.image.resize((width, height), Image.ANTIALIAS)
- def scale(self, width, height, output_type):
+ def scale(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales the image to the given dimensions.
Returns:
@@ -108,7 +115,7 @@ class Thumbnailer:
scaled = self._resize(width, height)
return self._encode_image(scaled, output_type)
- def crop(self, width, height, output_type):
+ def crop(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales and crops the image to the given dimensions preserving
aspect::
(w_in / h_in) = (w_scaled / h_scaled)
@@ -136,7 +143,7 @@ class Thumbnailer:
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self._encode_image(cropped, output_type)
- def _encode_image(self, output_image, output_type):
+ def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
output_bytes_io = BytesIO()
fmt = self.FORMATS[output_type]
if fmt == "JPEG":
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index d76f7389e1..0138b2e2d1 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,10 +15,19 @@
# limitations under the License.
import logging
+from typing import IO, TYPE_CHECKING
+
+from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
+from synapse.rest.media.v1.media_storage import SpamMediaException
+
+if TYPE_CHECKING:
+ from synapse.rest.media.v1.media_repository import MediaRepository
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -25,7 +35,7 @@ logger = logging.getLogger(__name__)
class UploadResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
@@ -37,14 +47,14 @@ class UploadResource(DirectServeJsonResource):
self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock()
- async def _async_render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
- content_length = request.getHeader(b"Content-Length").decode("ascii")
+ content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400)
if int(content_length) > self.max_upload_size:
@@ -70,7 +80,9 @@ class UploadResource(DirectServeJsonResource):
headers = request.requestHeaders
if headers.hasHeader(b"Content-Type"):
- media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii")
+ content_type_headers = headers.getRawHeaders(b"Content-Type")
+ assert content_type_headers # for mypy
+ media_type = content_type_headers[0].decode("ascii")
else:
raise SynapseError(msg="Upload request missing 'Content-Type'", code=400)
@@ -78,9 +90,15 @@ class UploadResource(DirectServeJsonResource):
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
- content_uri = await self.media_repo.create_content(
- media_type, upload_name, request.content, content_length, requester.user
- )
+ try:
+ content = request.content # type: IO # type: ignore
+ content_uri = await self.media_repo.create_content(
+ media_type, upload_name, content, content_length, requester.user
+ )
+ except SpamMediaException:
+ # For uploading of media we want to respond with a 400, instead of
+ # the default 404, as that would just be confusing.
+ raise SynapseError(400, "Bad content")
logger.info("Uploaded content with URI %r", content_uri)
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
index c0b733488b..9eeb970580 100644
--- a/synapse/rest/synapse/client/__init__.py
+++ b/synapse/rest/synapse/client/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,3 +12,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from typing import TYPE_CHECKING, Mapping
+
+from twisted.web.resource import Resource
+
+from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource
+from synapse.rest.synapse.client.pick_idp import PickIdpResource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client.sso_register import SsoRegisterResource
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resource]:
+ """Builds a resource tree to include synapse-specific client resources
+
+ These are resources which should be loaded on all workers which expose a C-S API:
+ ie, the main process, and any generic workers so configured.
+
+ Returns:
+ map from path to Resource.
+ """
+ resources = {
+ # SSO bits. These are always loaded, whether or not SSO login is actually
+ # enabled (they just won't work very well if it's not)
+ "/_synapse/client/pick_idp": PickIdpResource(hs),
+ "/_synapse/client/pick_username": pick_username_resource(hs),
+ "/_synapse/client/new_user_consent": NewUserConsentResource(hs),
+ "/_synapse/client/sso_register": SsoRegisterResource(hs),
+ }
+
+ # provider-specific SSO bits. Only load these if they are enabled, since they
+ # rely on optional dependencies.
+ if hs.config.oidc_enabled:
+ from synapse.rest.synapse.client.oidc import OIDCResource
+
+ resources["/_synapse/client/oidc"] = OIDCResource(hs)
+
+ if hs.config.saml2_enabled:
+ from synapse.rest.synapse.client.saml2 import SAML2Resource
+
+ res = SAML2Resource(hs)
+ resources["/_synapse/client/saml2"] = res
+
+ # This is also mounted under '/_matrix' for backwards-compatibility.
+ # To be removed in Synapse v1.32.0.
+ resources["/_matrix/saml2"] = res
+
+ return resources
+
+
+__all__ = ["build_synapse_client_resource_tree"]
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
new file mode 100644
index 0000000000..78ee0b5e88
--- /dev/null
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource, respond_with_html
+from synapse.http.servlet import parse_string
+from synapse.types import UserID
+from synapse.util.templates import build_jinja_env
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class NewUserConsentResource(DirectServeHtmlResource):
+ """A resource which collects consent to the server's terms from a new user
+
+ This resource gets mounted at /_synapse/client/new_user_consent, and is shown
+ when we are automatically creating a new user due to an SSO login.
+
+ It shows a template which prompts the user to go and read the Ts and Cs, and click
+ a clickybox if they have done so.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+ self._server_name = hs.hostname
+ self._consent_version = hs.config.consent.user_consent_version
+
+ def template_search_dirs():
+ if hs.config.sso.sso_template_dir:
+ yield hs.config.sso.sso_template_dir
+ yield hs.config.sso.default_template_dir
+
+ self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+ async def _async_render_GET(self, request: Request) -> None:
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ session = self._sso_handler.get_mapping_session(session_id)
+ except SynapseError as e:
+ logger.warning("Error fetching session: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ user_id = UserID(session.chosen_localpart, self._server_name)
+ user_profile = {
+ "display_name": session.display_name,
+ }
+
+ template_params = {
+ "user_id": user_id.to_string(),
+ "user_profile": user_profile,
+ "consent_version": self._consent_version,
+ "terms_url": "/_matrix/consent?v=%s" % (self._consent_version,),
+ }
+
+ template = self._jinja_env.get_template("sso_new_user_consent.html")
+ html = template.render(template_params)
+ respond_with_html(request, 200, html)
+
+ async def _async_render_POST(self, request: Request):
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ except SynapseError as e:
+ logger.warning("Error fetching session cookie: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ try:
+ accepted_version = parse_string(request, "accepted_version", required=True)
+ except SynapseError as e:
+ self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+ return
+
+ await self._sso_handler.handle_terms_accepted(
+ request, session_id, accepted_version
+ )
diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index d958dd65bb..64c0deb75d 100644
--- a/synapse/rest/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
from twisted.web.resource import Resource
-from synapse.rest.oidc.callback_resource import OIDCCallbackResource
+from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
logger = logging.getLogger(__name__)
@@ -25,3 +26,6 @@ class OIDCResource(Resource):
def __init__(self, hs):
Resource.__init__(self)
self.putChild(b"callback", OIDCCallbackResource(hs))
+
+
+__all__ = ["OIDCResource"]
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py
index f7a0bc4bdb..1af33f0a45 100644
--- a/synapse/rest/oidc/callback_resource.py
+++ b/synapse/rest/synapse/client/oidc/callback_resource.py
@@ -12,19 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
+from typing import TYPE_CHECKING
from synapse.http.server import DirectServeHtmlResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class OIDCCallbackResource(DirectServeHtmlResource):
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
async def _async_render_GET(self, request):
await self._oidc_handler.handle_oidc_callback(request)
+
+ async def _async_render_POST(self, request):
+ # the auth response can be returned via an x-www-form-urlencoded form instead
+ # of GET params, as per
+ # https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html.
+ await self._oidc_handler.handle_oidc_callback(request)
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
index 9e4fbc0cbd..d26ce46efc 100644
--- a/synapse/rest/synapse/client/password_reset.py
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.api.errors import ThreepidValidationError
from synapse.config.emailconfig import ThreepidBehaviour
diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py
new file mode 100644
index 0000000000..9550b82998
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_idp.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.server import (
+ DirectServeHtmlResource,
+ finish_request,
+ respond_with_html,
+)
+from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PickIdpResource(DirectServeHtmlResource):
+ """IdP picker resource.
+
+ This resource gets mounted under /_synapse/client/pick_idp. It serves an HTML page
+ which prompts the user to choose an Identity Provider from the list.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+ self._sso_login_idp_picker_template = (
+ hs.config.sso.sso_login_idp_picker_template
+ )
+ self._server_name = hs.hostname
+
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
+ client_redirect_url = parse_string(
+ request, "redirectUrl", required=True, encoding="utf-8"
+ )
+ idp = parse_string(request, "idp", required=False)
+
+ # if we need to pick an IdP, do so
+ if not idp:
+ return await self._serve_id_picker(request, client_redirect_url)
+
+ # otherwise, redirect to the IdP's redirect URI
+ providers = self._sso_handler.get_identity_providers()
+ auth_provider = providers.get(idp)
+ if not auth_provider:
+ logger.info("Unknown idp %r", idp)
+ self._sso_handler.render_error(
+ request, "unknown_idp", "Unknown identity provider ID"
+ )
+ return
+
+ sso_url = await auth_provider.handle_redirect_request(
+ request, client_redirect_url.encode("utf8")
+ )
+ logger.info("Redirecting to %s", sso_url)
+ request.redirect(sso_url)
+ finish_request(request)
+
+ async def _serve_id_picker(
+ self, request: SynapseRequest, client_redirect_url: str
+ ) -> None:
+ # otherwise, serve up the IdP picker
+ providers = self._sso_handler.get_identity_providers()
+ html = self._sso_login_idp_picker_template.render(
+ redirect_url=client_redirect_url,
+ server_name=self._server_name,
+ providers=providers.values(),
+ )
+ respond_with_html(request, 200, html)
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
new file mode 100644
index 0000000000..d9ffe84489
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -0,0 +1,134 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, List
+
+from twisted.web.resource import Resource
+from twisted.web.server import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import (
+ DirectServeHtmlResource,
+ DirectServeJsonResource,
+ respond_with_html,
+)
+from synapse.http.servlet import parse_boolean, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.util.templates import build_jinja_env
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+def pick_username_resource(hs: "HomeServer") -> Resource:
+ """Factory method to generate the username picker resource.
+
+ This resource gets mounted under /_synapse/client/pick_username and has two
+ children:
+
+ * "account_details": renders the form and handles the POSTed response
+ * "check": a JSON endpoint which checks if a userid is free.
+ """
+
+ res = Resource()
+ res.putChild(b"account_details", AccountDetailsResource(hs))
+ res.putChild(b"check", AvailabilityCheckResource(hs))
+
+ return res
+
+
+class AvailabilityCheckResource(DirectServeJsonResource):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+
+ async def _async_render_GET(self, request: Request):
+ localpart = parse_string(request, "username", required=True)
+
+ session_id = get_username_mapping_session_cookie_from_request(request)
+
+ is_available = await self._sso_handler.check_username_availability(
+ localpart, session_id
+ )
+ return 200, {"available": is_available}
+
+
+class AccountDetailsResource(DirectServeHtmlResource):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+
+ def template_search_dirs():
+ if hs.config.sso.sso_template_dir:
+ yield hs.config.sso.sso_template_dir
+ yield hs.config.sso.default_template_dir
+
+ self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+ async def _async_render_GET(self, request: Request) -> None:
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ session = self._sso_handler.get_mapping_session(session_id)
+ except SynapseError as e:
+ logger.warning("Error fetching session: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ idp_id = session.auth_provider_id
+ template_params = {
+ "idp": self._sso_handler.get_identity_providers()[idp_id],
+ "user_attributes": {
+ "display_name": session.display_name,
+ "emails": session.emails,
+ },
+ }
+
+ template = self._jinja_env.get_template("sso_auth_account_details.html")
+ html = template.render(template_params)
+ respond_with_html(request, 200, html)
+
+ async def _async_render_POST(self, request: SynapseRequest):
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ except SynapseError as e:
+ logger.warning("Error fetching session cookie: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ try:
+ localpart = parse_string(request, "username", required=True)
+ use_display_name = parse_boolean(request, "use_display_name", default=False)
+
+ try:
+ emails_to_use = [
+ val.decode("utf-8") for val in request.args.get(b"use_email", [])
+ ] # type: List[str]
+ except ValueError:
+ raise SynapseError(400, "Query parameter use_email must be utf-8")
+ except SynapseError as e:
+ logger.warning("[session %s] bad param: %s", session_id, e)
+ self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+ return
+
+ await self._sso_handler.handle_submit_username_request(
+ request, session_id, localpart, use_display_name, emails_to_use
+ )
diff --git a/synapse/rest/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py
index 68da37ca6a..3e8235ee1e 100644
--- a/synapse/rest/saml2/__init__.py
+++ b/synapse/rest/synapse/client/saml2/__init__.py
@@ -12,12 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
from twisted.web.resource import Resource
-from synapse.rest.saml2.metadata_resource import SAML2MetadataResource
-from synapse.rest.saml2.response_resource import SAML2ResponseResource
+from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
+from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
logger = logging.getLogger(__name__)
@@ -27,3 +28,6 @@ class SAML2Resource(Resource):
Resource.__init__(self)
self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
self.putChild(b"authn_response", SAML2ResponseResource(hs))
+
+
+__all__ = ["SAML2Resource"]
diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index 1e8526e22e..1e8526e22e 100644
--- a/synapse/rest/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index f6668fb5e3..4dfadf1bfb 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
@@ -14,24 +14,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
from synapse.http.server import DirectServeHtmlResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class SAML2ResponseResource(DirectServeHtmlResource):
"""A Twisted web resource which handles the SAML response"""
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._saml_handler = hs.get_saml_handler()
+ self._sso_handler = hs.get_sso_handler()
async def _async_render_GET(self, request):
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
# try to authenticate again.
- self._saml_handler._render_error(
+ self._sso_handler.render_error(
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)
diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py
new file mode 100644
index 0000000000..f2acce2437
--- /dev/null
+++ b/synapse/rest/synapse/client/sso_register.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class SsoRegisterResource(DirectServeHtmlResource):
+ """A resource which completes SSO registration
+
+ This resource gets mounted at /_synapse/client/sso_register, and is shown
+ after we collect username and/or consent for a new SSO user. It (finally) registers
+ the user, and confirms redirect to the client
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+
+ async def _async_render_GET(self, request: Request) -> None:
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ except SynapseError as e:
+ logger.warning("Error fetching session cookie: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+ await self._sso_handler.register_sso_user(request, session_id)
diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py
index 6f2a1931c5..11e7cb59da 100644
--- a/synapse/rulecheck/domain_rule_checker.py
+++ b/synapse/rulecheck/domain_rule_checker.py
@@ -77,8 +77,7 @@ class DomainRuleChecker(object):
)
def check_event_for_spam(self, event):
- """Implements synapse.events.SpamChecker.check_event_for_spam
- """
+ """Implements synapse.events.SpamChecker.check_event_for_spam"""
return False
def user_may_invite(
@@ -90,8 +89,7 @@ class DomainRuleChecker(object):
new_room,
published_room=False,
):
- """Implements synapse.events.SpamChecker.user_may_invite
- """
+ """Implements synapse.events.SpamChecker.user_may_invite"""
if self.can_only_invite_during_room_creation and not new_room:
return False
@@ -121,8 +119,7 @@ class DomainRuleChecker(object):
def user_may_create_room(
self, userid, invite_list, third_party_invite_list, cloning
):
- """Implements synapse.events.SpamChecker.user_may_create_room
- """
+ """Implements synapse.events.SpamChecker.user_may_create_room"""
if cloning:
return True
@@ -138,18 +135,15 @@ class DomainRuleChecker(object):
return True
def user_may_create_room_alias(self, userid, room_alias):
- """Implements synapse.events.SpamChecker.user_may_create_room_alias
- """
+ """Implements synapse.events.SpamChecker.user_may_create_room_alias"""
return True
def user_may_publish_room(self, userid, room_id):
- """Implements synapse.events.SpamChecker.user_may_publish_room
- """
+ """Implements synapse.events.SpamChecker.user_may_publish_room"""
return True
def user_may_join_room(self, userid, room_id, is_invited):
- """Implements synapse.events.SpamChecker.user_may_join_room
- """
+ """Implements synapse.events.SpamChecker.user_may_join_room"""
if self.can_only_join_rooms_with_invite and not is_invited:
return False
@@ -157,8 +151,7 @@ class DomainRuleChecker(object):
@staticmethod
def parse_config(config):
- """Implements synapse.events.SpamChecker.parse_config
- """
+ """Implements synapse.events.SpamChecker.parse_config"""
if "default" in config:
return config
else:
diff --git a/synapse/secrets.py b/synapse/secrets.py
index fb6d90a3b7..7939db75e7 100644
--- a/synapse/secrets.py
+++ b/synapse/secrets.py
@@ -26,10 +26,10 @@ if sys.version_info[0:2] >= (3, 6):
import secrets
class Secrets:
- def token_bytes(self, nbytes=32):
+ def token_bytes(self, nbytes: int = 32) -> bytes:
return secrets.token_bytes(nbytes)
- def token_hex(self, nbytes=32):
+ def token_hex(self, nbytes: int = 32) -> str:
return secrets.token_hex(nbytes)
@@ -38,8 +38,8 @@ else:
import os
class Secrets:
- def token_bytes(self, nbytes=32):
+ def token_bytes(self, nbytes: int = 32) -> bytes:
return os.urandom(nbytes)
- def token_hex(self, nbytes=32):
+ def token_hex(self, nbytes: int = 32) -> str:
return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii")
diff --git a/synapse/server.py b/synapse/server.py
index 78f665a51c..e85b9391fa 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -24,10 +24,20 @@
import abc
import functools
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ TypeVar,
+ Union,
+ cast,
+)
-import twisted.internet.base
import twisted.internet.tcp
+from twisted.internet import defer
from twisted.mail.smtp import sendmail
from twisted.web.iweb import IPolicyForHTTPS
@@ -50,10 +60,11 @@ from synapse.federation.federation_server import (
FederationServer,
)
from synapse.federation.send_queue import FederationRemoteSendQueue
-from synapse.federation.sender import FederationSender
+from synapse.federation.sender import AbstractFederationSender, FederationSender
from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
+from synapse.handlers.account_data import AccountDataHandler
from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.acme import AcmeHandler
from synapse.handlers.admin import AdminHandler
@@ -85,10 +96,11 @@ from synapse.handlers.room import (
RoomShutdownHandler,
)
from synapse.handlers.room_list import RoomListHandler
-from synapse.handlers.room_member import RoomMemberMasterHandler
+from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.search import SearchHandler
from synapse.handlers.set_password import SetPasswordHandler
+from synapse.handlers.space_summary import SpaceSummaryHandler
from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
@@ -101,6 +113,7 @@ from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
+from synapse.replication.tcp.external_cache import ExternalCache
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@@ -117,7 +130,7 @@ from synapse.server_notices.worker_server_notices_sender import (
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import Databases, DataStore, Storage
from synapse.streams.events import EventSources
-from synapse.types import DomainSpecificString
+from synapse.types import DomainSpecificString, ISynapseReactor
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -126,6 +139,8 @@ from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
+ from txredisapi import RedisProtocol
+
from synapse.handlers.oidc_handler import OidcHandler
from synapse.handlers.saml_handler import SamlHandler
@@ -233,7 +248,7 @@ class HomeServer(metaclass=abc.ABCMeta):
self.start_time = None # type: Optional[int]
self._instance_id = random_string(5)
- self._instance_name = config.worker_name or "master"
+ self._instance_name = config.worker.instance_name
self.version_string = version_string
@@ -276,16 +291,12 @@ class HomeServer(metaclass=abc.ABCMeta):
for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
getattr(self, "get_" + i + "_handler")()
- def get_reactor(self) -> twisted.internet.base.ReactorBase:
+ def get_reactor(self) -> ISynapseReactor:
"""
Fetch the Twisted reactor in use by this HomeServer.
"""
return self._reactor
- def get_ip_from_request(self, request) -> str:
- # X-Forwarded-For is handled by our custom request type.
- return request.getClientIP()
-
def is_mine(self, domain_specific_string: DomainSpecificString) -> bool:
return domain_specific_string.domain == self.hostname
@@ -341,11 +352,9 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
- return (
- InsecureInterceptableContextFactory()
- if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
- else RegularPolicyForHTTPS()
- )
+ if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
+ return InsecureInterceptableContextFactory()
+ return RegularPolicyForHTTPS()
@cache_in_self
def get_simple_http_client(self) -> SimpleHttpClient:
@@ -365,10 +374,13 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
"""
An HTTP client that uses configured HTTP(S) proxies and blacklists IPs
- based on the IP range blacklist.
+ based on the IP range blacklist/whitelist.
"""
return SimpleHttpClient(
- self, ip_blacklist=self.config.ip_range_blacklist, use_proxy=True,
+ self,
+ ip_whitelist=self.config.ip_range_whitelist,
+ ip_blacklist=self.config.ip_range_blacklist,
+ use_proxy=True,
)
@cache_in_self
@@ -390,7 +402,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return RoomShutdownHandler(self)
@cache_in_self
- def get_sendmail(self) -> sendmail:
+ def get_sendmail(self) -> Callable[..., defer.Deferred]:
return sendmail
@cache_in_self
@@ -406,10 +418,19 @@ class HomeServer(metaclass=abc.ABCMeta):
return PresenceHandler(self)
@cache_in_self
- def get_typing_handler(self):
+ def get_typing_writer_handler(self) -> TypingWriterHandler:
if self.config.worker.writers.typing == self.get_instance_name():
return TypingWriterHandler(self)
else:
+ raise Exception("Workers cannot write typing")
+
+ @cache_in_self
+ def get_typing_handler(self) -> FollowerTypingHandler:
+ if self.config.worker.writers.typing == self.get_instance_name():
+ # Use get_typing_writer_handler to ensure that we use the same
+ # cached version.
+ return self.get_typing_writer_handler()
+ else:
return FollowerTypingHandler(self)
@cache_in_self
@@ -496,7 +517,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return InitialSyncHandler(self)
@cache_in_self
- def get_profile_handler(self):
+ def get_profile_handler(self) -> ProfileHandler:
return ProfileHandler(self)
@cache_in_self
@@ -550,7 +571,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return TransportLayerClient(self)
@cache_in_self
- def get_federation_sender(self):
+ def get_federation_sender(self) -> AbstractFederationSender:
if self.should_send_federation():
return FederationSender(self)
elif not self.config.worker_app:
@@ -579,7 +600,9 @@ class HomeServer(metaclass=abc.ABCMeta):
return UserDirectoryHandler(self)
@cache_in_self
- def get_groups_local_handler(self):
+ def get_groups_local_handler(
+ self,
+ ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]:
if self.config.worker_app:
return GroupsLocalWorkerHandler(self)
else:
@@ -609,7 +632,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return StatsHandler(self)
@cache_in_self
- def get_spam_checker(self):
+ def get_spam_checker(self) -> SpamChecker:
return SpamChecker(self)
@cache_in_self
@@ -617,7 +640,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return ThirdPartyEventRules(self)
@cache_in_self
- def get_room_member_handler(self):
+ def get_room_member_handler(self) -> RoomMemberHandler:
if self.config.worker_app:
return RoomMemberWorkerHandler(self)
return RoomMemberMasterHandler(self)
@@ -627,13 +650,13 @@ class HomeServer(metaclass=abc.ABCMeta):
return FederationHandlerRegistry(self)
@cache_in_self
- def get_server_notices_manager(self):
+ def get_server_notices_manager(self) -> ServerNoticesManager:
if self.config.worker_app:
raise Exception("Workers cannot send server notices")
return ServerNoticesManager(self)
@cache_in_self
- def get_server_notices_sender(self):
+ def get_server_notices_sender(self) -> WorkerServerNoticesSender:
if self.config.worker_app:
return WorkerServerNoticesSender(self)
return ServerNoticesSender(self)
@@ -706,12 +729,41 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_module_api(self) -> ModuleApi:
return ModuleApi(self, self.get_auth_handler())
- async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
- return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
+ @cache_in_self
+ def get_account_data_handler(self) -> AccountDataHandler:
+ return AccountDataHandler(self)
+
+ @cache_in_self
+ def get_space_summary_handler(self) -> SpaceSummaryHandler:
+ return SpaceSummaryHandler(self)
+
+ @cache_in_self
+ def get_external_cache(self) -> ExternalCache:
+ return ExternalCache(self)
+
+ @cache_in_self
+ def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
+ if not self.config.redis.redis_enabled:
+ return None
+
+ # We only want to import redis module if we're using it, as we have
+ # `txredisapi` as an optional dependency.
+ from synapse.replication.tcp.redis import lazyConnection
+
+ logger.info(
+ "Connecting to redis (host=%r port=%r) for external cache",
+ self.config.redis_host,
+ self.config.redis_port,
+ )
+
+ return lazyConnection(
+ hs=self,
+ host=self.config.redis_host,
+ port=self.config.redis_port,
+ password=self.config.redis.redis_password,
+ reconnect=True,
+ )
def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?"
- return self.config.send_federation and (
- not self.config.worker_app
- or self.config.worker_app == "synapse.app.federation_sender"
- )
+ return self.config.send_federation
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 9137c4edb1..a9349bf9a1 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -13,13 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any
+from typing import TYPE_CHECKING, Any, Set
from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder
from synapse.config import ConfigError
from synapse.types import get_localpart_from_id
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -28,16 +31,11 @@ class ConsentServerNotices:
privacy policy consent, and sends one if we do.
"""
- def __init__(self, hs):
- """
-
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
self._server_notices_manager = hs.get_server_notices_manager()
self._store = hs.get_datastore()
- self._users_in_progress = set()
+ self._users_in_progress = set() # type: Set[str]
self._current_consent_version = hs.config.user_consent_version
self._server_notice_content = hs.config.user_consent_server_notice_content
@@ -73,6 +71,10 @@ class ConsentServerNotices:
try:
u = await self._store.get_user_by_id(user_id)
+ # The user doesn't exist.
+ if u is None:
+ return
+
if u["is_guest"] and not self._send_to_guests:
# don't send to guests
return
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 2258d306d9..a18a2e76c9 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List, Tuple
+from typing import TYPE_CHECKING, List, Tuple
from synapse.api.constants import (
EventTypes,
@@ -24,24 +24,24 @@ from synapse.api.constants import (
from synapse.api.errors import AuthError, ResourceLimitError, SynapseError
from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ResourceLimitsServerNotices:
- """ Keeps track of whether the server has reached it's resource limit and
+ """Keeps track of whether the server has reached it's resource limit and
ensures that the client is kept up to date.
"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
self._server_notices_manager = hs.get_server_notices_manager()
self._store = hs.get_datastore()
self._auth = hs.get_auth()
self._config = hs.config
self._resouce_limited = False
+ self._account_data_handler = hs.get_account_data_handler()
self._message_handler = hs.get_message_handler()
self._state = hs.get_state_handler()
@@ -177,7 +177,7 @@ class ResourceLimitsServerNotices:
# tag already present, nothing to do here
need_to_set_tag = False
if need_to_set_tag:
- max_id = await self._store.add_tag_to_room(
+ max_id = await self._account_data_handler.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 100dbd5e2c..144e1da78e 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -35,6 +35,7 @@ class ServerNoticesManager:
self._store = hs.get_datastore()
self._config = hs.config
+ self._account_data_handler = hs.get_account_data_handler()
self._room_creation_handler = hs.get_room_creation_handler()
self._room_member_handler = hs.get_room_member_handler()
self._event_creation_handler = hs.get_event_creation_handler()
@@ -57,7 +58,7 @@ class ServerNoticesManager:
user_id: str,
event_content: dict,
type: str = EventTypes.Message,
- state_key: Optional[bool] = None,
+ state_key: Optional[str] = None,
) -> EventBase:
"""Send a notice to the given user
@@ -163,7 +164,7 @@ class ServerNoticesManager:
)
room_id = info["room_id"]
- max_id = await self._store.add_tag_to_room(
+ max_id = await self._account_data_handler.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py
index 6870b67ca0..965c645889 100644
--- a/synapse/server_notices/server_notices_sender.py
+++ b/synapse/server_notices/server_notices_sender.py
@@ -12,25 +12,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable, Union
+from typing import TYPE_CHECKING, Iterable, Union
from synapse.server_notices.consent_server_notices import ConsentServerNotices
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
+from synapse.server_notices.worker_server_notices_sender import (
+ WorkerServerNoticesSender,
+)
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
-class ServerNoticesSender:
+class ServerNoticesSender(WorkerServerNoticesSender):
"""A centralised place which sends server notices automatically when
Certain Events take place
"""
- def __init__(self, hs):
- """
-
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self._server_notices = (
ConsentServerNotices(hs),
ResourceLimitsServerNotices(hs),
diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py
index 9273e61895..c76bd57460 100644
--- a/synapse/server_notices/worker_server_notices_sender.py
+++ b/synapse/server_notices/worker_server_notices_sender.py
@@ -12,16 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class WorkerServerNoticesSender:
"""Stub impl of ServerNoticesSender which does nothing"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
+ pass
async def on_user_syncing(self, user_id: str) -> None:
"""Called when the user performs a sync operation.
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1fa3b280b4..c3d6e80c49 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -310,6 +310,7 @@ class StateHandler:
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
+ entry = None
else:
# otherwise, we'll need to resolve the state across the prev_events.
@@ -340,9 +341,13 @@ class StateHandler:
current_state_ids=state_ids_before_event,
)
- # XXX: can we update the state cache entry for the new state group? or
- # could we set a flag on resolve_state_groups_for_events to tell it to
- # always make a state group?
+ # Assign the new state group to the cached state entry.
+ #
+ # Note that this can race in that we could generate multiple state
+ # groups for the same state entry, but that is just inefficient
+ # rather than dangerous.
+ if entry and entry.state_group is None:
+ entry.state_group = state_group_before_event
#
# now if it's not a state event, we're done
@@ -393,7 +398,7 @@ class StateHandler:
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Iterable[str]
) -> _StateCacheEntry:
- """ Given a list of event_ids this method fetches the state at each
+ """Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
Args:
@@ -565,7 +570,9 @@ class StateResolutionHandler:
return cache
logger.info(
- "Resolving state for %s with groups %s", room_id, list(group_names),
+ "Resolving state for %s with groups %s",
+ room_id,
+ list(group_names),
)
state_groups_histogram.observe(len(state_groups_ids))
@@ -610,7 +617,7 @@ class StateResolutionHandler:
event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
- used as a starting point fof finding the state we need; any missing
+ used as a starting point for finding the state we need; any missing
events will be requested via state_map_factory.
If None, all events will be fetched via state_res_store.
@@ -651,11 +658,15 @@ class StateResolutionHandler:
return
self._report_biggest(
- lambda i: i.cpu_time, "CPU time", _biggest_room_by_cpu_counter,
+ lambda i: i.cpu_time,
+ "CPU time",
+ _biggest_room_by_cpu_counter,
)
self._report_biggest(
- lambda i: i.db_time, "DB time", _biggest_room_by_db_counter,
+ lambda i: i.db_time,
+ "DB time",
+ _biggest_room_by_db_counter,
)
self._state_res_metrics.clear()
@@ -783,7 +794,7 @@ class StateResolutionStore:
)
def get_auth_chain_difference(
- self, state_sets: List[Set[str]]
+ self, room_id: str, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -796,4 +807,4 @@ class StateResolutionStore:
An awaitable that resolves to a set of event IDs.
"""
- return self.store.get_auth_chain_difference(state_sets)
+ return self.store.get_auth_chain_difference(room_id, state_sets)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 85edae053d..ce255da6fd 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -95,7 +95,11 @@ async def resolve_events_with_store(
if event.room_id != room_id:
raise Exception(
"Attempting to state-resolve for room %s with event %s which is in %s"
- % (room_id, event.event_id, event.room_id,)
+ % (
+ room_id,
+ event.event_id,
+ event.room_id,
+ )
)
# get the ids of the auth events which allow us to authenticate the
@@ -119,7 +123,11 @@ async def resolve_events_with_store(
if event.room_id != room_id:
raise Exception(
"Attempting to state-resolve for room %s with event %s which is in %s"
- % (room_id, event.event_id, event.room_id,)
+ % (
+ room_id,
+ event.event_id,
+ event.room_id,
+ )
)
state_map.update(state_map_new)
@@ -243,7 +251,7 @@ def _resolve_with_state(
def _resolve_state_events(
conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
) -> StateMap[EventBase]:
- """ This is where we actually decide which of the conflicted state to
+ """This is where we actually decide which of the conflicted state to
use.
We resolve conflicts in the following order:
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index f57df0d728..e73a548ee4 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -97,7 +97,9 @@ async def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
- auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
+ auth_diff = await _get_auth_chain_difference(
+ room_id, state_sets, event_map, state_res_store
+ )
full_conflicted_set = set(
itertools.chain(
@@ -116,7 +118,11 @@ async def resolve_events_with_store(
if event.room_id != room_id:
raise Exception(
"Attempting to state-resolve for room %s with event %s which is in %s"
- % (room_id, event.event_id, event.room_id,)
+ % (
+ room_id,
+ event.event_id,
+ event.room_id,
+ )
)
full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
@@ -236,6 +242,7 @@ async def _get_power_level_for_sender(
async def _get_auth_chain_difference(
+ room_id: str,
state_sets: Sequence[StateMap[str]],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
@@ -252,9 +259,90 @@ async def _get_auth_chain_difference(
Set of event IDs
"""
+ # The `StateResolutionStore.get_auth_chain_difference` function assumes that
+ # all events passed to it (and their auth chains) have been persisted
+ # previously. This is not the case for any events in the `event_map`, and so
+ # we need to manually handle those events.
+ #
+ # We do this by:
+ # 1. calculating the auth chain difference for the state sets based on the
+ # events in `event_map` alone
+ # 2. replacing any events in the state_sets that are also in `event_map`
+ # with their auth events (recursively), and then calling
+ # `store.get_auth_chain_difference` as normal
+ # 3. adding the results of 1 and 2 together.
+
+ # Map from event ID in `event_map` to their auth event IDs, and their auth
+ # event IDs if they appear in the `event_map`. This is the intersection of
+ # the event's auth chain with the events in the `event_map` *plus* their
+ # auth event IDs.
+ events_to_auth_chain = {} # type: Dict[str, Set[str]]
+ for event in event_map.values():
+ chain = {event.event_id}
+ events_to_auth_chain[event.event_id] = chain
+
+ to_search = [event]
+ while to_search:
+ for auth_id in to_search.pop().auth_event_ids():
+ chain.add(auth_id)
+ auth_event = event_map.get(auth_id)
+ if auth_event:
+ to_search.append(auth_event)
+
+ # We now a) calculate the auth chain difference for the unpersisted events
+ # and b) work out the state sets to pass to the store.
+ #
+ # Note: If the `event_map` is empty (which is the common case), we can do a
+ # much simpler calculation.
+ if event_map:
+ # The list of state sets to pass to the store, where each state set is a set
+ # of the event ids making up the state. This is similar to `state_sets`,
+ # except that (a) we only have event ids, not the complete
+ # ((type, state_key)->event_id) mappings; and (b) we have stripped out
+ # unpersisted events and replaced them with the persisted events in
+ # their auth chain.
+ state_sets_ids = [] # type: List[Set[str]]
+
+ # For each state set, the unpersisted event IDs reachable (by their auth
+ # chain) from the events in that set.
+ unpersisted_set_ids = [] # type: List[Set[str]]
+
+ for state_set in state_sets:
+ set_ids = set() # type: Set[str]
+ state_sets_ids.append(set_ids)
+
+ unpersisted_ids = set() # type: Set[str]
+ unpersisted_set_ids.append(unpersisted_ids)
+
+ for event_id in state_set.values():
+ event_chain = events_to_auth_chain.get(event_id)
+ if event_chain is not None:
+ # We have an event in `event_map`. We add all the auth
+ # events that it references (that aren't also in `event_map`).
+ set_ids.update(e for e in event_chain if e not in event_map)
+
+ # We also add the full chain of unpersisted event IDs
+ # referenced by this state set, so that we can work out the
+ # auth chain difference of the unpersisted events.
+ unpersisted_ids.update(e for e in event_chain if e in event_map)
+ else:
+ set_ids.add(event_id)
+
+ # The auth chain difference of the unpersisted events of the state sets
+ # is calculated by taking the difference between the union and
+ # intersections.
+ union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
+ intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
+
+ difference_from_event_map = union - intersection # type: Collection[str]
+ else:
+ difference_from_event_map = ()
+ state_sets_ids = [set(state_set.values()) for state_set in state_sets]
+
difference = await state_res_store.get_auth_chain_difference(
- [set(state_set.values()) for state_set in state_sets]
+ room_id, state_sets_ids
)
+ difference.update(difference_from_event_map)
return difference
@@ -574,7 +662,7 @@ async def _get_mainline_depth_for_event(
# We do an iterative search, replacing `event with the power level in its
# auth events (if any)
while tmp_event:
- depth = mainline_map.get(event.event_id)
+ depth = mainline_map.get(tmp_event.event_id)
if depth is not None:
return depth
diff --git a/synapse/static/client/login/style.css b/synapse/static/client/login/style.css
index 83e4f6abc8..dd76714a92 100644
--- a/synapse/static/client/login/style.css
+++ b/synapse/static/client/login/style.css
@@ -31,6 +31,11 @@ form {
margin: 10px 0 0 0;
}
+ul.radiobuttons {
+ text-align: left;
+ list-style: none;
+}
+
/*
* Add some padding to the viewport.
*/
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index bbff3c8d5b..0b9007e51f 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -27,6 +27,7 @@ There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
+from typing import TYPE_CHECKING
from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore
@@ -34,14 +35,17 @@ from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage
-__all__ = ["DataStores", "DataStore"]
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+__all__ = ["Databases", "DataStore"]
class Storage:
- """The high level interfaces for talking to various storage layers.
- """
+ """The high level interfaces for talking to various storage layers."""
- def __init__(self, hs, stores: Databases):
+ def __init__(self, hs: "HomeServer", stores: Databases):
# We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level
# interfaces.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2b196ded1b..240905329f 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,14 +17,18 @@
import logging
import random
from abc import ABCMeta
-from typing import Any, Optional
+from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
-from synapse.types import Collection, get_domain_from_id
+from synapse.storage.types import Connection
+from synapse.types import Collection, StreamToken, get_domain_from_id
from synapse.util import json_decoder
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
self.db_pool = database
self.rand = random.SystemRandom()
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: StreamToken,
+ rows: Iterable[Any],
+ ) -> None:
pass
- def _invalidate_state_caches(self, room_id, members_changed):
+ def _invalidate_state_caches(
+ self, room_id: str, members_changed: Iterable[str]
+ ) -> None:
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
Args:
- room_id (str): Room where state changed
- members_changed (iterable[str]): The user_ids of members that have
- changed
+ room_id: Room where state changed
+ members_changed: The user_ids of members that have changed
"""
for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
@@ -64,7 +75,7 @@ class SQLBaseStore(metaclass=ABCMeta):
def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
- ):
+ ) -> None:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
@@ -88,12 +99,15 @@ class SQLBaseStore(metaclass=ABCMeta):
cache.invalidate(tuple(key))
-def db_to_json(db_content):
+def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""
Take some data from a database row and return a JSON-decoded object.
Args:
- db_content (memoryview|buffer|bytes|bytearray|unicode)
+ db_content: The JSON-encoded contents from the database.
+
+ Returns:
+ The object decoded from JSON.
"""
# psycopg2 on Python 3 returns memoryview objects, which we need to
# cast to bytes to decode
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 810721ebe9..ccb06aab39 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,29 +12,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.types import Connection
+from synapse.types import JsonDict
from synapse.util import json_encoder
from . import engines
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.storage.database import DatabasePool, LoggingTransaction
+
logger = logging.getLogger(__name__)
class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items"""
- def __init__(self, name):
+ def __init__(self, name: str):
self.name = name
self.total_item_count = 0
- self.total_duration_ms = 0
- self.avg_item_count = 0
- self.avg_duration_ms = 0
+ self.total_duration_ms = 0.0
+ self.avg_item_count = 0.0
+ self.avg_duration_ms = 0.0
- def update(self, item_count, duration_ms):
+ def update(self, item_count: int, duration_ms: float) -> None:
"""Update the stats after doing an update"""
self.total_item_count += item_count
self.total_duration_ms += duration_ms
@@ -44,7 +49,7 @@ class BackgroundUpdatePerformance:
self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
- def average_items_per_ms(self):
+ def average_items_per_ms(self) -> Optional[float]:
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
@@ -58,7 +63,7 @@ class BackgroundUpdatePerformance:
# changes in how long the update process takes.
return float(self.avg_item_count) / float(self.avg_duration_ms)
- def total_items_per_ms(self):
+ def total_items_per_ms(self) -> Optional[float]:
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
@@ -72,7 +77,7 @@ class BackgroundUpdatePerformance:
class BackgroundUpdater:
- """ Background updates are updates to the database that run in the
+ """Background updates are updates to the database that run in the
background. Each update processes a batch of data at once. We attempt to
limit the impact of each update by monitoring how long each batch takes to
process and autotuning the batch size.
@@ -83,21 +88,25 @@ class BackgroundUpdater:
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
- def __init__(self, hs, database):
+ def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
# if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str]
- self._background_update_performance = {}
- self._background_update_handlers = {}
+ self._background_update_performance = (
+ {}
+ ) # type: Dict[str, BackgroundUpdatePerformance]
+ self._background_update_handlers = (
+ {}
+ ) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
self._all_done = False
- def start_doing_background_updates(self):
+ def start_doing_background_updates(self) -> None:
run_as_background_process("background_updates", self.run_background_updates)
- async def run_background_updates(self, sleep=True):
+ async def run_background_updates(self, sleep: bool = True) -> None:
logger.info("Starting background schema updates")
while True:
if sleep:
@@ -148,9 +157,8 @@ class BackgroundUpdater:
return False
- async def has_completed_background_update(self, update_name) -> bool:
- """Check if the given background update has finished running.
- """
+ async def has_completed_background_update(self, update_name: str) -> bool:
+ """Check if the given background update has finished running."""
if self._all_done:
return True
@@ -173,8 +181,7 @@ class BackgroundUpdater:
Returns once some amount of work is done.
Args:
- desired_duration_ms(float): How long we want to spend
- updating.
+ desired_duration_ms: How long we want to spend updating.
Returns:
True if we have finished running all the background updates, otherwise False
"""
@@ -190,7 +197,8 @@ class BackgroundUpdater:
if not self._current_background_update:
all_pending_updates = await self.db_pool.runInteraction(
- "background_updates", get_background_updates_txn,
+ "background_updates",
+ get_background_updates_txn,
)
if not all_pending_updates:
# no work left to do
@@ -220,6 +228,7 @@ class BackgroundUpdater:
return False
async def _do_background_update(self, desired_duration_ms: float) -> int:
+ assert self._current_background_update is not None
update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
@@ -273,7 +282,11 @@ class BackgroundUpdater:
return len(self._background_update_performance)
- def register_background_update_handler(self, update_name, update_handler):
+ def register_background_update_handler(
+ self,
+ update_name: str,
+ update_handler: Callable[[JsonDict, int], Awaitable[int]],
+ ):
"""Register a handler for doing a background update.
The handler should take two arguments:
@@ -287,12 +300,12 @@ class BackgroundUpdater:
The handler is responsible for updating the progress of the update.
Args:
- update_name(str): The name of the update that this code handles.
- update_handler(function): The function that does the update.
+ update_name: The name of the update that this code handles.
+ update_handler: The function that does the update.
"""
self._background_update_handlers[update_name] = update_handler
- def register_noop_background_update(self, update_name):
+ def register_noop_background_update(self, update_name: str) -> None:
"""Register a noop handler for a background update.
This is useful when we previously did a background update, but no
@@ -302,10 +315,10 @@ class BackgroundUpdater:
also be called to clear the update.
Args:
- update_name (str): Name of update
+ update_name: Name of update
"""
- async def noop_update(progress, batch_size):
+ async def noop_update(progress: JsonDict, batch_size: int) -> int:
await self._end_background_update(update_name)
return 1
@@ -313,14 +326,14 @@ class BackgroundUpdater:
def register_background_index_update(
self,
- update_name,
- index_name,
- table,
- columns,
- where_clause=None,
- unique=False,
- psql_only=False,
- ):
+ update_name: str,
+ index_name: str,
+ table: str,
+ columns: Iterable[str],
+ where_clause: Optional[str] = None,
+ unique: bool = False,
+ psql_only: bool = False,
+ ) -> None:
"""Helper for store classes to do a background index addition
To use:
@@ -332,19 +345,19 @@ class BackgroundUpdater:
2. In the Store constructor, call this method
Args:
- update_name (str): update_name to register for
- index_name (str): name of index to add
- table (str): table to add index to
- columns (list[str]): columns/expressions to include in index
- unique (bool): true to make a UNIQUE index
+ update_name: update_name to register for
+ index_name: name of index to add
+ table: table to add index to
+ columns: columns/expressions to include in index
+ unique: true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables)
"""
- def create_index_psql(conn):
+ def create_index_psql(conn: Connection) -> None:
conn.rollback()
# postgres insists on autocommit for the index
- conn.set_session(autocommit=True)
+ conn.set_session(autocommit=True) # type: ignore
try:
c = conn.cursor()
@@ -371,9 +384,9 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql)
c.execute(sql)
finally:
- conn.set_session(autocommit=False)
+ conn.set_session(autocommit=False) # type: ignore
- def create_index_sqlite(conn):
+ def create_index_sqlite(conn: Connection) -> None:
# Sqlite doesn't support concurrent creation of indexes.
#
# We don't use partial indices on SQLite as it wasn't introduced
@@ -399,7 +412,7 @@ class BackgroundUpdater:
c.execute(sql)
if isinstance(self.db_pool.engine, engines.PostgresEngine):
- runner = create_index_psql
+ runner = create_index_psql # type: Optional[Callable[[Connection], None]]
elif psql_only:
runner = None
else:
@@ -433,7 +446,9 @@ class BackgroundUpdater:
"background_updates", keyvalues={"update_name": update_name}
)
- async def _background_update_progress(self, update_name: str, progress: dict):
+ async def _background_update_progress(
+ self, update_name: str, progress: dict
+ ) -> None:
"""Update the progress of a background update
Args:
@@ -441,20 +456,22 @@ class BackgroundUpdater:
progress: The progress of the update.
"""
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,
progress,
)
- def _background_update_progress_txn(self, txn, update_name, progress):
+ def _background_update_progress_txn(
+ self, txn: "LoggingTransaction", update_name: str, progress: JsonDict
+ ) -> None:
"""Update the progress of a background update
Args:
- txn(cursor): The transaction.
- update_name(str): The name of the background update task
- progress(dict): The progress of the update.
+ txn: The transaction.
+ update_name: The name of the background update task
+ progress: The progress of the update.
"""
progress_json = json_encoder.encode(progress)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d1b5760c2c..94590e7b45 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -42,7 +42,6 @@ from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import (
LoggingContext,
- LoggingContextOrSentinel,
current_context,
make_deferred_yieldable,
)
@@ -85,8 +84,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> adbapi.ConnectionPool:
- """Get the connection pool for the database.
- """
+ """Get the connection pool for the database."""
# By default enable `cp_reconnect`. We need to fiddle with db_args in case
# someone has explicitly set `cp_reconnect`.
@@ -158,8 +156,8 @@ class LoggingDatabaseConnection:
def commit(self) -> None:
self.conn.commit()
- def rollback(self, *args, **kwargs) -> None:
- self.conn.rollback(*args, **kwargs)
+ def rollback(self) -> None:
+ self.conn.rollback()
def __enter__(self) -> "Connection":
self.conn.__enter__()
@@ -180,6 +178,9 @@ class LoggingDatabaseConnection:
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
+R = TypeVar("R")
+
+
class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
@@ -241,12 +242,15 @@ class LoggingTransaction:
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
+ def fetchone(self) -> Optional[Tuple]:
+ return self.txn.fetchone()
+
+ def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
+ return self.txn.fetchmany(size=size)
+
def fetchall(self) -> List[Tuple]:
return self.txn.fetchall()
- def fetchone(self) -> Tuple:
- return self.txn.fetchone()
-
def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()
@@ -259,13 +263,32 @@ class LoggingTransaction:
return self.txn.description
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
+ """Similar to `executemany`, except `txn.rowcount` will not be correct
+ afterwards.
+
+ More efficient than `executemany` on PostgreSQL
+ """
+
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
- for val in args:
- self.execute(sql, val)
+ self.executemany(sql, args)
+
+ def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
+ """Corresponds to psycopg2.extras.execute_values. Only available when
+ using postgres.
+
+ Always sets fetch=True when caling `execute_values`, so will return the
+ results.
+ """
+ assert isinstance(self.database_engine, PostgresEngine)
+ from psycopg2.extras import execute_values # type: ignore
+
+ return self._do_execute(
+ lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
+ )
def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args)
@@ -277,7 +300,7 @@ class LoggingTransaction:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
- def _do_execute(self, func, sql: str, *args: Any) -> None:
+ def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -348,9 +371,6 @@ class PerformanceCounters:
return top_n_counters
-R = TypeVar("R")
-
-
class DatabasePool:
"""Wraps a single physical database and connection pool.
@@ -360,7 +380,10 @@ class DatabasePool:
_TXN_ID = 0
def __init__(
- self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ self,
+ hs,
+ database_config: DatabaseConnectionConfig,
+ engine: BaseDatabaseEngine,
):
self.hs = hs
self._clock = hs.get_clock()
@@ -400,8 +423,7 @@ class DatabasePool:
)
def is_running(self) -> bool:
- """Is the database pool currently running
- """
+ """Is the database pool currently running"""
return self._db_pool.running
async def _check_safe_to_upsert(self) -> None:
@@ -514,7 +536,11 @@ class DatabasePool:
# This can happen if the database disappears mid
# transaction.
transaction_logger.warning(
- "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
+ "[TXN OPERROR] {%s} %s %d/%d",
+ name,
+ e,
+ i,
+ N,
)
if i < N:
i += 1
@@ -535,7 +561,9 @@ class DatabasePool:
conn.rollback()
except self.engine.module.Error as e1:
transaction_logger.warning(
- "[TXN EROLL] {%s} %s", name, e1,
+ "[TXN EROLL] {%s} %s",
+ name,
+ e1,
)
continue
raise
@@ -642,7 +670,7 @@ class DatabasePool:
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
- except: # noqa: E722, as we reraise the exception this is fine.
+ except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
raise
@@ -671,12 +699,15 @@ class DatabasePool:
Returns:
The result of func
"""
- parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
- if not parent_context:
+ curr_context = current_context()
+ if not curr_context:
logger.warning(
"Starting db connection from sentinel context: metrics will be lost"
)
parent_context = None
+ else:
+ assert isinstance(curr_context, LoggingContext)
+ parent_context = curr_context
start_time = monotonic_time()
@@ -722,6 +753,7 @@ class DatabasePool:
Returns:
A list of dicts where the key is the column header.
"""
+ assert cursor.description is not None, "cursor.description was None"
col_headers = [intern(str(column[0])) for column in cursor.description]
results = [dict(zip(col_headers, row)) for row in cursor]
return results
@@ -861,7 +893,7 @@ class DatabasePool:
", ".join("?" for _ in keys[0]),
)
- txn.executemany(sql, vals)
+ txn.execute_batch(sql, vals)
async def simple_upsert(
self,
@@ -1370,7 +1402,10 @@ class DatabasePool:
@staticmethod
def simple_select_onecol_txn(
- txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
) -> List[Any]:
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
@@ -1680,7 +1715,11 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
"""
await self.runInteraction(
- desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
+ desc,
+ self.simple_delete_one_txn,
+ table,
+ keyvalues,
+ db_autocommit=True,
)
@staticmethod
@@ -1867,6 +1906,7 @@ class DatabasePool:
retcols: Iterable[str],
filters: Optional[Dict[str, Any]] = None,
keyvalues: Optional[Dict[str, Any]] = None,
+ exclude_keyvalues: Optional[Dict[str, Any]] = None,
order_direction: str = "ASC",
) -> List[Dict[str, Any]]:
"""
@@ -1890,7 +1930,10 @@ class DatabasePool:
apply a WHERE ? LIKE ? clause.
keyvalues:
column names and values to select the rows with, or None to not
- apply a WHERE clause.
+ apply a WHERE key = value clause.
+ exclude_keyvalues:
+ column names and values to exclude rows with, or None to not
+ apply a WHERE key != value clause.
order_direction: Whether the results should be ordered "ASC" or "DESC".
Returns:
@@ -1899,7 +1942,7 @@ class DatabasePool:
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
- where_clause = "WHERE " if filters or keyvalues else ""
+ where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
arg_list = [] # type: List[Any]
if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
@@ -1908,6 +1951,9 @@ class DatabasePool:
if keyvalues:
where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
arg_list += list(keyvalues.values())
+ if exclude_keyvalues:
+ where_clause += " AND ".join("%s != ?" % (k,) for k in exclude_keyvalues)
+ arg_list += list(exclude_keyvalues.values())
sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
", ".join(retcols),
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 0c24325011..379c78bb83 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -56,7 +56,10 @@ class Databases:
database_config.databases,
)
prepare_database(
- db_conn, engine, hs.config, databases=database_config.databases,
+ db_conn,
+ engine,
+ hs.config,
+ databases=database_config.databases,
)
database = DatabasePool(hs, database_config, engine)
@@ -76,7 +79,7 @@ class Databases:
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
if hs.get_instance_name() in hs.config.worker.writers.events:
- persist_events = PersistEventsStore(hs, database, main)
+ persist_events = PersistEventsStore(hs, database, main, db_conn)
if "state" in database_config.databases:
logger.info(
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 43660ec4fb..1d44c3aa2c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import List, Optional, Tuple
from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
@@ -27,7 +27,7 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
-from synapse.types import get_domain_from_id
+from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
@@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
+from .events_forward_extremities import EventForwardExtremitiesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
@@ -118,6 +119,7 @@ class DataStore(
UIAuthStore,
CacheInvalidationWorkerStore,
ServerMetricsStore,
+ EventForwardExtremitiesStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
@@ -127,9 +129,6 @@ class DataStore(
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
- self._device_inbox_id_gen = StreamIdGenerator(
- db_conn, "device_inbox", "stream_id"
- )
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
)
@@ -149,9 +148,6 @@ class DataStore(
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- self._pushers_id_gen = StreamIdGenerator(
- db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
- )
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id"
)
@@ -166,9 +162,13 @@ class DataStore(
database,
stream_name="caches",
instance_name=hs.get_instance_name(),
- table="cache_invalidation_stream_by_instance",
- instance_column="instance_name",
- id_column="stream_id",
+ tables=[
+ (
+ "cache_invalidation_stream_by_instance",
+ "instance_name",
+ "stream_id",
+ )
+ ],
sequence_name="cache_invalidation_stream_seq",
writers=[],
)
@@ -192,36 +192,6 @@ class DataStore(
prefilled_cache=presence_cache_prefill,
)
- max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
- device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
- db_conn,
- "device_inbox",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=max_device_inbox_id,
- limit=1000,
- )
- self._device_inbox_stream_cache = StreamChangeCache(
- "DeviceInboxStreamChangeCache",
- min_device_inbox_id,
- prefilled_cache=device_inbox_prefill,
- )
- # The federation outbox and the local device inbox uses the same
- # stream_id generator.
- device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
- db_conn,
- "device_federation_outbox",
- entity_column="destination",
- stream_column="stream_id",
- max_value=max_device_inbox_id,
- limit=1000,
- )
- self._device_federation_outbox_stream_cache = StreamChangeCache(
- "DeviceFederationOutboxStreamChangeCache",
- min_device_outbox_id,
- prefilled_cache=device_outbox_prefill,
- )
-
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max
@@ -294,7 +264,7 @@ class DataStore(
return [UserPresenceState(**row) for row in rows]
- async def get_users(self) -> List[Dict[str, Any]]:
+ async def get_users(self) -> List[JsonDict]:
"""Function to retrieve a list of users in users table.
Returns:
@@ -322,7 +292,7 @@ class DataStore(
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
- ) -> Tuple[List[Dict[str, Any]], int]:
+ ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
@@ -342,12 +312,13 @@ class DataStore(
filters = []
args = [self.hs.config.server_name]
+ # `name` is in database already in lower case
if name:
- filters.append("(name LIKE ? OR displayname LIKE ?)")
- args.extend(["@%" + name + "%:%", "%" + name + "%"])
+ filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
+ args.extend(["@%" + name.lower() + "%:%", "%" + name.lower() + "%"])
elif user_id:
filters.append("name LIKE ?")
- args.extend(["%" + user_id + "%"])
+ args.extend(["%" + user_id.lower() + "%"])
if not guests:
filters.append("is_guest = 0")
@@ -369,7 +340,7 @@ class DataStore(
count = txn.fetchone()[0]
sql = (
- "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+ "SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url "
+ sql_base
+ " ORDER BY u.name LIMIT ? OFFSET ?"
)
@@ -382,7 +353,7 @@ class DataStore(
"get_users_paginate_txn", get_users_paginate_txn
)
- async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
+ async def search_users(self, term: str) -> Optional[List[JsonDict]]:
"""Function to search users list for one or more users with
the matched term.
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 49ee23470d..a277a1ef13 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,30 +14,75 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
import logging
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Set, Tuple
from synapse.api.constants import AccountDataTypes
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import _CacheContext, cached
+from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
+class AccountDataWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
+ self._instance_name = hs.get_instance_name()
+
+ if isinstance(database.engine, PostgresEngine):
+ self._can_write_to_account_data = (
+ self._instance_name in hs.config.worker.writers.account_data
+ )
+
+ self._account_data_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="account_data",
+ instance_name=self._instance_name,
+ tables=[
+ ("room_account_data", "instance_name", "stream_id"),
+ ("room_tags_revisions", "instance_name", "stream_id"),
+ ("account_data", "instance_name", "stream_id"),
+ ],
+ sequence_name="account_data_sequence",
+ writers=hs.config.worker.writers.account_data,
+ )
+ else:
+ self._can_write_to_account_data = True
+
+ # We shouldn't be running in worker mode with SQLite, but its useful
+ # to support it for unit tests.
+ #
+ # If this process is the writer than we need to use
+ # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+ # updated over replication. (Multiple writers are not supported for
+ # SQLite).
+ if hs.get_instance_name() in hs.config.worker.writers.account_data:
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn,
+ "room_account_data",
+ "stream_id",
+ extra_tables=[("room_tags_revisions", "stream_id")],
+ )
+ else:
+ self._account_data_id_gen = SlavedIdTracker(
+ db_conn,
+ "room_account_data",
+ "stream_id",
+ extra_tables=[("room_tags_revisions", "stream_id")],
+ )
+
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
@@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
super().__init__(database, db_conn, hs)
- @abc.abstractmethod
- def get_max_account_data_stream_id(self):
+ def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream ID for account data stream
Returns:
int
"""
- raise NotImplementedError()
+ return self._account_data_id_gen.get_current_token()
@cached()
async def get_account_data_for_user(
@@ -287,46 +331,46 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
- @cached(num_args=2, cache_context=True, max_entries=5000)
- async def is_ignored_by(
- self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
- ) -> bool:
- ignored_account_data = await self.get_global_account_data_by_type_for_user(
- AccountDataTypes.IGNORED_USER_LIST,
- ignorer_user_id,
- on_invalidate=cache_context.invalidate,
- )
- if not ignored_account_data:
- return False
-
- try:
- return ignored_user_id in ignored_account_data.get("ignored_users", {})
- except TypeError:
- # The type of the ignored_users field is invalid.
- return False
+ @cached(max_entries=5000, iterable=True)
+ async def ignored_by(self, user_id: str) -> Set[str]:
+ """
+ Get users which ignore the given user.
+ Params:
+ user_id: The user ID which might be ignored.
-class AccountDataStore(AccountDataWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- self._account_data_id_gen = StreamIdGenerator(
- db_conn,
- "account_data_max_stream_id",
- "stream_id",
- extra_tables=[
- ("room_account_data", "stream_id"),
- ("room_tags_revisions", "stream_id"),
- ],
+ Return:
+ The user IDs which ignore the given user.
+ """
+ return set(
+ await self.db_pool.simple_select_onecol(
+ table="ignored_users",
+ keyvalues={"ignored_user_id": user_id},
+ retcol="ignorer_user_id",
+ desc="ignored_by",
+ )
)
- super().__init__(database, db_conn, hs)
-
- def get_max_account_data_stream_id(self) -> int:
- """Get the current max stream id for the private user data stream
-
- Returns:
- The maximum stream ID.
- """
- return self._account_data_id_gen.get_current_token()
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == TagAccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ for row in rows:
+ self.get_tags_for_user.invalidate((row.user_id,))
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+ elif stream_name == AccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ for row in rows:
+ if not row.room_id:
+ self.get_global_account_data_by_type_for_user.invalidate(
+ (row.data_type, row.user_id)
+ )
+ self.get_account_data_for_user.invalidate((row.user_id,))
+ self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
+ self.get_account_data_for_room_and_type.invalidate(
+ (row.user_id, row.room_id, row.data_type)
+ )
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -342,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns:
The maximum stream ID.
"""
+ assert self._can_write_to_account_data
+
content_json = json_encoder.encode(content)
async with self._account_data_id_gen.get_next() as next_id:
@@ -360,14 +406,6 @@ class AccountDataStore(AccountDataWorkerStore):
lock=False,
)
- # it's theoretically possible for the above to succeed and the
- # below to fail - in which case we might reuse a stream id on
- # restart, and the above update might not get propagated. That
- # doesn't sound any worse than the whole update getting lost,
- # which is what would happen if we combined the two into one
- # transaction.
- await self._update_max_stream_id(next_id)
-
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id))
@@ -390,32 +428,18 @@ class AccountDataStore(AccountDataWorkerStore):
Returns:
The maximum stream ID.
"""
- content_json = json_encoder.encode(content)
+ assert self._can_write_to_account_data
async with self._account_data_id_gen.get_next() as next_id:
- # no need to lock here as account_data has a unique constraint on
- # (user_id, account_data_type) so simple_upsert will retry if
- # there is a conflict.
- await self.db_pool.simple_upsert(
- desc="add_user_account_data",
- table="account_data",
- keyvalues={"user_id": user_id, "account_data_type": account_data_type},
- values={"stream_id": next_id, "content": content_json},
- lock=False,
+ await self.db_pool.runInteraction(
+ "add_user_account_data",
+ self._add_account_data_for_user,
+ next_id,
+ user_id,
+ account_data_type,
+ content,
)
- # it's theoretically possible for the above to succeed and the
- # below to fail - in which case we might reuse a stream id on
- # restart, and the above update might not get propagated. That
- # doesn't sound any worse than the whole update getting lost,
- # which is what would happen if we combined the two into one
- # transaction.
- #
- # Note: This is only here for backwards compat to allow admins to
- # roll back to a previous Synapse version. Next time we update the
- # database version we can remove this table.
- await self._update_max_stream_id(next_id)
-
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
@@ -424,23 +448,71 @@ class AccountDataStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token()
- async def _update_max_stream_id(self, next_id: int) -> None:
- """Update the max stream_id
+ def _add_account_data_for_user(
+ self,
+ txn,
+ next_id: int,
+ user_id: str,
+ account_data_type: str,
+ content: JsonDict,
+ ) -> None:
+ content_json = json_encoder.encode(content)
- Args:
- next_id: The the revision to advance to.
- """
+ # no need to lock here as account_data has a unique constraint on
+ # (user_id, account_data_type) so simple_upsert will retry if
+ # there is a conflict.
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="account_data",
+ keyvalues={"user_id": user_id, "account_data_type": account_data_type},
+ values={"stream_id": next_id, "content": content_json},
+ lock=False,
+ )
- # Note: This is only here for backwards compat to allow admins to
- # roll back to a previous Synapse version. Next time we update the
- # database version we can remove this table.
+ # Ignored users get denormalized into a separate table as an optimisation.
+ if account_data_type != AccountDataTypes.IGNORED_USER_LIST:
+ return
- def _update(txn):
- update_max_id_sql = (
- "UPDATE account_data_max_stream_id"
- " SET stream_id = ?"
- " WHERE stream_id < ?"
+ # Insert / delete to sync the list of ignored users.
+ previously_ignored_users = set(
+ self.db_pool.simple_select_onecol_txn(
+ txn,
+ table="ignored_users",
+ keyvalues={"ignorer_user_id": user_id},
+ retcol="ignored_user_id",
)
- txn.execute(update_max_id_sql, (next_id, next_id))
+ )
+
+ # If the data is invalid, no one is ignored.
+ ignored_users_content = content.get("ignored_users", {})
+ if isinstance(ignored_users_content, dict):
+ currently_ignored_users = set(ignored_users_content)
+ else:
+ currently_ignored_users = set()
+
+ # Delete entries which are no longer ignored.
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="ignored_users",
+ column="ignored_user_id",
+ iterable=previously_ignored_users - currently_ignored_users,
+ keyvalues={"ignorer_user_id": user_id},
+ )
- await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
+ # Add entries which are newly ignored.
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="ignored_users",
+ values=[
+ {"ignorer_user_id": user_id, "ignored_user_id": u}
+ for u in currently_ignored_users - previously_ignored_users
+ ],
+ )
+
+ # Invalidate the cache for any ignored users which were added or removed.
+ for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
+ self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+
+
+class AccountDataStore(AccountDataWorkerStore):
+ pass
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e550cbc866..85bb853d33 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -32,7 +32,7 @@ from synapse.types import JsonDict
from synapse.util import json_encoder
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -73,8 +73,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
- """Check if the user is one associated with an app service (exclusively)
- """
+ """Check if the user is one associated with an app service (exclusively)"""
if self.exclusive_user_regex:
return bool(self.exclusive_user_regex.match(user_id))
else:
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 339bd691a4..6d18e692b0 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -14,11 +14,12 @@
# limitations under the License.
import logging
-from typing import Dict, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.types import UserID
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -279,8 +280,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return batch_size
async def _devices_last_seen_update(self, progress, batch_size):
- """Background update to insert last seen info into devices table
- """
+ """Background update to insert last seen info into devices table"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
@@ -362,8 +362,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self):
- """Removes entries in user IPs older than the configured period.
- """
+ """Removes entries in user IPs older than the configured period."""
if self.user_ips_max_age is None:
# Nothing to do
@@ -406,6 +405,34 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
"_prune_old_user_ips", _prune_old_user_ips_txn
)
+ async def get_last_client_ip_by_device(
+ self, user_id: str, device_id: Optional[str]
+ ) -> Dict[Tuple[str, str], dict]:
+ """For each device_id listed, give the user_ip it was last seen on.
+
+ The result might be slightly out of date as client IPs are inserted in batches.
+
+ Args:
+ user_id: The user to fetch devices for.
+ device_id: If None fetches all devices for the user
+
+ Returns:
+ A dictionary mapping a tuple of (user_id, device_id) to dicts, with
+ keys giving the column names from the devices table.
+ """
+
+ keyvalues = {"user_id": user_id}
+ if device_id is not None:
+ keyvalues["device_id"] = device_id
+
+ res = await self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues=keyvalues,
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ )
+
+ return {(d["user_id"], d["device_id"]): d for d in res}
+
class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -469,43 +496,35 @@ class ClientIpStore(ClientIpWorkerStore):
for entry in to_update.items():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
- try:
- self.db_pool.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="user_ips",
+ keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
+ values={
+ "user_agent": user_agent,
+ "device_id": device_id,
+ "last_seen": last_seen,
+ },
+ lock=False,
+ )
+
+ # Technically an access token might not be associated with
+ # a device so we need to check.
+ if device_id:
+ # this is always an update rather than an upsert: the row should
+ # already exist, and if it doesn't, that may be because it has been
+ # deleted, and we don't want to re-create it.
+ self.db_pool.simple_update_txn(
txn,
- table="user_ips",
- keyvalues={
- "user_id": user_id,
- "access_token": access_token,
- "ip": ip,
- },
- values={
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ updatevalues={
"user_agent": user_agent,
- "device_id": device_id,
"last_seen": last_seen,
+ "ip": ip,
},
- lock=False,
)
- # Technically an access token might not be associated with
- # a device so we need to check.
- if device_id:
- # this is always an update rather than an upsert: the row should
- # already exist, and if it doesn't, that may be because it has been
- # deleted, and we don't want to re-create it.
- self.db_pool.simple_update_txn(
- txn,
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
- updatevalues={
- "user_agent": user_agent,
- "last_seen": last_seen,
- "ip": ip,
- },
- )
- except Exception as e:
- # Failed to upsert, log and continue
- logger.error("Failed to insert client IP %r: %r", entry, e)
-
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], dict]:
@@ -519,18 +538,9 @@ class ClientIpStore(ClientIpWorkerStore):
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
"""
+ ret = await super().get_last_client_ip_by_device(user_id, device_id)
- keyvalues = {"user_id": user_id}
- if device_id is not None:
- keyvalues["device_id"] = device_id
-
- res = await self.db_pool.simple_select_list(
- table="devices",
- keyvalues=keyvalues,
- retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
- )
-
- ret = {(d["user_id"], d["device_id"]): d for d in res}
+ # Update what is retrieved from the database with data which is pending insertion.
for key in self._batch_row_update:
uid, access_token, ip = key
if uid == user_id:
@@ -546,12 +556,18 @@ class ClientIpStore(ClientIpWorkerStore):
}
return ret
- async def get_user_ip_and_agents(self, user):
+ async def get_user_ip_and_agents(
+ self, user: UserID
+ ) -> List[Dict[str, Union[str, int]]]:
user_id = user.to_string()
results = {}
for key in self._batch_row_update:
- uid, access_token, ip, = key
+ (
+ uid,
+ access_token,
+ ip,
+ ) = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index d42faa3f1f..691080ce74 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,25 +14,108 @@
# limitations under the License.
import logging
-from typing import List, Tuple
+from typing import List, Optional, Tuple
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.replication.tcp.streams import ToDeviceStream
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self._instance_name = hs.get_instance_name()
+
+ # Map of (user_id, device_id) to the last stream_id that has been
+ # deleted up to. This is so that we can no op deletions.
+ self._last_device_delete_cache = ExpiringCache(
+ cache_name="last_device_delete_cache",
+ clock=self._clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ )
+
+ if isinstance(database.engine, PostgresEngine):
+ self._can_write_to_device = (
+ self._instance_name in hs.config.worker.writers.to_device
+ )
+
+ self._device_inbox_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="to_device",
+ instance_name=self._instance_name,
+ tables=[("device_inbox", "instance_name", "stream_id")],
+ sequence_name="device_inbox_sequence",
+ writers=hs.config.worker.writers.to_device,
+ )
+ else:
+ self._can_write_to_device = True
+ self._device_inbox_id_gen = StreamIdGenerator(
+ db_conn, "device_inbox", "stream_id"
+ )
+
+ max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
+ device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_inbox",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=max_device_inbox_id,
+ limit=1000,
+ )
+ self._device_inbox_stream_cache = StreamChangeCache(
+ "DeviceInboxStreamChangeCache",
+ min_device_inbox_id,
+ prefilled_cache=device_inbox_prefill,
+ )
+
+ # The federation outbox and the local device inbox uses the same
+ # stream_id generator.
+ device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_federation_outbox",
+ entity_column="destination",
+ stream_column="stream_id",
+ max_value=max_device_inbox_id,
+ limit=1000,
+ )
+ self._device_federation_outbox_stream_cache = StreamChangeCache(
+ "DeviceFederationOutboxStreamChangeCache",
+ min_device_outbox_id,
+ prefilled_cache=device_outbox_prefill,
+ )
+
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == ToDeviceStream.NAME:
+ self._device_inbox_id_gen.advance(instance_name, token)
+ for row in rows:
+ if row.entity.startswith("@"):
+ self._device_inbox_stream_cache.entity_has_changed(
+ row.entity, token
+ )
+ else:
+ self._device_federation_outbox_stream_cache.entity_has_changed(
+ row.entity, token
+ )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
+
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
async def get_new_messages_for_device(
self,
user_id: str,
- device_id: str,
+ device_id: Optional[str],
last_stream_id: int,
current_stream_id: int,
limit: int = 100,
@@ -80,7 +163,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
@trace
async def delete_messages_for_device(
- self, user_id: str, device_id: str, up_to_stream_id: int
+ self, user_id: str, device_id: Optional[str], up_to_stream_id: int
) -> int:
"""
Args:
@@ -278,52 +361,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"get_all_new_device_messages", get_all_new_device_messages_txn
)
-
-class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
- DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
-
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
-
- self.db_pool.updates.register_background_index_update(
- "device_inbox_stream_index",
- index_name="device_inbox_stream_id_user_id",
- table="device_inbox",
- columns=["stream_id", "user_id"],
- )
-
- self.db_pool.updates.register_background_update_handler(
- self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
- )
-
- async def _background_drop_index_device_inbox(self, progress, batch_size):
- def reindex_txn(conn):
- txn = conn.cursor()
- txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
- txn.close()
-
- await self.db_pool.runWithConnection(reindex_txn)
-
- await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
-
- return 1
-
-
-class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
- DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
-
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
-
- # Map of (user_id, device_id) to the last stream_id that has been
- # deleted up to. This is so that we can no op deletions.
- self._last_device_delete_cache = ExpiringCache(
- cache_name="last_device_delete_cache",
- clock=self._clock,
- max_len=10000,
- expiry_ms=30 * 60 * 1000,
- )
-
@trace
async def add_messages_to_device_inbox(
self,
@@ -342,6 +379,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
The new stream_id.
"""
+ assert self._can_write_to_device
+
def add_messages_txn(txn, now_ms, stream_id):
# Add the local messages directly to the local inbox.
self._add_messages_to_local_device_inbox_txn(
@@ -351,16 +390,20 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Add the remote messages to the federation outbox.
# We'll send them to a remote server when we next send a
# federation transaction to that destination.
- sql = (
- "INSERT INTO device_federation_outbox"
- " (destination, stream_id, queued_ts, messages_json)"
- " VALUES (?,?,?,?)"
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="device_federation_outbox",
+ values=[
+ {
+ "destination": destination,
+ "stream_id": stream_id,
+ "queued_ts": now_ms,
+ "messages_json": json_encoder.encode(edu),
+ "instance_name": self._instance_name,
+ }
+ for destination, edu in remote_messages_by_destination.items()
+ ],
)
- rows = []
- for destination, edu in remote_messages_by_destination.items():
- edu_json = json_encoder.encode(edu)
- rows.append((destination, stream_id, now_ms, edu_json))
- txn.executemany(sql, rows)
async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
@@ -379,6 +422,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
async def add_messages_from_remote_to_device_inbox(
self, origin: str, message_id: str, local_messages_by_user_then_device: dict
) -> int:
+ assert self._can_write_to_device
+
def add_messages_txn(txn, now_ms, stream_id):
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
@@ -405,7 +450,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
},
)
- # Add the messages to the approriate local device inboxes so that
+ # Add the messages to the appropriate local device inboxes so that
# they'll be sent to the devices when they next sync.
self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
@@ -427,38 +472,45 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
+ assert self._can_write_to_device
+
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
- sql = "SELECT device_id FROM devices WHERE user_id = ?"
- txn.execute(sql, (user_id,))
+ devices = self.db_pool.simple_select_onecol_txn(
+ txn,
+ table="devices",
+ keyvalues={"user_id": user_id},
+ retcol="device_id",
+ )
+
message_json = json_encoder.encode(messages_by_device["*"])
- for row in txn:
+ for device_id in devices:
# Add the message for all devices for this user on this
# server.
- device = row[0]
- messages_json_for_user[device] = message_json
+ messages_json_for_user[device_id] = message_json
else:
if not devices:
continue
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "device_id", devices
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ keyvalues={"user_id": user_id},
+ column="device_id",
+ iterable=devices,
+ retcols=("device_id",),
)
- sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
- # TODO: Maybe this needs to be done in batches if there are
- # too many local devices for a given user.
- txn.execute(sql, [user_id] + list(args))
- for row in txn:
+ for row in rows:
# Only insert into the local inbox if the device exists on
# this server
- device = row[0]
- message_json = json_encoder.encode(messages_by_device[device])
- messages_json_for_user[device] = message_json
+ device_id = row["device_id"]
+ message_json = json_encoder.encode(messages_by_device[device_id])
+ messages_json_for_user[device_id] = message_json
if messages_json_for_user:
local_by_user_then_device[user_id] = messages_json_for_user
@@ -466,14 +518,52 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
if not local_by_user_then_device:
return
- sql = (
- "INSERT INTO device_inbox"
- " (user_id, device_id, stream_id, message_json)"
- " VALUES (?,?,?,?)"
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="device_inbox",
+ values=[
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "stream_id": stream_id,
+ "message_json": message_json,
+ "instance_name": self._instance_name,
+ }
+ for user_id, messages_by_device in local_by_user_then_device.items()
+ for device_id, message_json in messages_by_device.items()
+ ],
)
- rows = []
- for user_id, messages_by_device in local_by_user_then_device.items():
- for device_id, message_json in messages_by_device.items():
- rows.append((user_id, device_id, stream_id, message_json))
- txn.executemany(sql, rows)
+
+class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
+ DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_index_update(
+ "device_inbox_stream_index",
+ index_name="device_inbox_stream_id_user_id",
+ table="device_inbox",
+ columns=["stream_id", "user_id"],
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
+ )
+
+ async def _background_drop_index_device_inbox(self, progress, batch_size):
+ def reindex_txn(conn):
+ txn = conn.cursor()
+ txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
+ txn.close()
+
+ await self.db_pool.runWithConnection(reindex_txn)
+
+ await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+
+ return 1
+
+
+class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
+ pass
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index dfb4f87b8f..d327e9aa0b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -57,6 +57,38 @@ class DeviceWorkerStore(SQLBaseStore):
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
+ async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+ """Retrieve number of all devices of given users.
+ Only returns number of devices that are not marked as hidden.
+
+ Args:
+ user_ids: The IDs of the users which owns devices
+ Returns:
+ Number of devices of this users.
+ """
+
+ def count_devices_by_users_txn(txn, user_ids):
+ sql = """
+ SELECT count(*)
+ FROM devices
+ WHERE
+ hidden = '0' AND
+ """
+
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "user_id", user_ids
+ )
+
+ txn.execute(sql + clause, args)
+ return txn.fetchone()[0]
+
+ if not user_ids:
+ return 0
+
+ return await self.db_pool.runInteraction(
+ "count_devices_by_users", count_devices_by_users_txn, user_ids
+ )
+
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -283,7 +315,8 @@ class DeviceWorkerStore(SQLBaseStore):
# make sure we go through the devices in stream order
device_ids = sorted(
- user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
+ user_devices.keys(),
+ key=lambda i: query_map[(user_id, i)][0],
)
for device_id in device_ids:
@@ -334,8 +367,7 @@ class DeviceWorkerStore(SQLBaseStore):
async def mark_as_sent_devices_by_remote(
self, destination: str, stream_id: int
) -> None:
- """Mark that updates have successfully been sent to the destination.
- """
+ """Mark that updates have successfully been sent to the destination."""
await self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
@@ -649,7 +681,8 @@ class DeviceWorkerStore(SQLBaseStore):
return results
async def get_user_ids_requiring_device_list_resync(
- self, user_ids: Optional[Collection[str]] = None,
+ self,
+ user_ids: Optional[Collection[str]] = None,
) -> Set[str]:
"""Given a list of remote users return the list of users that we
should resync the device lists for. If None is given instead of a list,
@@ -689,8 +722,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
- """Mark that we no longer track device lists for remote user.
- """
+ """Mark that we no longer track device lists for remote user."""
def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
self.db_pool.simple_delete_txn(
@@ -865,12 +897,13 @@ class DeviceWorkerStore(SQLBaseStore):
DELETE FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ?
"""
- txn.executemany(sql, ((row[0], row[1]) for row in rows))
+ txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
logger.info("Pruned %d device list outbound pokes", count)
await self.db_pool.runInteraction(
- "_prune_old_outbound_device_pokes", _prune_txn,
+ "_prune_old_outbound_device_pokes",
+ _prune_txn,
)
@@ -911,7 +944,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
# clear out duplicate device list outbound pokes
self.db_pool.updates.register_background_update_handler(
- BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
+ self._remove_duplicate_outbound_pokes,
)
# a pair of background updates that were added during the 1.14 release cycle,
@@ -972,17 +1006,23 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
row = None
for row in rows:
self.db_pool.simple_delete_txn(
- txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
+ txn,
+ "device_lists_outbound_pokes",
+ {x: row[x] for x in KEY_COLS},
)
row["sent"] = False
self.db_pool.simple_insert_txn(
- txn, "device_lists_outbound_pokes", row,
+ txn,
+ "device_lists_outbound_pokes",
+ row,
)
if row:
self.db_pool.updates._background_update_progress_txn(
- txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
+ txn,
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
+ {"last_row": row},
)
return len(rows)
@@ -1254,7 +1294,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# we've done a full resync, so we remove the entry that says we need
# to resync
self.db_pool.simple_delete_txn(
- txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
+ txn,
+ table="device_lists_remote_resync",
+ keyvalues={"user_id": user_id},
)
async def add_device_change_to_streams(
@@ -1304,14 +1346,16 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_ids: List[str],
):
txn.call_after(
- self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
+ self._device_list_stream_cache.entity_has_changed,
+ user_id,
+ stream_ids[-1],
)
min_stream_id = stream_ids[0]
# Delete older entries in the table, as we really only care about
# when the latest change happened.
- txn.executemany(
+ txn.execute_batch(
"""
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index e5060d4c46..267b948397 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -85,7 +85,7 @@ class DirectoryStore(DirectoryWorkerStore):
servers: Iterable[str],
creator: Optional[str] = None,
) -> None:
- """ Creates an association between a room alias and room_id/servers
+ """Creates an association between a room alias and room_id/servers
Args:
room_alias: The alias to create.
@@ -160,7 +160,10 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
async def update_aliases_for_room(
- self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
+ self,
+ old_room_id: str,
+ new_room_id: str,
+ creator: Optional[str] = None,
) -> None:
"""Repoint all of the aliases for a given room, to a different room.
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4d1b92d1aa..f1e7859d26 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -360,7 +361,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def count_e2e_one_time_keys(
self, user_id: str, device_id: str
) -> Dict[str, int]:
- """ Count the number of one time keys the server has for a device
+ """Count the number of one time keys the server has for a device
Returns:
A mapping from algorithm to number of keys for that algorithm.
"""
@@ -493,7 +494,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
)
def _get_bare_e2e_cross_signing_keys_bulk_txn(
- self, txn: Connection, user_ids: List[str],
+ self,
+ txn: Connection,
+ user_ids: List[str],
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
@@ -513,21 +516,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
for user_chunk in batch_iter(user_ids, 100):
clause, params = make_in_list_sql_clause(
- txn.database_engine, "k.user_id", user_chunk
- )
- sql = (
- """
- SELECT k.user_id, k.keytype, k.keydata, k.stream_id
- FROM e2e_cross_signing_keys k
- INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
- FROM e2e_cross_signing_keys
- GROUP BY user_id, keytype) s
- USING (user_id, stream_id, keytype)
- WHERE
- """
- + clause
+ txn.database_engine, "user_id", user_chunk
)
+ # Fetch the latest key for each type per user.
+ if isinstance(self.database_engine, PostgresEngine):
+ # The `DISTINCT ON` clause will pick the *first* row it
+ # encounters, so ordering by stream ID desc will ensure we get
+ # the latest key.
+ sql = """
+ SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
+ FROM e2e_cross_signing_keys
+ WHERE %(clause)s
+ ORDER BY user_id, keytype, stream_id DESC
+ """ % {
+ "clause": clause
+ }
+ else:
+ # SQLite has special handling for bare columns when using
+ # MIN/MAX with a `GROUP BY` clause where it picks the value from
+ # a row that matches the MIN/MAX.
+ sql = """
+ SELECT user_id, keytype, keydata, MAX(stream_id)
+ FROM e2e_cross_signing_keys
+ WHERE %(clause)s
+ GROUP BY user_id, keytype
+ """ % {
+ "clause": clause
+ }
+
txn.execute(sql, params)
rows = self.db_pool.cursor_to_dict(txn)
@@ -541,7 +558,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
return result
def _get_e2e_cross_signing_signatures_txn(
- self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
+ self,
+ txn: Connection,
+ keys: Dict[str, Dict[str, dict]],
+ from_user_id: str,
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing signatures made by a user on a set of keys.
@@ -619,7 +639,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Optional[Dict[str, dict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -707,53 +727,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
"""Get the current stream id from the _device_list_id_gen"""
...
-
-class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
- async def set_e2e_device_keys(
- self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
- ) -> bool:
- """Stores device keys for a device. Returns whether there was a change
- or the keys were already in the database.
- """
-
- def _set_e2e_device_keys_txn(txn):
- set_tag("user_id", user_id)
- set_tag("device_id", device_id)
- set_tag("time_now", time_now)
- set_tag("device_keys", device_keys)
-
- old_key_json = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- retcol="key_json",
- allow_none=True,
- )
-
- # In py3 we need old_key_json to match new_key_json type. The DB
- # returns unicode while encode_canonical_json returns bytes.
- new_key_json = encode_canonical_json(device_keys).decode("utf-8")
-
- if old_key_json == new_key_json:
- log_kv({"Message": "Device key already stored."})
- return False
-
- self.db_pool.simple_upsert_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- values={"ts_added_ms": time_now, "key_json": new_key_json},
- )
- log_kv({"message": "Device keys stored."})
- return True
-
- return await self.db_pool.runInteraction(
- "set_e2e_device_keys", _set_e2e_device_keys_txn
- )
-
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
- ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+ ) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Take a list of one time keys out of the database.
Args:
@@ -840,6 +816,50 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
+
+class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+ async def set_e2e_device_keys(
+ self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+ ) -> bool:
+ """Stores device keys for a device. Returns whether there was a change
+ or the keys were already in the database.
+ """
+
+ def _set_e2e_device_keys_txn(txn):
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("time_now", time_now)
+ set_tag("device_keys", device_keys)
+
+ old_key_json = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcol="key_json",
+ allow_none=True,
+ )
+
+ # In py3 we need old_key_json to match new_key_json type. The DB
+ # returns unicode while encode_canonical_json returns bytes.
+ new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+
+ if old_key_json == new_key_json:
+ log_kv({"Message": "Device key already stored."})
+ return False
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ values={"ts_added_ms": time_now, "key_json": new_key_json},
+ )
+ log_kv({"message": "Device keys stored."})
+ return True
+
+ return await self.db_pool.runInteraction(
+ "set_e2e_device_keys", _set_e2e_device_keys_txn
+ )
+
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn):
log_kv(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 2e07c37340..a956be491a 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Cursor
from synapse.types import Collection
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
+class _NoChainCoverIndex(Exception):
+ def __init__(self, room_id: str):
+ super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
+
+
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -47,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
) # type: LruCache[str, List[Tuple[str, int]]]
async def get_auth_chain(
- self, event_ids: Collection[str], include_given: bool = False
+ self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
+ room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
@@ -59,22 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
list of events
"""
event_ids = await self.get_auth_chain_ids(
- event_ids, include_given=include_given
+ room_id, event_ids, include_given=include_given
)
return await self.get_events_as_list(event_ids)
async def get_auth_chain_ids(
- self, event_ids: Collection[str], include_given: bool = False,
+ self,
+ room_id: str,
+ event_ids: Collection[str],
+ include_given: bool = False,
) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
+ room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
Returns:
- An awaitable which resolve to a list of event_ids
+ list of event_ids
"""
+
+ # Check if we have indexed the room so we can use the chain cover
+ # algorithm.
+ room = await self.get_room(room_id)
+ if room["has_auth_chain_index"]:
+ try:
+ return await self.db_pool.runInteraction(
+ "get_auth_chain_ids_chains",
+ self._get_auth_chain_ids_using_cover_index_txn,
+ room_id,
+ event_ids,
+ include_given,
+ )
+ except _NoChainCoverIndex:
+ # For whatever reason we don't actually have a chain cover index
+ # for the events in question, so we fall back to the old method.
+ pass
+
return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
@@ -82,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
include_given,
)
+ def _get_auth_chain_ids_using_cover_index_txn(
+ self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+ ) -> List[str]:
+ """Calculates the auth chain IDs using the chain index."""
+
+ # First we look up the chain ID/sequence numbers for the given events.
+
+ initial_events = set(event_ids)
+
+ # All the events that we've found that are reachable from the events.
+ seen_events = set() # type: Set[str]
+
+ # A map from chain ID to max sequence number of the given events.
+ event_chains = {} # type: Dict[int, int]
+
+ sql = """
+ SELECT event_id, chain_id, sequence_number
+ FROM event_auth_chains
+ WHERE %s
+ """
+ for batch in batch_iter(initial_events, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for event_id, chain_id, sequence_number in txn:
+ seen_events.add(event_id)
+ event_chains[chain_id] = max(
+ sequence_number, event_chains.get(chain_id, 0)
+ )
+
+ # Check that we actually have a chain ID for all the events.
+ events_missing_chain_info = initial_events.difference(seen_events)
+ if events_missing_chain_info:
+ # This can happen due to e.g. downgrade/upgrade of the server. We
+ # raise an exception and fall back to the previous algorithm.
+ logger.info(
+ "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ room_id,
+ events_missing_chain_info,
+ )
+ raise _NoChainCoverIndex(room_id)
+
+ # Now we look up all links for the chains we have, adding chains that
+ # are reachable from any event.
+ sql = """
+ SELECT
+ origin_chain_id, origin_sequence_number,
+ target_chain_id, target_sequence_number
+ FROM event_auth_chain_links
+ WHERE %s
+ """
+
+ # A map from chain ID to max sequence number *reachable* from any event ID.
+ chains = {} # type: Dict[int, int]
+
+ # Add all linked chains reachable from initial set of chains.
+ for batch in batch_iter(event_chains, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "origin_chain_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in txn:
+ # chains are only reachable if the origin sequence number of
+ # the link is less than the max sequence number in the
+ # origin chain.
+ if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
+ chains[target_chain_id] = max(
+ target_sequence_number,
+ chains.get(target_chain_id, 0),
+ )
+
+ # Add the initial set of chains, excluding the sequence corresponding to
+ # initial event.
+ for chain_id, seq_no in event_chains.items():
+ chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
+
+ # Now for each chain we figure out the maximum sequence number reachable
+ # from *any* event ID. Events with a sequence less than that are in the
+ # auth chain.
+ if include_given:
+ results = initial_events
+ else:
+ results = set()
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We can use `execute_values` to efficiently fetch the gaps when
+ # using postgres.
+ sql = """
+ SELECT event_id
+ FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
+ WHERE
+ c.chain_id = l.chain_id
+ AND sequence_number <= max_seq
+ """
+
+ rows = txn.execute_values(sql, chains.items())
+ results.update(r for r, in rows)
+ else:
+ # For SQLite we just fall back to doing a noddy for loop.
+ sql = """
+ SELECT event_id FROM event_auth_chains
+ WHERE chain_id = ? AND sequence_number <= ?
+ """
+ for chain_id, max_no in chains.items():
+ txn.execute(sql, (chain_id, max_no))
+ results.update(r for r, in txn)
+
+ return list(results)
+
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
+ """Calculates the auth chain IDs.
+
+ This is used when we don't have a cover index for the room.
+ """
if include_given:
results = set(event_ids)
else:
@@ -137,7 +288,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
- async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
+ async def get_auth_chain_difference(
+ self, room_id: str, state_sets: List[Set[str]]
+ ) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -149,15 +302,194 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
The set of the difference in auth chains.
"""
+ # Check if we have indexed the room so we can use the chain cover
+ # algorithm.
+ room = await self.get_room(room_id)
+ if room["has_auth_chain_index"]:
+ try:
+ return await self.db_pool.runInteraction(
+ "get_auth_chain_difference_chains",
+ self._get_auth_chain_difference_using_cover_index_txn,
+ room_id,
+ state_sets,
+ )
+ except _NoChainCoverIndex:
+ # For whatever reason we don't actually have a chain cover index
+ # for the events in question, so we fall back to the old method.
+ pass
+
return await self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
)
+ def _get_auth_chain_difference_using_cover_index_txn(
+ self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+ ) -> Set[str]:
+ """Calculates the auth chain difference using the chain index.
+
+ See docs/auth_chain_difference_algorithm.md for details
+ """
+
+ # First we look up the chain ID/sequence numbers for all the events, and
+ # work out the chain/sequence numbers reachable from each state set.
+
+ initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+ # Map from event_id -> (chain ID, seq no)
+ chain_info = {} # type: Dict[str, Tuple[int, int]]
+
+ # Map from chain ID -> seq no -> event Id
+ chain_to_event = {} # type: Dict[int, Dict[int, str]]
+
+ # All the chains that we've found that are reachable from the state
+ # sets.
+ seen_chains = set() # type: Set[int]
+
+ sql = """
+ SELECT event_id, chain_id, sequence_number
+ FROM event_auth_chains
+ WHERE %s
+ """
+ for batch in batch_iter(initial_events, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for event_id, chain_id, sequence_number in txn:
+ chain_info[event_id] = (chain_id, sequence_number)
+ seen_chains.add(chain_id)
+ chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+
+ # Check that we actually have a chain ID for all the events.
+ events_missing_chain_info = initial_events.difference(chain_info)
+ if events_missing_chain_info:
+ # This can happen due to e.g. downgrade/upgrade of the server. We
+ # raise an exception and fall back to the previous algorithm.
+ logger.info(
+ "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ room_id,
+ events_missing_chain_info,
+ )
+ raise _NoChainCoverIndex(room_id)
+
+ # Corresponds to `state_sets`, except as a map from chain ID to max
+ # sequence number reachable from the state set.
+ set_to_chain = [] # type: List[Dict[int, int]]
+ for state_set in state_sets:
+ chains = {} # type: Dict[int, int]
+ set_to_chain.append(chains)
+
+ for event_id in state_set:
+ chain_id, seq_no = chain_info[event_id]
+
+ chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
+
+ # Now we look up all links for the chains we have, adding chains to
+ # set_to_chain that are reachable from each set.
+ sql = """
+ SELECT
+ origin_chain_id, origin_sequence_number,
+ target_chain_id, target_sequence_number
+ FROM event_auth_chain_links
+ WHERE %s
+ """
+
+ # (We need to take a copy of `seen_chains` as we want to mutate it in
+ # the loop)
+ for batch in batch_iter(set(seen_chains), 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "origin_chain_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in txn:
+ for chains in set_to_chain:
+ # chains are only reachable if the origin sequence number of
+ # the link is less than the max sequence number in the
+ # origin chain.
+ if origin_sequence_number <= chains.get(origin_chain_id, 0):
+ chains[target_chain_id] = max(
+ target_sequence_number,
+ chains.get(target_chain_id, 0),
+ )
+
+ seen_chains.add(target_chain_id)
+
+ # Now for each chain we figure out the maximum sequence number reachable
+ # from *any* state set and the minimum sequence number reachable from
+ # *all* state sets. Events in that range are in the auth chain
+ # difference.
+ result = set()
+
+ # Mapping from chain ID to the range of sequence numbers that should be
+ # pulled from the database.
+ chain_to_gap = {} # type: Dict[int, Tuple[int, int]]
+
+ for chain_id in seen_chains:
+ min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
+ max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
+
+ if min_seq_no < max_seq_no:
+ # We have a non empty gap, try and fill it from the events that
+ # we have, otherwise add them to the list of gaps to pull out
+ # from the DB.
+ for seq_no in range(min_seq_no + 1, max_seq_no + 1):
+ event_id = chain_to_event.get(chain_id, {}).get(seq_no)
+ if event_id:
+ result.add(event_id)
+ else:
+ chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
+ break
+
+ if not chain_to_gap:
+ # If there are no gaps to fetch, we're done!
+ return result
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We can use `execute_values` to efficiently fetch the gaps when
+ # using postgres.
+ sql = """
+ SELECT event_id
+ FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
+ WHERE
+ c.chain_id = l.chain_id
+ AND min_seq < sequence_number AND sequence_number <= max_seq
+ """
+
+ args = [
+ (chain_id, min_no, max_no)
+ for chain_id, (min_no, max_no) in chain_to_gap.items()
+ ]
+
+ rows = txn.execute_values(sql, args)
+ result.update(r for r, in rows)
+ else:
+ # For SQLite we just fall back to doing a noddy for loop.
+ sql = """
+ SELECT event_id FROM event_auth_chains
+ WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
+ """
+ for chain_id, (min_no, max_no) in chain_to_gap.items():
+ txn.execute(sql, (chain_id, min_no, max_no))
+ result.update(r for r, in txn)
+
+ return result
+
def _get_auth_chain_difference_txn(
self, txn, state_sets: List[Set[str]]
) -> Set[str]:
+ """Calculates the auth chain difference using a breadth first search.
+
+ This is used when we don't have a cover index for the room.
+ """
# Algorithm Description
# ~~~~~~~~~~~~~~~~~~~~~
@@ -184,7 +516,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# and state sets {A} and {B} then walking the auth chains of A and B
# would immediately show that C is reachable by both. However, if we
# stopped at C then we'd only reach E via the auth chain of B and so E
- # would errornously get included in the returned difference.
+ # would erroneously get included in the returned difference.
#
# The other thing that we do is limit the number of auth chains we walk
# at once, due to practical limits (i.e. we can only query the database
@@ -310,7 +642,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
a_ids = new_aids
- # Mark that the auth event is reachable by the approriate sets.
+ # Mark that the auth event is reachable by the appropriate sets.
sets.intersection_update(event_to_missing_sets[event_id])
search.sort()
@@ -445,8 +777,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
async def get_min_depth(self, room_id: str) -> int:
- """For the given room, get the minimum depth we have seen for it.
- """
+ """For the given room, get the minimum depth we have seen for it."""
return await self.db_pool.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
@@ -462,7 +793,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return int(min_depth) if min_depth is not None else None
- async def get_forward_extremeties_for_room(
+ async def get_forward_extremities_for_room_at_stream_ordering(
self, room_id: str, stream_ordering: int
) -> List[str]:
"""For a given room_id and stream_ordering, return the forward
@@ -671,12 +1002,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
await self.db_pool.runInteraction(
- "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
+ "_delete_old_forward_extrem_cache",
+ _delete_old_forward_extrem_cache_txn,
)
class EventFederationStore(EventFederationWorkerStore):
- """ Responsible for storing and serving up the various graphs associated
+ """Responsible for storing and serving up the various graphs associated
with an event. Including the main event graph and the auth chains for an
event.
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 2e56dfaf31..78245ad5bd 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -54,8 +54,7 @@ def _serialize_action(actions, is_highlight):
def _deserialize_action(actions, is_highlight):
- """Custom deserializer for actions. This allows us to "compress" common actions
- """
+ """Custom deserializer for actions. This allows us to "compress" common actions"""
if actions:
return db_to_json(actions)
@@ -91,7 +90,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
- self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+ self,
+ room_id: str,
+ user_id: str,
+ last_read_event_id: Optional[str],
) -> Dict[str, int]:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt.
@@ -120,13 +122,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _get_unread_counts_by_receipt_txn(
- self, txn, room_id, user_id, last_read_event_id,
+ self,
+ txn,
+ room_id,
+ user_id,
+ last_read_event_id,
):
stream_ordering = None
if last_read_event_id is not None:
stream_ordering = self.get_stream_id_for_event_txn(
- txn, last_read_event_id, allow_none=True,
+ txn,
+ last_read_event_id,
+ allow_none=True,
)
if stream_ordering is None:
@@ -487,7 +495,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
VALUES (?, ?, ?, ?, ?, ?)
"""
- txn.executemany(
+ txn.execute_batch(
sql,
(
_gen_entry(user_id, actions)
@@ -803,7 +811,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
],
)
- txn.executemany(
+ txn.execute_batch(
"""
UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ?
@@ -835,6 +843,52 @@ class EventPushActionsWorkerStore(SQLBaseStore):
(rotate_to_stream_ordering,),
)
+ def _remove_old_push_actions_before_txn(
+ self, txn, room_id, user_id, stream_ordering
+ ):
+ """
+ Purges old push actions for a user and room before a given
+ stream_ordering.
+
+ We however keep a months worth of highlighted notifications, so that
+ users can still get a list of recent highlights.
+
+ Args:
+ txn: The transcation
+ room_id: Room ID to delete from
+ user_id: user ID to delete for
+ stream_ordering: The lowest stream ordering which will
+ not be deleted.
+ """
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (room_id, user_id),
+ )
+
+ # We need to join on the events table to get the received_ts for
+ # event_push_actions and sqlite won't let us use a join in a delete so
+ # we can't just delete where received_ts < x. Furthermore we can
+ # only identify event_push_actions by a tuple of room_id, event_id
+ # we we can't use a subquery.
+ # Instead, we look up the stream ordering for the last event in that
+ # room received before the threshold time and delete event_push_actions
+ # in the room with a stream_odering before that.
+ txn.execute(
+ "DELETE FROM event_push_actions "
+ " WHERE user_id = ? AND room_id = ? AND "
+ " stream_ordering <= ?"
+ " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+ (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+ )
+
+ txn.execute(
+ """
+ DELETE FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+ """,
+ (room_id, user_id, stream_ordering),
+ )
+
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@@ -894,62 +948,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
- async def get_latest_push_action_stream_ordering(self):
- def f(txn):
- txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
- return txn.fetchone()
-
- result = await self.db_pool.runInteraction(
- "get_latest_push_action_stream_ordering", f
- )
- return result[0] or 0
-
- def _remove_old_push_actions_before_txn(
- self, txn, room_id, user_id, stream_ordering
- ):
- """
- Purges old push actions for a user and room before a given
- stream_ordering.
-
- We however keep a months worth of highlighted notifications, so that
- users can still get a list of recent highlights.
-
- Args:
- txn: The transcation
- room_id: Room ID to delete from
- user_id: user ID to delete for
- stream_ordering: The lowest stream ordering which will
- not be deleted.
- """
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (room_id, user_id),
- )
-
- # We need to join on the events table to get the received_ts for
- # event_push_actions and sqlite won't let us use a join in a delete so
- # we can't just delete where received_ts < x. Furthermore we can
- # only identify event_push_actions by a tuple of room_id, event_id
- # we we can't use a subquery.
- # Instead, we look up the stream ordering for the last event in that
- # room received before the threshold time and delete event_push_actions
- # in the room with a stream_odering before that.
- txn.execute(
- "DELETE FROM event_push_actions "
- " WHERE user_id = ? AND room_id = ? AND "
- " stream_ordering <= ?"
- " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
- (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
- )
-
- txn.execute(
- """
- DELETE FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
- """,
- (room_id, user_id, stream_ordering),
- )
-
def _action_has_highlight(actions):
for action in actions:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 90fb1a1f00..98dac19a95 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,17 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
import attr
from prometheus_client import Counter
@@ -32,10 +42,12 @@ from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
+from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
-from synapse.util.iterutils import batch_iter
+from synapse.util.iterutils import batch_iter, sorted_topologically
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -80,7 +92,11 @@ class PersistEventsStore:
"""
def __init__(
- self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore"
+ self,
+ hs: "HomeServer",
+ db: DatabasePool,
+ main_data_store: "DataStore",
+ db_conn: Connection,
):
self.hs = hs
self.db_pool = db
@@ -366,6 +382,38 @@ class PersistEventsStore:
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
+ self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
+
+ # _store_rejected_events_txn filters out any events which were
+ # rejected, and returns the filtered list.
+ events_and_contexts = self._store_rejected_events_txn(
+ txn, events_and_contexts=events_and_contexts
+ )
+
+ # From this point onwards the events are only ones that weren't
+ # rejected.
+
+ self._update_metadata_tables_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ all_events_and_contexts=all_events_and_contexts,
+ backfilled=backfilled,
+ )
+
+ # We call this last as it assumes we've inserted the events into
+ # room_memberships, where applicable.
+ self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
+ def _persist_event_auth_chain_txn(
+ self,
+ txn: LoggingTransaction,
+ events: List[EventBase],
+ ) -> None:
+
+ # We only care about state events, so this if there are no state events.
+ if not any(e.is_state() for e in events):
+ return
+
# We want to store event_auth mappings for rejected events, as they're
# used in state res v2.
# This is only necessary if the rejected event appears in an accepted
@@ -381,39 +429,488 @@ class PersistEventsStore:
"room_id": event.room_id,
"auth_id": auth_id,
}
- for event, _ in events_and_contexts
+ for event in events
for auth_id in event.auth_event_ids()
if event.is_state()
],
)
- # _store_rejected_events_txn filters out any events which were
- # rejected, and returns the filtered list.
- events_and_contexts = self._store_rejected_events_txn(
- txn, events_and_contexts=events_and_contexts
+ # We now calculate chain ID/sequence numbers for any state events we're
+ # persisting. We ignore out of band memberships as we're not in the room
+ # and won't have their auth chain (we'll fix it up later if we join the
+ # room).
+ #
+ # See: docs/auth_chain_difference_algorithm.md
+
+ # We ignore legacy rooms that we aren't filling the chain cover index
+ # for.
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="rooms",
+ column="room_id",
+ iterable={event.room_id for event in events if event.is_state()},
+ keyvalues={},
+ retcols=("room_id", "has_auth_chain_index"),
)
+ rooms_using_chain_index = {
+ row["room_id"] for row in rows if row["has_auth_chain_index"]
+ }
- # From this point onwards the events are only ones that weren't
- # rejected.
+ state_events = {
+ event.event_id: event
+ for event in events
+ if event.is_state() and event.room_id in rooms_using_chain_index
+ }
- self._update_metadata_tables_txn(
+ if not state_events:
+ return
+
+ # We need to know the type/state_key and auth events of the events we're
+ # calculating chain IDs for. We don't rely on having the full Event
+ # instances as we'll potentially be pulling more events from the DB and
+ # we don't need the overhead of fetching/parsing the full event JSON.
+ event_to_types = {
+ e.event_id: (e.type, e.state_key) for e in state_events.values()
+ }
+ event_to_auth_chain = {
+ e.event_id: e.auth_event_ids() for e in state_events.values()
+ }
+ event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+
+ self._add_chain_cover_index(
txn,
- events_and_contexts=events_and_contexts,
- all_events_and_contexts=all_events_and_contexts,
- backfilled=backfilled,
+ self.db_pool,
+ self.store.event_chain_id_gen,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
)
- # We call this last as it assumes we've inserted the events into
- # room_memberships, where applicable.
- self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+ @classmethod
+ def _add_chain_cover_index(
+ cls,
+ txn,
+ db_pool: DatabasePool,
+ event_chain_id_gen: SequenceGenerator,
+ event_to_room_id: Dict[str, str],
+ event_to_types: Dict[str, Tuple[str, str]],
+ event_to_auth_chain: Dict[str, List[str]],
+ ) -> None:
+ """Calculate the chain cover index for the given events.
+
+ Args:
+ event_to_room_id: Event ID to the room ID of the event
+ event_to_types: Event ID to type and state_key of the event
+ event_to_auth_chain: Event ID to list of auth event IDs of the
+ event (events with no auth events can be excluded).
+ """
+
+ # Map from event ID to chain ID/sequence number.
+ chain_map = {} # type: Dict[str, Tuple[int, int]]
+
+ # Set of event IDs to calculate chain ID/seq numbers for.
+ events_to_calc_chain_id_for = set(event_to_room_id)
+
+ # We check if there are any events that need to be handled in the rooms
+ # we're looking at. These should just be out of band memberships, where
+ # we didn't have the auth chain when we first persisted.
+ rows = db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ keyvalues={},
+ column="room_id",
+ iterable=set(event_to_room_id.values()),
+ retcols=("event_id", "type", "state_key"),
+ )
+ for row in rows:
+ event_id = row["event_id"]
+ event_type = row["type"]
+ state_key = row["state_key"]
+
+ # (We could pull out the auth events for all rows at once using
+ # simple_select_many, but this case happens rarely and almost always
+ # with a single row.)
+ auth_events = db_pool.simple_select_onecol_txn(
+ txn,
+ "event_auth",
+ keyvalues={"event_id": event_id},
+ retcol="auth_id",
+ )
+
+ events_to_calc_chain_id_for.add(event_id)
+ event_to_types[event_id] = (event_type, state_key)
+ event_to_auth_chain[event_id] = auth_events
+
+ # First we get the chain ID and sequence numbers for the events'
+ # auth events (that aren't also currently being persisted).
+ #
+ # Note that there there is an edge case here where we might not have
+ # calculated chains and sequence numbers for events that were "out
+ # of band". We handle this case by fetching the necessary info and
+ # adding it to the set of events to calculate chain IDs for.
+
+ missing_auth_chains = {
+ a_id
+ for auth_events in event_to_auth_chain.values()
+ for a_id in auth_events
+ if a_id not in events_to_calc_chain_id_for
+ }
+
+ # We loop here in case we find an out of band membership and need to
+ # fetch their auth event info.
+ while missing_auth_chains:
+ sql = """
+ SELECT event_id, events.type, state_key, chain_id, sequence_number
+ FROM events
+ INNER JOIN state_events USING (event_id)
+ LEFT JOIN event_auth_chains USING (event_id)
+ WHERE
+ """
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine,
+ "event_id",
+ missing_auth_chains,
+ )
+ txn.execute(sql + clause, args)
+
+ missing_auth_chains.clear()
+
+ for auth_id, event_type, state_key, chain_id, sequence_number in txn:
+ event_to_types[auth_id] = (event_type, state_key)
+
+ if chain_id is None:
+ # No chain ID, so the event was persisted out of band.
+ # We add to list of events to calculate auth chains for.
+
+ events_to_calc_chain_id_for.add(auth_id)
+
+ event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn(
+ txn,
+ "event_auth",
+ keyvalues={"event_id": auth_id},
+ retcol="auth_id",
+ )
+
+ missing_auth_chains.update(
+ e
+ for e in event_to_auth_chain[auth_id]
+ if e not in event_to_types
+ )
+ else:
+ chain_map[auth_id] = (chain_id, sequence_number)
+
+ # Now we check if we have any events where we don't have auth chain,
+ # this should only be out of band memberships.
+ for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
+ for auth_id in event_to_auth_chain[event_id]:
+ if (
+ auth_id not in chain_map
+ and auth_id not in events_to_calc_chain_id_for
+ ):
+ events_to_calc_chain_id_for.discard(event_id)
+
+ # If this is an event we're trying to persist we add it to
+ # the list of events to calculate chain IDs for next time
+ # around. (Otherwise we will have already added it to the
+ # table).
+ room_id = event_to_room_id.get(event_id)
+ if room_id:
+ e_type, state_key = event_to_types[event_id]
+ db_pool.simple_insert_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "type": e_type,
+ "state_key": state_key,
+ },
+ )
+
+ # We stop checking the event's auth events since we've
+ # discarded it.
+ break
+
+ if not events_to_calc_chain_id_for:
+ return
+
+ # Allocate chain ID/sequence numbers to each new event.
+ new_chain_tuples = cls._allocate_chain_ids(
+ txn,
+ db_pool,
+ event_chain_id_gen,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
+ events_to_calc_chain_id_for,
+ chain_map,
+ )
+ chain_map.update(new_chain_tuples)
+
+ db_pool.simple_insert_many_txn(
+ txn,
+ table="event_auth_chains",
+ values=[
+ {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
+ for event_id, (c_id, seq) in new_chain_tuples.items()
+ ],
+ )
+
+ db_pool.simple_delete_many_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ keyvalues={},
+ column="event_id",
+ iterable=new_chain_tuples,
+ )
+
+ # Now we need to calculate any new links between chains caused by
+ # the new events.
+ #
+ # Links are pairs of chain ID/sequence numbers such that for any
+ # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
+ # if and only if there is at least one link (CA, S1) -> (CB, S2)
+ # where SA >= S1 and S2 >= SB.
+ #
+ # We try and avoid adding redundant links to the table, e.g. if we
+ # have two links between two chains which both start/end at the
+ # sequence number event (or cross) then one can be safely dropped.
+ #
+ # To calculate new links we look at every new event and:
+ # 1. Fetch the chain ID/sequence numbers of its auth events,
+ # discarding any that are reachable by other auth events, or
+ # that have the same chain ID as the event.
+ # 2. For each retained auth event we:
+ # a. Add a link from the event's to the auth event's chain
+ # ID/sequence number; and
+ # b. Add a link from the event to every chain reachable by the
+ # auth event.
+
+ # Step 1, fetch all existing links from all the chains we've seen
+ # referenced.
+ chain_links = _LinkMap()
+ rows = db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth_chain_links",
+ column="origin_chain_id",
+ iterable={chain_id for chain_id, _ in chain_map.values()},
+ keyvalues={},
+ retcols=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
+ )
+ for row in rows:
+ chain_links.add_link(
+ (row["origin_chain_id"], row["origin_sequence_number"]),
+ (row["target_chain_id"], row["target_sequence_number"]),
+ new=False,
+ )
+
+ # We do this in toplogical order to avoid adding redundant links.
+ for event_id in sorted_topologically(
+ events_to_calc_chain_id_for, event_to_auth_chain
+ ):
+ chain_id, sequence_number = chain_map[event_id]
+
+ # Filter out auth events that are reachable by other auth
+ # events. We do this by looking at every permutation of pairs of
+ # auth events (A, B) to check if B is reachable from A.
+ reduction = {
+ a_id
+ for a_id in event_to_auth_chain.get(event_id, [])
+ if chain_map[a_id][0] != chain_id
+ }
+ for start_auth_id, end_auth_id in itertools.permutations(
+ event_to_auth_chain.get(event_id, []),
+ r=2,
+ ):
+ if chain_links.exists_path_from(
+ chain_map[start_auth_id], chain_map[end_auth_id]
+ ):
+ reduction.discard(end_auth_id)
+
+ # Step 2, figure out what the new links are from the reduced
+ # list of auth events.
+ for auth_id in reduction:
+ auth_chain_id, auth_sequence_number = chain_map[auth_id]
+
+ # Step 2a, add link between the event and auth event
+ chain_links.add_link(
+ (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
+ )
+
+ # Step 2b, add a link to chains reachable from the auth
+ # event.
+ for target_id, target_seq in chain_links.get_links_from(
+ (auth_chain_id, auth_sequence_number)
+ ):
+ if target_id == chain_id:
+ continue
+
+ chain_links.add_link(
+ (chain_id, sequence_number), (target_id, target_seq)
+ )
+
+ db_pool.simple_insert_many_txn(
+ txn,
+ table="event_auth_chain_links",
+ values=[
+ {
+ "origin_chain_id": source_id,
+ "origin_sequence_number": source_seq,
+ "target_chain_id": target_id,
+ "target_sequence_number": target_seq,
+ }
+ for (
+ source_id,
+ source_seq,
+ target_id,
+ target_seq,
+ ) in chain_links.get_additions()
+ ],
+ )
+
+ @staticmethod
+ def _allocate_chain_ids(
+ txn,
+ db_pool: DatabasePool,
+ event_chain_id_gen: SequenceGenerator,
+ event_to_room_id: Dict[str, str],
+ event_to_types: Dict[str, Tuple[str, str]],
+ event_to_auth_chain: Dict[str, List[str]],
+ events_to_calc_chain_id_for: Set[str],
+ chain_map: Dict[str, Tuple[int, int]],
+ ) -> Dict[str, Tuple[int, int]]:
+ """Allocates, but does not persist, chain ID/sequence numbers for the
+ events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
+ for info on args)
+ """
+
+ # We now calculate the chain IDs/sequence numbers for the events. We do
+ # this by looking at the chain ID and sequence number of any auth event
+ # with the same type/state_key and incrementing the sequence number by
+ # one. If there was no match or the chain ID/sequence number is already
+ # taken we generate a new chain.
+ #
+ # We try to reduce the number of times that we hit the database by
+ # batching up calls, to make this more efficient when persisting large
+ # numbers of state events (e.g. during joins).
+ #
+ # We do this by:
+ # 1. Calculating for each event which auth event will be used to
+ # inherit the chain ID, i.e. converting the auth chain graph to a
+ # tree that we can allocate chains on. We also keep track of which
+ # existing chain IDs have been referenced.
+ # 2. Fetching the max allocated sequence number for each referenced
+ # existing chain ID, generating a map from chain ID to the max
+ # allocated sequence number.
+ # 3. Iterating over the tree and allocating a chain ID/seq no. to the
+ # new event, by incrementing the sequence number from the
+ # referenced event's chain ID/seq no. and checking that the
+ # incremented sequence number hasn't already been allocated (by
+ # looking in the map generated in the previous step). We generate a
+ # new chain if the sequence number has already been allocated.
+ #
+
+ existing_chains = set() # type: Set[int]
+ tree = [] # type: List[Tuple[str, Optional[str]]]
+
+ # We need to do this in a topologically sorted order as we want to
+ # generate chain IDs/sequence numbers of an event's auth events before
+ # the event itself.
+ for event_id in sorted_topologically(
+ events_to_calc_chain_id_for, event_to_auth_chain
+ ):
+ for auth_id in event_to_auth_chain.get(event_id, []):
+ if event_to_types.get(event_id) == event_to_types.get(auth_id):
+ existing_chain_id = chain_map.get(auth_id)
+ if existing_chain_id:
+ existing_chains.add(existing_chain_id[0])
+
+ tree.append((event_id, auth_id))
+ break
+ else:
+ tree.append((event_id, None))
+
+ # Fetch the current max sequence number for each existing referenced chain.
+ sql = """
+ SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
+ WHERE %s
+ GROUP BY chain_id
+ """
+ clause, args = make_in_list_sql_clause(
+ db_pool.engine, "chain_id", existing_chains
+ )
+ txn.execute(sql % (clause,), args)
+
+ chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
+
+ # Allocate the new events chain ID/sequence numbers.
+ #
+ # To reduce the number of calls to the database we don't allocate a
+ # chain ID number in the loop, instead we use a temporary `object()` for
+ # each new chain ID. Once we've done the loop we generate the necessary
+ # number of new chain IDs in one call, replacing all temporary
+ # objects with real allocated chain IDs.
+
+ unallocated_chain_ids = set() # type: Set[object]
+ new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
+ for event_id, auth_event_id in tree:
+ # If we reference an auth_event_id we fetch the allocated chain ID,
+ # either from the existing `chain_map` or the newly generated
+ # `new_chain_tuples` map.
+ existing_chain_id = None
+ if auth_event_id:
+ existing_chain_id = new_chain_tuples.get(auth_event_id)
+ if not existing_chain_id:
+ existing_chain_id = chain_map[auth_event_id]
+
+ new_chain_tuple = None # type: Optional[Tuple[Any, int]]
+ if existing_chain_id:
+ # We found a chain ID/sequence number candidate, check its
+ # not already taken.
+ proposed_new_id = existing_chain_id[0]
+ proposed_new_seq = existing_chain_id[1] + 1
+
+ if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
+ new_chain_tuple = (
+ proposed_new_id,
+ proposed_new_seq,
+ )
+
+ # If we need to start a new chain we allocate a temporary chain ID.
+ if not new_chain_tuple:
+ new_chain_tuple = (object(), 1)
+ unallocated_chain_ids.add(new_chain_tuple[0])
+
+ new_chain_tuples[event_id] = new_chain_tuple
+ chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
+
+ # Generate new chain IDs for all unallocated chain IDs.
+ newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
+ txn, len(unallocated_chain_ids)
+ )
+
+ # Map from potentially temporary chain ID to real chain ID
+ chain_id_to_allocated_map = dict(
+ zip(unallocated_chain_ids, newly_allocated_chain_ids)
+ ) # type: Dict[Any, int]
+ chain_id_to_allocated_map.update((c, c) for c in existing_chains)
+
+ return {
+ event_id: (chain_id_to_allocated_map[chain_id], seq)
+ for event_id, (chain_id, seq) in new_chain_tuples.items()
+ }
def _persist_transaction_ids_txn(
self,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
):
- """Persist the mapping from transaction IDs to event IDs (if defined).
- """
+ """Persist the mapping from transaction IDs to event IDs (if defined)."""
to_insert = []
for event, _ in events_and_contexts:
@@ -433,7 +930,9 @@ class PersistEventsStore:
if to_insert:
self.db_pool.simple_insert_many_txn(
- txn, table="event_txn_id", values=to_insert,
+ txn,
+ table="event_txn_id",
+ values=to_insert,
)
def _update_current_state_txn(
@@ -465,7 +964,9 @@ class PersistEventsStore:
txn.execute(sql, (stream_id, self._instance_name, room_id))
self.db_pool.simple_delete_txn(
- txn, table="current_state_events", keyvalues={"room_id": room_id},
+ txn,
+ table="current_state_events",
+ keyvalues={"room_id": room_id},
)
else:
# We're still in the room, so we update the current state as normal.
@@ -489,7 +990,7 @@ class PersistEventsStore:
WHERE room_id = ? AND type = ? AND state_key = ?
)
"""
- txn.executemany(
+ txn.execute_batch(
sql,
(
(
@@ -508,7 +1009,7 @@ class PersistEventsStore:
)
# Now we actually update the current_state_events table
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?",
(
@@ -520,7 +1021,7 @@ class PersistEventsStore:
# We include the membership in the current state table, hence we do
# a lookup when we insert. This assumes that all events have already
# been inserted into room_memberships.
- txn.executemany(
+ txn.execute_batch(
"""INSERT INTO current_state_events
(room_id, type, state_key, event_id, membership)
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -540,7 +1041,7 @@ class PersistEventsStore:
# we have no record of the fact the user *was* a member of the
# room but got, say, state reset out of it.
if to_delete or to_insert:
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM local_current_membership"
" WHERE room_id = ? AND user_id = ?",
(
@@ -551,7 +1052,7 @@ class PersistEventsStore:
)
if to_insert:
- txn.executemany(
+ txn.execute_batch(
"""INSERT INTO local_current_membership
(room_id, user_id, event_id, membership)
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -574,7 +1075,7 @@ class PersistEventsStore:
# Figure out the changes of membership to invalidate the
# `get_rooms_for_user` cache.
# We find out which membership events we may have deleted
- # and which we have added, then we invlidate the caches for all
+ # and which we have added, then we invalidate the caches for all
# those users.
members_changed = {
state_key
@@ -769,8 +1270,10 @@ class PersistEventsStore:
logger.exception("")
raise
+ # update the stored internal_metadata to update the "outlier" flag.
+ # TODO: This is unused as of Synapse 1.31. Remove it once we are happy
+ # to drop backwards-compatibility with 1.30.
metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
-
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
@@ -799,7 +1302,8 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
def _store_event_txn(self, txn, events_and_contexts):
- """Insert new events into the event and event_json tables
+ """Insert new events into the event, event_json, redaction and
+ state_events tables.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
@@ -817,6 +1321,19 @@ class PersistEventsStore:
d.pop("redacted_because", None)
return d
+ def get_internal_metadata(event):
+ im = event.internal_metadata.get_dict()
+
+ # temporary hack for database compatibility with Synapse 1.30 and earlier:
+ # store the `outlier` flag inside the internal_metadata json as well as in
+ # the `events` table, so that if anyone rolls back to an older Synapse,
+ # things keep working. This can be removed once we are happy to drop support
+ # for that
+ if event.internal_metadata.is_outlier():
+ im["outlier"] = True
+
+ return im
+
self.db_pool.simple_insert_many_txn(
txn,
table="event_json",
@@ -825,7 +1342,7 @@ class PersistEventsStore:
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": json_encoder.encode(
- event.internal_metadata.get_dict()
+ get_internal_metadata(event)
),
"json": json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
@@ -871,6 +1388,29 @@ class PersistEventsStore:
updatevalues={"have_censored": False},
)
+ state_events_and_contexts = [
+ ec for ec in events_and_contexts if ec[0].is_state()
+ ]
+
+ state_values = []
+ for event, context in state_events_and_contexts:
+ vals = {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ }
+
+ # TODO: How does this work with backfilling?
+ if hasattr(event, "replaces_state"):
+ vals["prev_state"] = event.replaces_state
+
+ state_values.append(vals)
+
+ self.db_pool.simple_insert_many_txn(
+ txn, table="state_events", values=state_values
+ )
+
def _store_rejected_events_txn(self, txn, events_and_contexts):
"""Add rows to the 'rejections' table for received events which were
rejected
@@ -987,29 +1527,6 @@ class PersistEventsStore:
txn, [event for event, _ in events_and_contexts]
)
- state_events_and_contexts = [
- ec for ec in events_and_contexts if ec[0].is_state()
- ]
-
- state_values = []
- for event, context in state_events_and_contexts:
- vals = {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- }
-
- # TODO: How does this work with backfilling?
- if hasattr(event, "replaces_state"):
- vals["prev_state"] = event.replaces_state
-
- state_values.append(vals)
-
- self.db_pool.simple_insert_many_txn(
- txn, table="state_events", values=state_values
- )
-
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
@@ -1131,8 +1648,7 @@ class PersistEventsStore:
)
def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database.
- """
+ """Store a room member in the database."""
def str_or_none(val: Any) -> Optional[str]:
return val if isinstance(val, str) else None
@@ -1350,7 +1866,7 @@ class PersistEventsStore:
"""
if events_and_contexts:
- txn.executemany(
+ txn.execute_batch(
sql,
(
(
@@ -1379,7 +1895,7 @@ class PersistEventsStore:
# Now we delete the staging area for *all* events that were being
# persisted.
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
((event.event_id,) for event, _ in all_events_and_contexts),
)
@@ -1498,7 +2014,7 @@ class PersistEventsStore:
" )"
)
- txn.executemany(
+ txn.execute_batch(
query,
[
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@@ -1512,7 +2028,7 @@ class PersistEventsStore:
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
- txn.executemany(
+ txn.execute_batch(
query,
[
(ev.event_id, ev.room_id)
@@ -1520,3 +2036,132 @@ class PersistEventsStore:
if not ev.internal_metadata.is_outlier()
],
)
+
+
+@attr.s(slots=True)
+class _LinkMap:
+ """A helper type for tracking links between chains."""
+
+ # Stores the set of links as nested maps: source chain ID -> target chain ID
+ # -> source sequence number -> target sequence number.
+ maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
+
+ # Stores the links that have been added (with new set to true), as tuples of
+ # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
+ additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
+
+ def add_link(
+ self,
+ src_tuple: Tuple[int, int],
+ target_tuple: Tuple[int, int],
+ new: bool = True,
+ ) -> bool:
+ """Add a new link between two chains, ensuring no redundant links are added.
+
+ New links should be added in topological order.
+
+ Args:
+ src_tuple: The chain ID/sequence number of the source of the link.
+ target_tuple: The chain ID/sequence number of the target of the link.
+ new: Whether this is a "new" link, i.e. should it be returned
+ by `get_additions`.
+
+ Returns:
+ True if a link was added, false if the given link was dropped as redundant
+ """
+ src_chain, src_seq = src_tuple
+ target_chain, target_seq = target_tuple
+
+ current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
+
+ assert src_chain != target_chain
+
+ if new:
+ # Check if the new link is redundant
+ for current_seq_src, current_seq_target in current_links.items():
+ # If a link "crosses" another link then its redundant. For example
+ # in the following link 1 (L1) is redundant, as any event reachable
+ # via L1 is *also* reachable via L2.
+ #
+ # Chain A Chain B
+ # | |
+ # L1 |------ |
+ # | | |
+ # L2 |---- | -->|
+ # | | |
+ # | |--->|
+ # | |
+ # | |
+ #
+ # So we only need to keep links which *do not* cross, i.e. links
+ # that both start and end above or below an existing link.
+ #
+ # Note, since we add links in topological ordering we should never
+ # see `src_seq` less than `current_seq_src`.
+
+ if current_seq_src <= src_seq and target_seq <= current_seq_target:
+ # This new link is redundant, nothing to do.
+ return False
+
+ self.additions.add((src_chain, src_seq, target_chain, target_seq))
+
+ current_links[src_seq] = target_seq
+ return True
+
+ def get_links_from(
+ self, src_tuple: Tuple[int, int]
+ ) -> Generator[Tuple[int, int], None, None]:
+ """Gets the chains reachable from the given chain/sequence number.
+
+ Yields:
+ The chain ID and sequence number the link points to.
+ """
+ src_chain, src_seq = src_tuple
+ for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
+ for link_src_seq, target_seq in sequence_numbers.items():
+ if link_src_seq <= src_seq:
+ yield target_id, target_seq
+
+ def get_links_between(
+ self, source_chain: int, target_chain: int
+ ) -> Generator[Tuple[int, int], None, None]:
+ """Gets the links between two chains.
+
+ Yields:
+ The source and target sequence numbers.
+ """
+
+ yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
+
+ def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
+ """Gets any newly added links.
+
+ Yields:
+ The source chain ID/sequence number and target chain ID/sequence number
+ """
+
+ for src_chain, src_seq, target_chain, _ in self.additions:
+ target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
+ if target_seq is not None:
+ yield (src_chain, src_seq, target_chain, target_seq)
+
+ def exists_path_from(
+ self,
+ src_tuple: Tuple[int, int],
+ target_tuple: Tuple[int, int],
+ ) -> bool:
+ """Checks if there is a path between the source chain ID/sequence and
+ target chain ID/sequence.
+ """
+ src_chain, src_seq = src_tuple
+ target_chain, target_seq = target_tuple
+
+ if src_chain == target_chain:
+ return target_seq <= src_seq
+
+ links = self.get_links_between(src_chain, target_chain)
+ for link_start_seq, link_end_seq in links:
+ if link_start_seq <= src_seq and target_seq <= link_end_seq:
+ return True
+
+ return False
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 97b6754846..78367ea58d 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -14,14 +14,40 @@
# limitations under the License.
import logging
+from typing import Dict, List, Optional, Tuple
+
+import attr
from synapse.api.constants import EventContentFields
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import make_event_from_dict
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.storage.databases.main.events import PersistEventsStore
+from synapse.storage.types import Cursor
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True)
+class _CalculateChainCover:
+ """Return value for _calculate_chain_cover_txn."""
+
+ # The last room_id/depth/stream processed.
+ room_id = attr.ib(type=str)
+ depth = attr.ib(type=int)
+ stream = attr.ib(type=int)
+
+ # Number of rows processed
+ processed_count = attr.ib(type=int)
+
+ # Map from room_id to last depth/stream processed for each room that we have
+ # processed all events for (i.e. the rooms we can flip the
+ # `has_auth_chain_index` for)
+ finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+
+
class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@@ -99,13 +125,26 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
columns=["user_id", "created_ts"],
)
+ self.db_pool.updates.register_background_update_handler(
+ "rejected_events_metadata",
+ self._rejected_events_metadata,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ "chain_cover",
+ self._chain_cover_index,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ "purged_chain_cover",
+ self._purged_chain_cover_index,
+ )
+
async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
-
def reindex_txn(txn):
sql = (
"SELECT stream_ordering, event_id, json FROM events"
@@ -143,9 +182,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
- for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
- clump = update_rows[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ txn.execute_batch(sql, update_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@@ -175,8 +212,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
-
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
@@ -221,9 +256,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
- for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
- clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ txn.execute_batch(sql, rows_to_update)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@@ -435,8 +468,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return num_handled
async def _redactions_received_ts(self, progress, batch_size):
- """Handles filling out the `received_ts` column in redactions.
- """
+ """Handles filling out the `received_ts` column in redactions."""
last_event_id = progress.get("last_event_id", "")
def _redactions_received_ts_txn(txn):
@@ -491,8 +523,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return count
async def _event_fix_redactions_bytes(self, progress, batch_size):
- """Undoes hex encoded censored redacted event JSON.
- """
+ """Undoes hex encoded censored redacted event JSON."""
def _event_fix_redactions_bytes_txn(txn):
# This update is quite fast due to new index.
@@ -582,3 +613,401 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
await self.db_pool.updates._end_background_update("event_store_labels")
return num_rows
+
+ async def _rejected_events_metadata(self, progress: dict, batch_size: int) -> int:
+ """Adds rejected events to the `state_events` and `event_auth` metadata
+ tables.
+ """
+
+ last_event_id = progress.get("last_event_id", "")
+
+ def get_rejected_events(
+ txn: Cursor,
+ ) -> List[Tuple[str, str, JsonDict, bool, bool]]:
+ # Fetch rejected event json, their room version and whether we have
+ # inserted them into the state_events or auth_events tables.
+ #
+ # Note we can assume that events that don't have a corresponding
+ # room version are V1 rooms.
+ sql = """
+ SELECT DISTINCT
+ event_id,
+ COALESCE(room_version, '1'),
+ json,
+ state_events.event_id IS NOT NULL,
+ event_auth.event_id IS NOT NULL
+ FROM rejections
+ INNER JOIN event_json USING (event_id)
+ LEFT JOIN rooms USING (room_id)
+ LEFT JOIN state_events USING (event_id)
+ LEFT JOIN event_auth USING (event_id)
+ WHERE event_id > ?
+ ORDER BY event_id
+ LIMIT ?
+ """
+
+ txn.execute(
+ sql,
+ (
+ last_event_id,
+ batch_size,
+ ),
+ )
+
+ return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
+
+ results = await self.db_pool.runInteraction(
+ desc="_rejected_events_metadata_get", func=get_rejected_events
+ )
+
+ if not results:
+ await self.db_pool.updates._end_background_update(
+ "rejected_events_metadata"
+ )
+ return 0
+
+ state_events = []
+ auth_events = []
+ for event_id, room_version, event_json, has_state, has_event_auth in results:
+ last_event_id = event_id
+
+ if has_state and has_event_auth:
+ continue
+
+ room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version)
+ if not room_version_obj:
+ # We no longer support this room version, so we just ignore the
+ # events entirely.
+ logger.info(
+ "Ignoring event with unknown room version %r: %r",
+ room_version,
+ event_id,
+ )
+ continue
+
+ event = make_event_from_dict(event_json, room_version_obj)
+
+ if not event.is_state():
+ continue
+
+ if not has_state:
+ state_events.append(
+ {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ }
+ )
+
+ if not has_event_auth:
+ # Old, dodgy, events may have duplicate auth events, which we
+ # need to deduplicate as we have a unique constraint.
+ for auth_id in set(event.auth_event_ids()):
+ auth_events.append(
+ {
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ "auth_id": auth_id,
+ }
+ )
+
+ if state_events:
+ await self.db_pool.simple_insert_many(
+ table="state_events",
+ values=state_events,
+ desc="_rejected_events_metadata_state_events",
+ )
+
+ if auth_events:
+ await self.db_pool.simple_insert_many(
+ table="event_auth",
+ values=auth_events,
+ desc="_rejected_events_metadata_event_auth",
+ )
+
+ await self.db_pool.updates._background_update_progress(
+ "rejected_events_metadata", {"last_event_id": last_event_id}
+ )
+
+ if len(results) < batch_size:
+ await self.db_pool.updates._end_background_update(
+ "rejected_events_metadata"
+ )
+
+ return len(results)
+
+ async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
+ """A background updates that iterates over all rooms and generates the
+ chain cover index for them.
+ """
+
+ current_room_id = progress.get("current_room_id", "")
+
+ # Where we've processed up to in the room, defaults to the start of the
+ # room.
+ last_depth = progress.get("last_depth", -1)
+ last_stream = progress.get("last_stream", -1)
+
+ result = await self.db_pool.runInteraction(
+ "_chain_cover_index",
+ self._calculate_chain_cover_txn,
+ current_room_id,
+ last_depth,
+ last_stream,
+ batch_size,
+ single_room=False,
+ )
+
+ finished = result.processed_count == 0
+
+ total_rows_processed = result.processed_count
+ current_room_id = result.room_id
+ last_depth = result.depth
+ last_stream = result.stream
+
+ for room_id, (depth, stream) in result.finished_room_map.items():
+ # If we've done all the events in the room we flip the
+ # `has_auth_chain_index` in the DB. Note that its possible for
+ # further events to be persisted between the above and setting the
+ # flag without having the chain cover calculated for them. This is
+ # fine as a) the code gracefully handles these cases and b) we'll
+ # calculate them below.
+
+ await self.db_pool.simple_update(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": True},
+ desc="_chain_cover_index",
+ )
+
+ # Handle any events that might have raced with us flipping the
+ # bit above.
+ result = await self.db_pool.runInteraction(
+ "_chain_cover_index",
+ self._calculate_chain_cover_txn,
+ room_id,
+ depth,
+ stream,
+ batch_size=None,
+ single_room=True,
+ )
+
+ total_rows_processed += result.processed_count
+
+ if finished:
+ await self.db_pool.updates._end_background_update("chain_cover")
+ return total_rows_processed
+
+ await self.db_pool.updates._background_update_progress(
+ "chain_cover",
+ {
+ "current_room_id": current_room_id,
+ "last_depth": last_depth,
+ "last_stream": last_stream,
+ },
+ )
+
+ return total_rows_processed
+
+ def _calculate_chain_cover_txn(
+ self,
+ txn: Cursor,
+ last_room_id: str,
+ last_depth: int,
+ last_stream: int,
+ batch_size: Optional[int],
+ single_room: bool,
+ ) -> _CalculateChainCover:
+ """Calculate the chain cover for `batch_size` events, ordered by
+ `(room_id, depth, stream)`.
+
+ Args:
+ txn,
+ last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
+ tuple to fetch results after.
+ batch_size: The maximum number of events to process. If None then
+ no limit.
+ single_room: Whether to calculate the index for just the given
+ room.
+ """
+
+ # Get the next set of events in the room (that we haven't already
+ # computed chain cover for). We do this in topological order.
+
+ # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
+ # comparison, but that is not supported on older SQLite versions
+ tuple_clause, tuple_args = make_tuple_comparison_clause(
+ self.database_engine,
+ [
+ ("events.room_id", last_room_id),
+ ("topological_ordering", last_depth),
+ ("stream_ordering", last_stream),
+ ],
+ )
+
+ extra_clause = ""
+ if single_room:
+ extra_clause = "AND events.room_id = ?"
+ tuple_args.append(last_room_id)
+
+ sql = """
+ SELECT
+ event_id, state_events.type, state_events.state_key,
+ topological_ordering, stream_ordering,
+ events.room_id
+ FROM events
+ INNER JOIN state_events USING (event_id)
+ LEFT JOIN event_auth_chains USING (event_id)
+ LEFT JOIN event_auth_chain_to_calculate USING (event_id)
+ WHERE event_auth_chains.event_id IS NULL
+ AND event_auth_chain_to_calculate.event_id IS NULL
+ AND %(tuple_cmp)s
+ %(extra)s
+ ORDER BY events.room_id, topological_ordering, stream_ordering
+ %(limit)s
+ """ % {
+ "tuple_cmp": tuple_clause,
+ "limit": "LIMIT ?" if batch_size is not None else "",
+ "extra": extra_clause,
+ }
+
+ if batch_size is not None:
+ tuple_args.append(batch_size)
+
+ txn.execute(sql, tuple_args)
+ rows = txn.fetchall()
+
+ # Put the results in the necessary format for
+ # `_add_chain_cover_index`
+ event_to_room_id = {row[0]: row[5] for row in rows}
+ event_to_types = {row[0]: (row[1], row[2]) for row in rows}
+
+ # Calculate the new last position we've processed up to.
+ new_last_depth = rows[-1][3] if rows else last_depth # type: int
+ new_last_stream = rows[-1][4] if rows else last_stream # type: int
+ new_last_room_id = rows[-1][5] if rows else "" # type: str
+
+ # Map from room_id to last depth/stream_ordering processed for the room,
+ # excluding the last room (which we're likely still processing). We also
+ # need to include the room passed in if it's not included in the result
+ # set (as we then know we've processed all events in said room).
+ #
+ # This is the set of rooms that we can now safely flip the
+ # `has_auth_chain_index` bit for.
+ finished_rooms = {
+ row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
+ }
+ if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
+ finished_rooms[last_room_id] = (last_depth, last_stream)
+
+ count = len(rows)
+
+ # We also need to fetch the auth events for them.
+ auth_events = self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth",
+ column="event_id",
+ iterable=event_to_room_id,
+ keyvalues={},
+ retcols=("event_id", "auth_id"),
+ )
+
+ event_to_auth_chain = {} # type: Dict[str, List[str]]
+ for row in auth_events:
+ event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+
+ # Calculate and persist the chain cover index for this set of events.
+ #
+ # Annoyingly we need to gut wrench into the persit event store so that
+ # we can reuse the function to calculate the chain cover for rooms.
+ PersistEventsStore._add_chain_cover_index(
+ txn,
+ self.db_pool,
+ self.event_chain_id_gen,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
+ )
+
+ return _CalculateChainCover(
+ room_id=new_last_room_id,
+ depth=new_last_depth,
+ stream=new_last_stream,
+ processed_count=count,
+ finished_room_map=finished_rooms,
+ )
+
+ async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> int:
+ """
+ A background updates that iterates over the chain cover and deletes the
+ chain cover for events that have been purged.
+
+ This may be due to fully purging a room or via setting a retention policy.
+ """
+ current_event_id = progress.get("current_event_id", "")
+
+ def purged_chain_cover_txn(txn) -> int:
+ # The event ID from events will be null if the chain ID / sequence
+ # number points to a purged event.
+ sql = """
+ SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL
+ FROM event_auth_chains
+ LEFT JOIN events AS e USING (event_id)
+ WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ?
+ """
+ txn.execute(sql, (current_event_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ # The event IDs and chain IDs / sequence numbers where the event has
+ # been purged.
+ unreferenced_event_ids = []
+ unreferenced_chain_id_tuples = []
+ event_id = ""
+ for event_id, chain_id, sequence_number, has_event in rows:
+ if not has_event:
+ unreferenced_event_ids.append((event_id,))
+ unreferenced_chain_id_tuples.append((chain_id, sequence_number))
+
+ # Delete the unreferenced auth chains from event_auth_chain_links and
+ # event_auth_chains.
+ txn.executemany(
+ """
+ DELETE FROM event_auth_chains WHERE event_id = ?
+ """,
+ unreferenced_event_ids,
+ )
+ # We should also delete matching target_*, but there is no index on
+ # target_chain_id. Hopefully any purged events are due to a room
+ # being fully purged and they will be removed from the origin_*
+ # searches.
+ txn.executemany(
+ """
+ DELETE FROM event_auth_chain_links WHERE
+ origin_chain_id = ? AND origin_sequence_number = ?
+ """,
+ unreferenced_chain_id_tuples,
+ )
+
+ progress = {
+ "current_event_id": event_id,
+ }
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "purged_chain_cover", progress
+ )
+
+ return len(rows)
+
+ result = await self.db_pool.runInteraction(
+ "_purged_chain_cover_index",
+ purged_chain_cover_txn,
+ )
+
+ if not result:
+ await self.db_pool.updates._end_background_update("purged_chain_cover")
+
+ return result
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
new file mode 100644
index 0000000000..b3703ae161
--- /dev/null
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Dict, List
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class EventForwardExtremitiesStore(SQLBaseStore):
+ async def delete_forward_extremities_for_room(self, room_id: str) -> int:
+ """Delete any extra forward extremities for a room.
+
+ Invalidates the "get_latest_event_ids_in_room" cache if any forward
+ extremities were deleted.
+
+ Returns count deleted.
+ """
+
+ def delete_forward_extremities_for_room_txn(txn):
+ # First we need to get the event_id to not delete
+ sql = """
+ SELECT event_id FROM event_forward_extremities
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id = ?
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ """
+ txn.execute(sql, (room_id,))
+ rows = txn.fetchall()
+ try:
+ event_id = rows[0][0]
+ logger.debug(
+ "Found event_id %s as the forward extremity to keep for room %s",
+ event_id,
+ room_id,
+ )
+ except KeyError:
+ msg = "No forward extremity event found for room %s" % room_id
+ logger.warning(msg)
+ raise SynapseError(400, msg)
+
+ # Now delete the extra forward extremities
+ sql = """
+ DELETE FROM event_forward_extremities
+ WHERE event_id != ? AND room_id = ?
+ """
+
+ txn.execute(sql, (event_id, room_id))
+ logger.info(
+ "Deleted %s extra forward extremities for room %s",
+ txn.rowcount,
+ room_id,
+ )
+
+ if txn.rowcount > 0:
+ # Invalidate the cache
+ self._invalidate_cache_and_stream(
+ txn,
+ self.get_latest_event_ids_in_room,
+ (room_id,),
+ )
+
+ return txn.rowcount
+
+ return await self.db_pool.runInteraction(
+ "delete_forward_extremities_for_room",
+ delete_forward_extremities_for_room_txn,
+ )
+
+ async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+ """Get list of forward extremities for a room."""
+
+ def get_forward_extremities_for_room_txn(txn):
+ sql = """
+ SELECT event_id, state_group, depth, received_ts
+ FROM event_forward_extremities
+ INNER JOIN event_to_state_groups USING (event_id)
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id = ?
+ """
+
+ txn.execute(sql, (room_id,))
+ return self.db_pool.cursor_to_dict(txn)
+
+ return await self.db_pool.runInteraction(
+ "get_forward_extremities_for_room",
+ get_forward_extremities_for_room_txn,
+ )
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index f21baffd10..dfb638ea54 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
+
import logging
import threading
from collections import namedtuple
@@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection, JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -96,9 +97,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database,
stream_name="events",
instance_name=hs.get_instance_name(),
- table="events",
- instance_column="instance_name",
- id_column="stream_ordering",
+ tables=[("events", "instance_name", "stream_ordering")],
sequence_name="events_stream_seq",
writers=hs.config.worker.writers.events,
)
@@ -107,9 +106,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database,
stream_name="backfill",
instance_name=hs.get_instance_name(),
- table="events",
- instance_column="instance_name",
- id_column="stream_ordering",
+ tables=[("events", "instance_name", "stream_ordering")],
sequence_name="events_backfill_stream_seq",
positive=False,
writers=hs.config.worker.writers.events,
@@ -124,7 +121,9 @@ class EventsWorkerStore(SQLBaseStore):
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._stream_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering",
+ db_conn,
+ "events",
+ "stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
@@ -144,7 +143,8 @@ class EventsWorkerStore(SQLBaseStore):
if hs.config.run_background_tasks:
# We periodically clean out old transaction ID mappings
self._clock.looping_call(
- self._cleanup_old_transaction_ids, 5 * 60 * 1000,
+ self._cleanup_old_transaction_ids,
+ 5 * 60 * 1000,
)
self._get_event_cache = LruCache(
@@ -157,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_list = []
self._event_fetch_ongoing = 0
+ # We define this sequence here so that it can be referenced from both
+ # the DataStore and PersistEventStore.
+ def get_chain_id_txn(txn):
+ txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+ return txn.fetchone()[0]
+
+ self.event_chain_id_gen = build_sequence_generator(
+ db_conn,
+ database.engine,
+ get_chain_id_txn,
+ "event_auth_chain_id",
+ table="event_auth_chains",
+ id_column="chain_id",
+ )
+
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
@@ -784,6 +799,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected_reason=rejected_reason,
)
original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
+ original_ev.internal_metadata.outlier = row["outlier"]
event_map[event_id] = original_ev
@@ -890,7 +906,8 @@ class EventsWorkerStore(SQLBaseStore):
ej.json,
ej.format_version,
r.room_version,
- rej.reason
+ rej.reason,
+ e.outlier
FROM events AS e
JOIN event_json AS ej USING (event_id)
LEFT JOIN rooms r ON r.room_id = e.room_id
@@ -914,6 +931,7 @@ class EventsWorkerStore(SQLBaseStore):
"room_version_id": row[5],
"rejected_reason": row[6],
"redactions": [],
+ "outlier": row[7],
}
# check for redactions
@@ -1029,7 +1047,8 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
set[str]: The events we have already seen.
"""
- results = set()
+ # if the event cache contains the event, obviously we've seen it.
+ results = {x for x in event_ids if self._get_event_cache.contains(x)}
def have_seen_events_txn(txn, chunk):
sql = "SELECT event_id FROM events as e WHERE "
@@ -1037,12 +1056,9 @@ class EventsWorkerStore(SQLBaseStore):
txn.database_engine, "e.event_id", chunk
)
txn.execute(sql + clause, args)
- for (event_id,) in txn:
- results.add(event_id)
+ results.update(row[0] for row in txn)
- # break the input up into chunks of 100
- input_iterator = iter(event_ids)
- for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
+ for chunk in batch_iter((x for x in event_ids if x not in results), 100):
await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
@@ -1329,8 +1345,7 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
async def is_event_after(self, event_id1, event_id2):
- """Returns True if event_id1 is after event_id2 in the stream
- """
+ """Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@@ -1432,8 +1447,7 @@ class EventsWorkerStore(SQLBaseStore):
@wrap_as_background_process("_cleanup_old_transaction_ids")
async def _cleanup_old_transaction_ids(self):
- """Cleans out transaction id mappings older than 24hrs.
- """
+ """Cleans out transaction id mappings older than 24hrs."""
def _cleanup_old_transaction_ids_txn(txn):
sql = """
@@ -1444,5 +1458,6 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (one_day_ago,))
return await self.db_pool.runInteraction(
- "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
+ "_cleanup_old_transaction_ids",
+ _cleanup_old_transaction_ids_txn,
)
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 7218191965..ac07e0197b 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple
+
+from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -26,6 +28,9 @@ from synapse.util import json_encoder
_DEFAULT_CATEGORY_ID = ""
_DEFAULT_ROLE_ID = ""
+# A room in a group.
+_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool})
+
class GroupServerWorkerStore(SQLBaseStore):
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
@@ -72,7 +77,7 @@ class GroupServerWorkerStore(SQLBaseStore):
async def get_rooms_in_group(
self, group_id: str, include_private: bool = False
- ) -> List[Dict[str, Union[str, bool]]]:
+ ) -> List[_RoomInGroup]:
"""Retrieve the rooms that belong to a given group. Does not return rooms that
lack members.
@@ -123,7 +128,9 @@ class GroupServerWorkerStore(SQLBaseStore):
)
async def get_rooms_for_summary_by_category(
- self, group_id: str, include_private: bool = False,
+ self,
+ group_id: str,
+ include_private: bool = False,
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""Get the rooms and categories that should be included in a summary request
@@ -368,8 +375,7 @@ class GroupServerWorkerStore(SQLBaseStore):
async def is_user_invited_to_local_group(
self, group_id: str, user_id: str
) -> Optional[bool]:
- """Has the group server invited a user?
- """
+ """Has the group server invited a user?"""
return await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -427,8 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
)
async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
- """Get all groups a user is publicising
- """
+ """Get all groups a user is publicising"""
return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
@@ -437,8 +442,7 @@ class GroupServerWorkerStore(SQLBaseStore):
)
async def get_attestations_need_renewals(self, valid_until_ms):
- """Get all attestations that need to be renewed until givent time
- """
+ """Get all attestations that need to be renewed until givent time"""
def _get_attestations_need_renewals_txn(txn):
sql = """
@@ -781,8 +785,7 @@ class GroupServerStore(GroupServerWorkerStore):
profile: Optional[JsonDict],
is_public: Optional[bool],
) -> None:
- """Add/update room category for group
- """
+ """Add/update room category for group"""
insertion_values = {}
update_values = {"category_id": category_id} # This cannot be empty
@@ -818,8 +821,7 @@ class GroupServerStore(GroupServerWorkerStore):
profile: Optional[JsonDict],
is_public: Optional[bool],
) -> None:
- """Add/remove user role
- """
+ """Add/remove user role"""
insertion_values = {}
update_values = {"role_id": role_id} # This cannot be empty
@@ -1012,8 +1014,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def add_group_invite(self, group_id: str, user_id: str) -> None:
- """Record that the group server has invited a user
- """
+ """Record that the group server has invited a user"""
await self.db_pool.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
@@ -1156,8 +1157,7 @@ class GroupServerStore(GroupServerWorkerStore):
async def update_group_publicity(
self, group_id: str, user_id: str, publicise: bool
) -> None:
- """Update whether the user is publicising their membership of the group
- """
+ """Update whether the user is publicising their membership of the group"""
await self.db_pool.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -1300,8 +1300,7 @@ class GroupServerStore(GroupServerWorkerStore):
async def update_attestation_renewal(
self, group_id: str, user_id: str, attestation: dict
) -> None:
- """Update an attestation that we have renewed
- """
+ """Update an attestation that we have renewed"""
await self.db_pool.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -1312,8 +1311,7 @@ class GroupServerStore(GroupServerWorkerStore):
async def update_remote_attestion(
self, group_id: str, user_id: str, attestation: dict
) -> None:
- """Update an attestation that a remote has renewed
- """
+ """Update an attestation that a remote has renewed"""
await self.db_pool.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index f8f4bb9b3f..d504323b03 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore
from synapse.storage.keys import FetchKeyResult
+from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -32,8 +33,7 @@ db_binary_type = memoryview
class KeyStore(SQLBaseStore):
- """Persistence for signature verification keys
- """
+ """Persistence for signature verification keys"""
@cached()
def _get_server_verify_key(self, server_name_and_key_id):
@@ -44,7 +44,7 @@ class KeyStore(SQLBaseStore):
)
async def get_server_verify_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
- ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
+ ) -> Dict[Tuple[str, str], FetchKeyResult]:
"""
Args:
server_name_and_key_ids:
@@ -56,7 +56,7 @@ class KeyStore(SQLBaseStore):
"""
keys = {}
- def _get_keys(txn, batch):
+ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch)
@@ -77,13 +77,12 @@ class KeyStore(SQLBaseStore):
# `ts_valid_until_ms`.
ts_valid_until_ms = 0
- res = FetchKeyResult(
+ keys[(server_name, key_id)] = FetchKeyResult(
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms,
)
- keys[(server_name, key_id)] = res
- def _txn(txn):
+ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
for batch in batch_iter(server_name_and_key_ids, 50):
_get_keys(txn, batch)
return keys
@@ -155,7 +154,7 @@ class KeyStore(SQLBaseStore):
(server_name, key_id, from_server) triplet if one already existed.
Args:
server_name: The name of the server.
- key_id: The identifer of the key this JSON is for.
+ key_id: The identifier of the key this JSON is for.
from_server: The server this JSON was fetched from.
ts_now_ms: The time now in milliseconds.
ts_valid_until_ms: The time when this json stops being valid.
@@ -182,7 +181,7 @@ class KeyStore(SQLBaseStore):
async def get_server_keys_json(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
- """Retrive the key json for a list of server_keys and key ids.
+ """Retrieve the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 4b2f224718..4f3d192562 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
@@ -22,6 +24,22 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
)
+class MediaSortOrder(Enum):
+ """
+ Enum to define the sorting method used when returning media with
+ get_local_media_by_user_paginate
+ """
+
+ MEDIA_ID = "media_id"
+ UPLOAD_NAME = "upload_name"
+ CREATED_TS = "created_ts"
+ LAST_ACCESS_TS = "last_access_ts"
+ MEDIA_LENGTH = "media_length"
+ MEDIA_TYPE = "media_type"
+ QUARANTINED_BY = "quarantined_by"
+ SAFE_FROM_QUARANTINE = "safe_from_quarantine"
+
+
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -117,7 +135,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_local_media_by_user_paginate(
- self, start: int, limit: int, user_id: str
+ self,
+ start: int,
+ limit: int,
+ user_id: str,
+ order_by: str = MediaSortOrder.CREATED_TS.value,
+ direction: str = "f",
) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media
which an user_id has uploaded
@@ -126,6 +149,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
start: offset in the list
limit: maximum amount of media_ids to retrieve
user_id: fully-qualified user id
+ order_by: the sort order of the returned list
+ direction: sort ascending or descending
Returns:
A paginated list of all metadata of user's media,
plus the total count of all the user's media
@@ -133,6 +158,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(txn):
+ # Set ordering
+ order_by_column = MediaSortOrder(order_by).value
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
args = [user_id]
sql = """
SELECT COUNT(*) as total_media
@@ -154,9 +187,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"safe_from_quarantine"
FROM local_media_repository
WHERE user_id = ?
- ORDER BY created_ts DESC, media_id DESC
+ ORDER BY {order_by_column} {order}, media_id ASC
LIMIT ? OFFSET ?
- """
+ """.format(
+ order_by_column=order_by_column,
+ order=order,
+ )
args += [limit, start]
txn.execute(sql, args)
@@ -168,8 +204,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_local_media_before(
- self, before_ts: int, size_gt: int, keep_profiles: bool,
- ) -> Optional[List[str]]:
+ self,
+ before_ts: int,
+ size_gt: int,
+ keep_profiles: bool,
+ ) -> List[str]:
# to find files that have never been accessed (last_access_ts IS NULL)
# compare with `created_ts`
@@ -340,16 +379,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- await self.db_pool.simple_insert(
- "local_media_repository_thumbnails",
- {
+ await self.db_pool.simple_upsert(
+ table="local_media_repository_thumbnails",
+ keyvalues={
"media_id": media_id,
"thumbnail_width": thumbnail_width,
"thumbnail_height": thumbnail_height,
"thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type,
- "thumbnail_length": thumbnail_length,
},
+ values={"thumbnail_length": thumbnail_length},
desc="store_local_thumbnail",
)
@@ -416,7 +455,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_origin = ? AND media_id = ?"
)
- txn.executemany(
+ txn.execute_batch(
sql,
(
(time_ms, media_origin, media_id)
@@ -429,7 +468,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_id = ?"
)
- txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
+ txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
@@ -453,10 +492,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_remote_media_thumbnail(
- self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
+ self,
+ origin: str,
+ media_id: str,
+ t_width: int,
+ t_height: int,
+ t_type: str,
) -> Optional[Dict[str, Any]]:
- """Fetch the thumbnail info of given width, height and type.
- """
+ """Fetch the thumbnail info of given width, height and type."""
return await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails",
@@ -490,18 +533,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- await self.db_pool.simple_insert(
- "remote_media_cache_thumbnails",
- {
+ await self.db_pool.simple_upsert(
+ table="remote_media_cache_thumbnails",
+ keyvalues={
"media_origin": origin,
"media_id": media_id,
"thumbnail_width": thumbnail_width,
"thumbnail_height": thumbnail_height,
"thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type,
- "thumbnail_length": thumbnail_length,
- "filesystem_id": filesystem_id,
},
+ values={"thumbnail_length": thumbnail_length},
+ insertion_values={"filesystem_id": filesystem_id},
desc="store_remote_media_thumbnail",
)
@@ -556,7 +599,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache", _delete_url_cache_txn
@@ -585,11 +628,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_media_txn(txn):
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index ab18cc4d79..614a418a15 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1]
)
+ async def count_daily_e2ee_messages(self):
+ """
+ Returns an estimate of the number of messages sent in the last day.
+
+ If it has been significantly less or more than one day since the last
+ call to this function, it will return None.
+ """
+
+ def _count_messages(txn):
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
+
+ async def count_daily_sent_e2ee_messages(self):
+ def _count_messages(txn):
+ # This is good enough as if you have silly characters in your own
+ # hostname then that's your own fault.
+ like_clause = "%:" + self.hs.hostname
+
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND sender LIKE ?
+ AND stream_ordering > ?
+ """
+
+ txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction(
+ "count_daily_sent_e2ee_messages", _count_messages
+ )
+
+ async def count_daily_active_e2ee_rooms(self):
+ def _count(txn):
+ sql = """
+ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction(
+ "count_daily_active_e2ee_rooms", _count
+ )
+
async def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
@@ -111,7 +167,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
async def count_daily_sent_messages(self):
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
- # hostname then thats your own fault.
+ # hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
sql = """
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index d788dc0fc6..757da3d55d 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict, List
+from typing import Dict, List, Optional
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
@@ -109,7 +109,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return users
@cached(num_args=1)
- async def user_last_seen_monthly_active(self, user_id: str) -> int:
+ async def user_last_seen_monthly_active(self, user_id: str) -> Optional[int]:
"""
Checks if a given user is part of the monthly active user group
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index dbbb99cb95..0ff693a310 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Tuple
+from typing import Dict, List, Tuple
from synapse.api.presence import UserPresenceState
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
@@ -130,7 +130,9 @@ class PresenceStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
+ cached_method_name="_get_presence_for_user",
+ list_name="user_ids",
+ num_args=1,
)
async def get_presence_for_users(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(
@@ -155,5 +157,63 @@ class PresenceStore(SQLBaseStore):
return {row["user_id"]: UserPresenceState(**row) for row in rows}
+ async def get_presence_for_all_users(
+ self,
+ include_offline: bool = True,
+ ) -> Dict[str, UserPresenceState]:
+ """Retrieve the current presence state for all users.
+
+ Note that the presence_stream table is culled frequently, so it should only
+ contain the latest presence state for each user.
+
+ Args:
+ include_offline: Whether to include offline presence states
+
+ Returns:
+ A dict of user IDs to their current UserPresenceState.
+ """
+ users_to_state = {}
+
+ exclude_keyvalues = None
+ if not include_offline:
+ # Exclude offline presence state
+ exclude_keyvalues = {"state": "offline"}
+
+ # This may be a very heavy database query.
+ # We paginate in order to not block a database connection.
+ limit = 100
+ offset = 0
+ while True:
+ rows = await self.db_pool.runInteraction(
+ "get_presence_for_all_users",
+ self.db_pool.simple_select_list_paginate_txn,
+ "presence_stream",
+ orderby="stream_id",
+ start=offset,
+ limit=limit,
+ exclude_keyvalues=exclude_keyvalues,
+ retcols=(
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ ),
+ order_direction="ASC",
+ )
+
+ for row in rows:
+ users_to_state[row["user_id"]] = UserPresenceState(**row)
+
+ # We've run out of updates to query
+ if len(rows) < limit:
+ break
+
+ offset += limit
+
+ return users_to_state
+
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 2570003457..1e65cb8a04 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -145,7 +145,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def set_profile_avatar_url(
- self, user_localpart: str, new_avatar_url: str, batchnum: int
+ self, user_localpart: str, new_avatar_url: Optional[str], batchnum: int
) -> None:
# Invalidate the read cache for this user
self.get_profile_avatar_url.invalidate((user_localpart,))
@@ -159,7 +159,11 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def set_profiles_active(
- self, users: List[UserID], active: bool, hide: bool, batchnum: int,
+ self,
+ users: List[UserID],
+ active: bool,
+ hide: bool,
+ batchnum: int,
) -> None:
"""Given a set of users, set active and hidden flags on them.
@@ -240,8 +244,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def is_subscribed_remote_profile_for_user(self, user_id):
- """Check whether we are interested in a remote user's profile.
- """
+ """Check whether we are interested in a remote user's profile."""
res = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
@@ -267,8 +270,7 @@ class ProfileWorkerStore(SQLBaseStore):
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> List[Dict[str, str]]:
- """Get all users who haven't been checked since `last_checked`
- """
+ """Get all users who haven't been checked since `last_checked`"""
def _get_remote_profile_cache_entries_that_expire_txn(txn):
sql = """
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 5d668aadb2..41f4fe7f95 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -28,7 +28,10 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
async def purge_history(
self, room_id: str, token: str, delete_local_events: bool
) -> Set[int]:
- """Deletes room history before a certain point
+ """Deletes room history before a certain point.
+
+ Note that only a single purge can occur at once, this is guaranteed via
+ a higher level (in the PaginationHandler).
Args:
room_id:
@@ -52,7 +55,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
delete_local_events,
)
- def _purge_history_txn(self, txn, room_id, token, delete_local_events):
+ def _purge_history_txn(
+ self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
+ ) -> Set[int]:
# Tables that should be pruned:
# event_auth
# event_backward_extremities
@@ -103,7 +108,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
- # having any backwards extremeties)
+ # having any backwards extremities)
raise SynapseError(
400, "topological_ordering is greater than forward extremeties"
)
@@ -154,7 +159,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
logger.info("[purge] Finding new backward extremities")
- # We calculate the new entries for the backward extremeties by finding
+ # We calculate the new entries for the backward extremities by finding
# events to be purged that are pointed to by events we're not going to
# purge.
txn.execute(
@@ -172,7 +177,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
)
# Update backward extremeties
- txn.executemany(
+ txn.execute_batch(
"INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)",
[(room_id, event_id) for event_id, in new_backwards_extrems],
@@ -296,7 +301,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"purge_room", self._purge_room_txn, room_id
)
- def _purge_room_txn(self, txn, room_id):
+ def _purge_room_txn(self, txn, room_id: str) -> List[int]:
# First we fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
@@ -310,6 +315,27 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
state_groups = [row[0] for row in txn]
+ # Get all the auth chains that are referenced by events that are to be
+ # deleted.
+ txn.execute(
+ """
+ SELECT chain_id, sequence_number FROM events
+ LEFT JOIN event_auth_chains USING (event_id)
+ WHERE room_id = ?
+ """,
+ (room_id,),
+ )
+ referenced_chain_id_tuples = list(txn)
+
+ logger.info("[purge] removing events from event_auth_chain_links")
+ txn.executemany(
+ """
+ DELETE FROM event_auth_chain_links WHERE
+ origin_chain_id = ? AND origin_sequence_number = ?
+ """,
+ referenced_chain_id_tuples,
+ )
+
# Now we delete tables which lack an index on room_id but have one on event_id
for table in (
"event_auth",
@@ -319,6 +345,8 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_reference_hashes",
"event_relations",
"event_to_state_groups",
+ "event_auth_chains",
+ "event_auth_chain_to_calculate",
"redactions",
"rejections",
"state_events",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 711d5aa23d..9e58dc0e6a 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -168,7 +168,9 @@ class PushRulesWorkerStore(
)
@cachedList(
- cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
+ cached_method_name="get_push_rules_for_user",
+ list_name="user_ids",
+ num_args=1,
)
async def bulk_get_push_rules(self, user_ids):
if not user_ids:
@@ -195,7 +197,9 @@ class PushRulesWorkerStore(
use_new_defaults = user_id in self._users_new_default_push_rules
results[user_id] = _load_rules(
- rules, enabled_map_by_user.get(user_id, {}), use_new_defaults,
+ rules,
+ enabled_map_by_user.get(user_id, {}),
+ use_new_defaults,
)
return results
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 7997242d90..c65558c280 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -15,18 +15,41 @@
# limitations under the License.
import logging
-from typing import Iterable, Iterator, List, Tuple
-
-from canonicaljson import encode_canonical_json
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
+from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class PusherWorkerStore(SQLBaseStore):
- def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+ self._pushers_id_gen = StreamIdGenerator(
+ db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ "remove_deactivated_pushers",
+ self._remove_deactivated_pushers,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ "remove_stale_pushers",
+ self._remove_stale_pushers,
+ )
+
+ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded
@@ -44,21 +67,23 @@ class PusherWorkerStore(SQLBaseStore):
)
continue
- yield r
+ yield PusherConfig(**r)
- async def user_has_pusher(self, user_id):
+ async def user_has_pusher(self, user_id: str) -> bool:
ret = await self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
- def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
- return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
+ async def get_pushers_by_app_id_and_pushkey(
+ self, app_id: str, pushkey: str
+ ) -> Iterator[PusherConfig]:
+ return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
- def get_pushers_by_user_id(self, user_id):
- return self.get_pushers_by({"user_name": user_id})
+ async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]:
+ return await self.get_pushers_by({"user_name": user_id})
- async def get_pushers_by(self, keyvalues):
+ async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
ret = await self.db_pool.simple_select_list(
"pushers",
keyvalues,
@@ -83,7 +108,7 @@ class PusherWorkerStore(SQLBaseStore):
)
return self._decode_pushers_rows(ret)
- async def get_all_pushers(self):
+ async def get_all_pushers(self) -> Iterator[PusherConfig]:
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
@@ -159,14 +184,18 @@ class PusherWorkerStore(SQLBaseStore):
)
@cached(num_args=1, max_entries=15000)
- async def get_if_user_has_pusher(self, user_id):
+ async def get_if_user_has_pusher(self, user_id: str):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(
- cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
+ cached_method_name="get_if_user_has_pusher",
+ list_name="user_ids",
+ num_args=1,
)
- async def get_if_users_have_pushers(self, user_ids):
+ async def get_if_users_have_pushers(
+ self, user_ids: Iterable[str]
+ ) -> Dict[str, bool]:
rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
@@ -224,7 +253,7 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated)
async def update_pusher_failing_since(
- self, app_id, pushkey, user_id, failing_since
+ self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int]
) -> None:
await self.db_pool.simple_update(
table="pushers",
@@ -233,7 +262,9 @@ class PusherWorkerStore(SQLBaseStore):
desc="update_pusher_failing_since",
)
- async def get_throttle_params_by_room(self, pusher_id):
+ async def get_throttle_params_by_room(
+ self, pusher_id: str
+ ) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
@@ -243,43 +274,140 @@ class PusherWorkerStore(SQLBaseStore):
params_by_room = {}
for row in res:
- params_by_room[row["room_id"]] = {
- "last_sent_ts": row["last_sent_ts"],
- "throttle_ms": row["throttle_ms"],
- }
+ params_by_room[row["room_id"]] = ThrottleParams(
+ row["last_sent_ts"],
+ row["throttle_ms"],
+ )
return params_by_room
- async def set_throttle_params(self, pusher_id, room_id, params) -> None:
+ async def set_throttle_params(
+ self, pusher_id: str, room_id: str, params: ThrottleParams
+ ) -> None:
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
- params,
+ {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
desc="set_throttle_params",
lock=False,
)
+ async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int:
+ """A background update that deletes all pushers for deactivated users.
+
+ Note that we don't proacively tell the pusherpool that we've deleted
+ these (just because its a bit off a faff to do from here), but they will
+ get cleaned up at the next restart
+ """
+
+ last_user = progress.get("last_user", "")
+
+ def _delete_pushers(txn) -> int:
+
+ sql = """
+ SELECT name FROM users
+ WHERE deactivated = ? and name > ?
+ ORDER BY name ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (1, last_user, batch_size))
+ users = [row[0] for row in txn]
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="pushers",
+ column="user_name",
+ iterable=users,
+ keyvalues={},
+ )
+
+ if users:
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "remove_deactivated_pushers", {"last_user": users[-1]}
+ )
+
+ return len(users)
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_deactivated_pushers", _delete_pushers
+ )
+
+ if number_deleted < batch_size:
+ await self.db_pool.updates._end_background_update(
+ "remove_deactivated_pushers"
+ )
+
+ return number_deleted
+
+ async def _remove_stale_pushers(self, progress: dict, batch_size: int) -> int:
+ """A background update that deletes all pushers for logged out devices.
+
+ Note that we don't proacively tell the pusherpool that we've deleted
+ these (just because its a bit off a faff to do from here), but they will
+ get cleaned up at the next restart
+ """
+
+ last_pusher = progress.get("last_pusher", 0)
+
+ def _delete_pushers(txn) -> int:
+
+ sql = """
+ SELECT p.id, access_token FROM pushers AS p
+ LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
+ WHERE p.id > ?
+ ORDER BY p.id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_pusher, batch_size))
+ pushers = [(row[0], row[1]) for row in txn]
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="pushers",
+ column="id",
+ iterable=(pusher_id for pusher_id, token in pushers if token is None),
+ keyvalues={},
+ )
+
+ if pushers:
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "remove_stale_pushers", {"last_pusher": pushers[-1][0]}
+ )
+
+ return len(pushers)
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_stale_pushers", _delete_pushers
+ )
+
+ if number_deleted < batch_size:
+ await self.db_pool.updates._end_background_update("remove_stale_pushers")
+
+ return number_deleted
+
class PusherStore(PusherWorkerStore):
- def get_pushers_stream_token(self):
+ def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
async def add_pusher(
self,
- user_id,
- access_token,
- kind,
- app_id,
- app_display_name,
- device_display_name,
- pushkey,
- pushkey_ts,
- lang,
- data,
- last_stream_ordering,
- profile_tag="",
+ user_id: str,
+ access_token: Optional[int],
+ kind: str,
+ app_id: str,
+ app_display_name: str,
+ device_display_name: str,
+ pushkey: str,
+ pushkey_ts: int,
+ lang: Optional[str],
+ data: Optional[JsonDict],
+ last_stream_ordering: int,
+ profile_tag: str = "",
) -> None:
async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
@@ -294,7 +422,7 @@ class PusherStore(PusherWorkerStore):
"device_display_name": device_display_name,
"ts": pushkey_ts,
"lang": lang,
- "data": bytearray(encode_canonical_json(data)),
+ "data": json_encoder.encode(data),
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
@@ -311,20 +439,22 @@ class PusherStore(PusherWorkerStore):
# invalidate, since we the user might not have had a pusher before
await self.db_pool.runInteraction(
"add_pusher",
- self._invalidate_cache_and_stream,
+ self._invalidate_cache_and_stream, # type: ignore
self.get_if_user_has_pusher,
(user_id,),
)
async def delete_pusher_by_app_id_pushkey_user_id(
- self, app_id, pushkey, user_id
+ self, app_id: str, pushkey: str, user_id: str
) -> None:
def delete_pusher_txn(txn, stream_id):
- self._invalidate_cache_and_stream(
+ self._invalidate_cache_and_stream( # type: ignore
txn, self.get_if_user_has_pusher, (user_id,)
)
- self.db_pool.simple_delete_one_txn(
+ # It is expected that there is exactly one pusher to delete, but
+ # if it isn't there (or there are multiple) delete them all.
+ self.db_pool.simple_delete_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@@ -348,3 +478,46 @@ class PusherStore(PusherWorkerStore):
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
+
+ async def delete_all_pushers_for_user(self, user_id: str) -> None:
+ """Delete all pushers associated with an account."""
+
+ # We want to generate a row in `deleted_pushers` for each pusher we're
+ # deleting, so we fetch the list now so we can generate the appropriate
+ # number of stream IDs.
+ #
+ # Note: technically there could be a race here between adding/deleting
+ # pushers, but a) the worst case if we don't stop a pusher until the
+ # next restart and b) this is only called when we're deactivating an
+ # account.
+ pushers = list(await self.get_pushers_by_user_id(user_id))
+
+ def delete_pushers_txn(txn, stream_ids):
+ self._invalidate_cache_and_stream( # type: ignore
+ txn, self.get_if_user_has_pusher, (user_id,)
+ )
+
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="pushers",
+ keyvalues={"user_name": user_id},
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="deleted_pushers",
+ values=[
+ {
+ "stream_id": stream_id,
+ "app_id": pusher.app_id,
+ "pushkey": pusher.pushkey,
+ "user_id": user_id,
+ }
+ for stream_id, pusher in zip(stream_ids, pushers)
+ ],
+ )
+
+ async with self._pushers_id_gen.get_next_mult(len(pushers)) as stream_ids:
+ await self.db_pool.runInteraction(
+ "delete_all_pushers_for_user", delete_pushers_txn, stream_ids
+ )
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1e7949a323..43c852c96c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,15 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
import logging
from typing import Any, Dict, List, Optional, Tuple
from twisted.internet import defer
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
- """This is an abstract base class where subclasses must implement
- `get_max_receipt_stream_id` which can be called in the initializer.
- """
-
+class ReceiptsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
+ self._instance_name = hs.get_instance_name()
+
+ if isinstance(database.engine, PostgresEngine):
+ self._can_write_to_receipts = (
+ self._instance_name in hs.config.worker.writers.receipts
+ )
+
+ self._receipts_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="receipts",
+ instance_name=self._instance_name,
+ tables=[("receipts_linearized", "instance_name", "stream_id")],
+ sequence_name="receipts_sequence",
+ writers=hs.config.worker.writers.receipts,
+ )
+ else:
+ self._can_write_to_receipts = True
+
+ # We shouldn't be running in worker mode with SQLite, but its useful
+ # to support it for unit tests.
+ #
+ # If this process is the writer than we need to use
+ # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+ # updated over replication. (Multiple writers are not supported for
+ # SQLite).
+ if hs.get_instance_name() in hs.config.worker.writers.receipts:
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+ else:
+ self._receipts_id_gen = SlavedIdTracker(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+
super().__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
- @abc.abstractmethod
def get_max_receipt_stream_id(self):
"""Get the current max stream ID for receipts stream
Returns:
int
"""
- raise NotImplementedError()
+ return self._receipts_id_gen.get_current_token()
@cached()
async def get_users_with_read_receipts_in_room(self, room_id):
@@ -130,7 +160,7 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
Args:
room_id: List of room_ids.
- to_key: Max stream id to fetch receipts upto.
+ to_key: Max stream id to fetch receipts up to.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
@@ -159,7 +189,7 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
Args:
room_ids: The room id.
- to_key: Max stream id to fetch receipts upto.
+ to_key: Max stream id to fetch receipts up to.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
@@ -178,8 +208,7 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
- """See get_linearized_receipts_for_room
- """
+ """See get_linearized_receipts_for_room"""
def f(txn):
if from_key:
@@ -274,7 +303,9 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
}
return results
- @cached(num_args=2,)
+ @cached(
+ num_args=2,
+ )
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> Dict[str, JsonDict]:
@@ -282,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
to a limit of the latest 100 read receipts.
Args:
- to_key: Max stream id to fetch receipts upto.
+ to_key: Max stream id to fetch receipts up to.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
@@ -428,19 +459,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
-
-class ReceiptsStore(ReceiptsWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- # We instantiate this first as the ReceiptsWorkerStore constructor
- # needs to be able to call get_max_receipt_stream_id
- self._receipts_id_gen = StreamIdGenerator(
- db_conn, "receipts_linearized", "stream_id"
+ def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+ self.get_receipts_for_user.invalidate((user_id, receipt_type))
+ self._get_linearized_receipts_for_room.invalidate_many((room_id,))
+ self.get_last_receipt_event_id_for_user.invalidate(
+ (user_id, room_id, receipt_type)
)
+ self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+ self.get_receipts_for_room.invalidate((room_id, receipt_type))
+
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == ReceiptsStream.NAME:
+ self._receipts_id_gen.advance(instance_name, token)
+ for row in rows:
+ self.invalidate_caches_for_receipt(
+ row.room_id, row.receipt_type, row.user_id
+ )
+ self._receipts_stream_cache.entity_has_changed(row.room_id, token)
- super().__init__(database, db_conn, hs)
-
- def get_max_receipt_stream_id(self):
- return self._receipts_id_gen.get_current_token()
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
def insert_linearized_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
@@ -452,6 +489,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
+ assert self._can_write_to_receipts
+
res = self.db_pool.simple_select_one_txn(
txn,
table="events",
@@ -483,28 +522,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
)
return None
- txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
- txn.call_after(
- self._invalidate_get_users_with_receipts_in_room,
- room_id,
- receipt_type,
- user_id,
- )
- txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
- # FIXME: This shouldn't invalidate the whole cache
txn.call_after(
- self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
+ self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
)
txn.call_after(
self._receipts_stream_cache.entity_has_changed, room_id, stream_id
)
- txn.call_after(
- self.get_last_receipt_event_id_for_user.invalidate,
- (user_id, room_id, receipt_type),
- )
-
self.db_pool.simple_upsert_txn(
txn,
table="receipts_linearized",
@@ -543,6 +568,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
Automatically does conversion between linearized and graph
representations.
"""
+ assert self._can_write_to_receipts
+
if not event_ids:
return None
@@ -607,6 +634,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
async def insert_graph_receipt(
self, room_id, receipt_type, user_id, event_ids, data
):
+ assert self._can_write_to_receipts
+
return await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
@@ -620,6 +649,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
def insert_graph_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_ids, data
):
+ assert self._can_write_to_receipts
+
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
@@ -653,3 +684,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"data": json_encoder.encode(data),
},
)
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+ pass
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 1624738409..c439b58db6 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,14 +16,14 @@
# limitations under the License.
import logging
import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor
@@ -70,7 +70,12 @@ class TokenLookupResult:
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -79,7 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# call `find_max_generated_user_id_localpart` each time, which is
# expensive if there are many entries.
self._user_id_seq = build_sequence_generator(
- database.engine, find_max_generated_user_id_localpart, "user_id_seq",
+ db_conn,
+ database.engine,
+ find_max_generated_user_id_localpart,
+ "user_id_seq",
+ table=None,
+ id_column=None,
)
self._account_validity_enabled = hs.config.account_validity_enabled
@@ -91,7 +101,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
if hs.config.run_background_tasks:
self._clock.call_later(
- 0.0, self._set_expiration_date_when_missing,
+ 0.0,
+ self._set_expiration_date_when_missing,
)
# Create a background job for culling expired 3PID validity tokens
@@ -116,6 +127,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"creation_ts",
"user_type",
"deactivated",
+ "shadow_banned",
],
allow_none=True,
desc="get_user_by_id",
@@ -379,7 +391,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
async def get_info_for_users(
- self, user_ids: List[str],
+ self,
+ user_ids: List[str],
):
"""Return the user info for a given set of users
@@ -463,6 +476,37 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
+ async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
+ """Sets whether a user shadow-banned.
+
+ Args:
+ user: user ID of the user to test
+ shadow_banned: true iff the user is to be shadow-banned, false otherwise.
+ """
+
+ def set_shadow_banned_txn(txn):
+ user_id = user.to_string()
+ self.db_pool.simple_update_one_txn(
+ txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"shadow_banned": shadow_banned},
+ )
+ # In order for this to apply immediately, clear the cache for this user.
+ tokens = self.db_pool.simple_select_onecol_txn(
+ txn,
+ table="access_tokens",
+ keyvalues={"user_id": user_id},
+ retcol="token",
+ )
+ for token in tokens:
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_access_token, (token,)
+ )
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
+ await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
+
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
@@ -546,6 +590,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
+ async def record_user_external_id(
+ self, auth_provider: str, external_id: str, user_id: str
+ ) -> None:
+ """Record a mapping from an external user id to a mxid
+
+ Args:
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+ user_id: complete mxid that it is mapped to
+ """
+ await self.db_pool.simple_insert(
+ table="user_external_ids",
+ values={
+ "auth_provider": auth_provider,
+ "external_id": external_id,
+ "user_id": user_id,
+ },
+ desc="record_user_external_id",
+ )
+
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
) -> Optional[str]:
@@ -566,6 +630,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_external_id",
)
+ async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
+ """Look up external ids for the given user
+
+ Args:
+ mxid: the MXID to be looked up
+
+ Returns:
+ Tuples of (auth_provider, external_id)
+ """
+ res = await self.db_pool.simple_select_list(
+ table="user_external_ids",
+ keyvalues={"user_id": mxid},
+ retcols=("auth_provider", "external_id"),
+ desc="get_external_ids_by_user",
+ )
+ return [(r["auth_provider"], r["external_id"]) for r in res]
+
async def count_all_users(self):
"""Counts all users registered on the homeserver."""
@@ -1029,9 +1110,50 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="del_user_pending_deactivation",
)
+ async def get_access_token_last_validated(self, token_id: int) -> int:
+ """Retrieves the time (in milliseconds) of the last validation of an access token.
+
+ Args:
+ token_id: The ID of the access token to update.
+ Raises:
+ StoreError if the access token was not found.
+
+ Returns:
+ The last validation time.
+ """
+ result = await self.db_pool.simple_select_one_onecol(
+ "access_tokens", {"id": token_id}, "last_validated"
+ )
+
+ # If this token has not been validated (since starting to track this),
+ # return 0 instead of None.
+ return result or 0
+
+ async def update_access_token_last_validated(self, token_id: int) -> None:
+ """Updates the last time an access token was validated.
+
+ Args:
+ token_id: The ID of the access token to update.
+ Raises:
+ StoreError if there was a problem updating this.
+ """
+ now = self._clock.time_msec()
+
+ await self.db_pool.simple_update_one(
+ "access_tokens",
+ {"id": token_id},
+ {"last_validated": now},
+ desc="update_access_token_last_validated",
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
@@ -1066,6 +1188,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
+ self.db_pool.updates.register_background_index_update(
+ "user_external_ids_user_id_idx",
+ index_name="user_external_ids_user_id_idx",
+ table="user_external_ids",
+ columns=["user_id"],
+ unique=False,
+ )
+
async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
@@ -1146,7 +1276,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
FROM user_threepids
"""
- txn.executemany(sql, [(id_server,) for id_server in id_servers])
+ txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
if id_servers:
await self.db_pool.runInteraction(
@@ -1184,6 +1314,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
@cached()
@@ -1228,6 +1359,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
The token ID
"""
next_id = self._access_tokens_id_gen.get_next()
+ now = self._clock.time_msec()
await self.db_pool.simple_insert(
"access_tokens",
@@ -1238,6 +1370,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"device_id": device_id,
"valid_until_ms": valid_until_ms,
"puppets_user_id": puppets_user_id,
+ "last_validated": now,
},
desc="add_access_token_to_user",
)
@@ -1411,26 +1544,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- async def record_user_external_id(
- self, auth_provider: str, external_id: str, user_id: str
- ) -> None:
- """Record a mapping from an external user id to a mxid
-
- Args:
- auth_provider: identifier for the remote auth provider
- external_id: id on that system
- user_id: complete mxid that it is mapped to
- """
- await self.db_pool.simple_insert(
- table="user_external_ids",
- values={
- "auth_provider": auth_provider,
- "external_id": external_id,
- "user_id": user_id,
- },
- desc="record_user_external_id",
- )
-
async def user_set_password_hash(
self, user_id: str, password_hash: Optional[str]
) -> None:
@@ -1502,7 +1615,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
async def user_delete_access_tokens(
self,
user_id: str,
- except_token_id: Optional[str] = None,
+ except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
@@ -1525,7 +1638,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
- values = [v for _, v in items]
+ values = [v for _, v in items] # type: List[Union[str, int]]
if except_token_id:
where_clause += " AND id != ?"
values.append(except_token_id)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 2b781a17cc..8db6f1396a 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -15,7 +15,6 @@
# limitations under the License.
import collections
import logging
-import re
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
@@ -29,6 +28,7 @@ from synapse.storage.databases.main.search import SearchStore
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
+from synapse.util.stringutils import MXC_REGEX
logger = logging.getLogger(__name__)
@@ -83,7 +83,7 @@ class RoomWorkerStore(SQLBaseStore):
return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
- retcols=("room_id", "is_public", "creator"),
+ retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
desc="get_room",
allow_none=True,
)
@@ -194,8 +194,7 @@ class RoomWorkerStore(SQLBaseStore):
)
async def get_room_count(self) -> int:
- """Retrieve the total number of rooms.
- """
+ """Retrieve the total number of rooms."""
def f(txn):
sql = "SELECT count(*) FROM rooms"
@@ -399,14 +398,14 @@ class RoomWorkerStore(SQLBaseStore):
# Filter room names by a string
where_statement = ""
if search_term:
- where_statement = "WHERE state.name LIKE ?"
+ where_statement = "WHERE LOWER(state.name) LIKE ?"
# Our postgres db driver converts ? -> %s in SQL strings as that's the
# placeholder for postgres.
# HOWEVER, if you put a % into your SQL then everything goes wibbly.
# To get around this, we're going to surround search_term with %'s
# before giving it to the database in python instead
- search_term = "%" + search_term + "%"
+ search_term = "%" + search_term.lower() + "%"
# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
@@ -537,7 +536,8 @@ class RoomWorkerStore(SQLBaseStore):
return rooms, room_count[0]
return await self.db_pool.runInteraction(
- "get_rooms_paginate", _get_rooms_paginate_txn,
+ "get_rooms_paginate",
+ _get_rooms_paginate_txn,
)
@cached(max_entries=10000)
@@ -603,7 +603,8 @@ class RoomWorkerStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
ret = await self.db_pool.runInteraction(
- "get_retention_policy_for_room", get_retention_policy_for_room_txn,
+ "get_retention_policy_for_room",
+ get_retention_policy_for_room_txn,
)
# If we don't know this room ID, ret will be None, in this case return the default
@@ -685,8 +686,6 @@ class RoomWorkerStore(SQLBaseStore):
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
- mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
-
sql = """
SELECT stream_ordering, json FROM events
JOIN event_json USING (room_id, event_id)
@@ -713,7 +712,7 @@ class RoomWorkerStore(SQLBaseStore):
for url in (content_url, thumbnail_url):
if not url:
continue
- matches = mxc_re.match(url)
+ matches = MXC_REGEX.match(url)
if matches:
hostname = matches.group(1)
media_id = matches.group(2)
@@ -734,7 +733,10 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
async def quarantine_media_by_id(
- self, server_name: str, media_id: str, quarantined_by: str,
+ self,
+ server_name: str,
+ media_id: str,
+ quarantined_by: str,
) -> int:
"""quarantines a single local or remote media id
@@ -988,7 +990,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self.config = hs.config
self.db_pool.updates.register_background_update_handler(
- "insert_room_retention", self._background_insert_retention,
+ "insert_room_retention",
+ self._background_insert_retention,
)
self.db_pool.updates.register_background_update_handler(
@@ -1060,7 +1063,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return False
end = await self.db_pool.runInteraction(
- "insert_room_retention", _background_insert_retention_txn,
+ "insert_room_retention",
+ _background_insert_retention_txn,
)
if end:
@@ -1071,7 +1075,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
async def _background_add_rooms_room_version_column(
self, progress: dict, batch_size: int
):
- """Background update to go and add room version inforamtion to `rooms`
+ """Background update to go and add room version information to `rooms`
table from `current_state_events` table.
"""
@@ -1191,6 +1195,37 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
# It's overridden by RoomStore for the synapse master.
raise NotImplementedError()
+ async def has_auth_chain_index(self, room_id: str) -> bool:
+ """Check if the room has (or can have) a chain cover index.
+
+ Defaults to True if we don't have an entry in `rooms` table nor any
+ events for the room.
+ """
+
+ has_auth_chain_index = await self.db_pool.simple_select_one_onecol(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcol="has_auth_chain_index",
+ desc="has_auth_chain_index",
+ allow_none=True,
+ )
+
+ if has_auth_chain_index:
+ return True
+
+ # It's possible that we already have events for the room in our DB
+ # without a corresponding room entry. If we do then we don't want to
+ # mark the room as having an auth chain cover index.
+ max_ordering = await self.db_pool.simple_select_one_onecol(
+ table="events",
+ keyvalues={"room_id": room_id},
+ retcol="MAX(stream_ordering)",
+ allow_none=True,
+ desc="upsert_room_on_join",
+ )
+
+ return max_ordering is None
+
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -1204,12 +1239,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
Called when we join a room over federation, and overwrites any room version
currently in the table.
"""
+ # It's possible that we already have events for the room in our DB
+ # without a corresponding room entry. If we do then we don't want to
+ # mark the room as having an auth chain cover index.
+ has_auth_chain_index = await self.has_auth_chain_index(room_id)
+
await self.db_pool.simple_upsert(
desc="upsert_room_on_join",
table="rooms",
keyvalues={"room_id": room_id},
values={"room_version": room_version.identifier},
- insertion_values={"is_public": False, "creator": ""},
+ insertion_values={
+ "is_public": False,
+ "creator": "",
+ "has_auth_chain_index": has_auth_chain_index,
+ },
# rooms has a unique constraint on room_id, so no need to lock when doing an
# emulated upsert.
lock=False,
@@ -1244,6 +1288,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"creator": room_creator_user_id,
"is_public": is_public,
"room_version": room_version.identifier,
+ "has_auth_chain_index": True,
},
)
if is_public:
@@ -1272,6 +1317,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
When we receive an invite or any other event over federation that may relate to a room
we are not in, store the version of the room if we don't already know the room version.
"""
+ # It's possible that we already have events for the room in our DB
+ # without a corresponding room entry. If we do then we don't want to
+ # mark the room as having an auth chain cover index.
+ has_auth_chain_index = await self.has_auth_chain_index(room_id)
+
await self.db_pool.simple_upsert(
desc="maybe_store_room_on_outlier_membership",
table="rooms",
@@ -1281,6 +1331,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"room_version": room_version.identifier,
"is_public": False,
"creator": "",
+ "has_auth_chain_index": has_auth_chain_index,
},
# rooms has a unique constraint on room_id, so no need to lock when doing an
# emulated upsert.
@@ -1568,7 +1619,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
LIMIT ?
OFFSET ?
""".format(
- where_clause=where_clause, order=order,
+ where_clause=where_clause,
+ order=order,
)
args += [limit, start]
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index dcdaf09682..a9216ca9ae 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -70,10 +70,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
):
self._known_servers_count = 1
self.hs.get_clock().looping_call(
- self._count_known_servers, 60 * 1000,
+ self._count_known_servers,
+ 60 * 1000,
)
self.hs.get_clock().call_later(
- 1000, self._count_known_servers,
+ 1000,
+ self._count_known_servers,
)
LaterGauge(
"synapse_federation_known_servers",
@@ -174,7 +176,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000)
async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
- """ Get the details of a room roughly suitable for use by the room
+ """Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
room_id: The room ID to query
@@ -488,8 +490,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_users_who_share_room_with_user(
self, user_id: str, cache_context: _CacheContext
) -> Set[str]:
- """Returns the set of users who share a room with `user_id`
- """
+ """Returns the set of users who share a room with `user_id`"""
room_ids = await self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate
)
@@ -618,7 +619,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
+ cached_method_name="_get_joined_profile_from_event_id",
+ list_name="event_ids",
)
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
@@ -802,8 +804,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
- """Get user_id and membership of a set of event IDs.
- """
+ """Get user_id and membership of a set of event IDs."""
return await self.db_pool.simple_select_many_batch(
table="room_memberships",
@@ -873,8 +874,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
"max_stream_id_exclusive", self._stream_order_on_start + 1
)
- INSERT_CLUMP_SIZE = 1000
-
def add_membership_profile_txn(txn):
sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -915,9 +914,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
UPDATE room_memberships SET display_name = ?, avatar_url = ?
WHERE event_id = ? AND room_id = ?
"""
- for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
- clump = to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(to_update_sql, clump)
+ txn.execute_batch(to_update_sql, to_update)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index ad875c733a..3907189e29 100644
--- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,5 +23,6 @@ def run_create(cur, database_engine, *args, **kwargs):
def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute(
- "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
+ "UPDATE remote_media_cache SET last_access_ts = ?",
+ (int(time.time() * 1000),),
)
diff --git a/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
new file mode 100644
index 0000000000..8f5e65aa71
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5825, 'user_external_ids_user_id_idx', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql
new file mode 100644
index 0000000000..1a101cd5eb
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The last time this access token was "validated" (i.e. logged in or succeeded
+-- at user-interactive authentication).
+ALTER TABLE access_tokens ADD COLUMN last_validated BIGINT;
diff --git a/synapse/storage/databases/main/schema/delta/58/27local_invites.sql b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql
new file mode 100644
index 0000000000..44b2a0572f
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql
@@ -0,0 +1,18 @@
+/*
+ * Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is unused since Synapse v1.17.0.
+DROP TABLE local_invites;
diff --git a/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres
new file mode 100644
index 0000000000..de57645019
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE access_tokens DROP COLUMN last_used;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite
new file mode 100644
index 0000000000..ee0e3521bf
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ -- Dropping last_used column from access_tokens table.
+
+CREATE TABLE access_tokens2 (
+ id BIGINT PRIMARY KEY,
+ user_id TEXT NOT NULL,
+ device_id TEXT,
+ token TEXT NOT NULL,
+ valid_until_ms BIGINT,
+ puppets_user_id TEXT,
+ last_validated BIGINT,
+ UNIQUE(token)
+);
+
+INSERT INTO access_tokens2(id, user_id, device_id, token)
+ SELECT id, user_id, device_id, token FROM access_tokens;
+
+DROP TABLE access_tokens;
+ALTER TABLE access_tokens2 RENAME TO access_tokens;
+
+CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id);
+
+
+-- Re-adding foreign key reference in event_txn_id table
+
+CREATE TABLE event_txn_id2 (
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ token_id BIGINT NOT NULL,
+ txn_id TEXT NOT NULL,
+ inserted_ts BIGINT NOT NULL,
+ FOREIGN KEY (event_id)
+ REFERENCES events (event_id) ON DELETE CASCADE,
+ FOREIGN KEY (token_id)
+ REFERENCES access_tokens (id) ON DELETE CASCADE
+);
+
+INSERT INTO event_txn_id2(event_id, room_id, user_id, token_id, txn_id, inserted_ts)
+ SELECT event_id, room_id, user_id, token_id, txn_id, inserted_ts FROM event_txn_id;
+
+DROP TABLE event_txn_id;
+ALTER TABLE event_txn_id2 RENAME TO event_txn_id;
+
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
+CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
new file mode 100644
index 0000000000..9e8f35c1d2
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -0,0 +1,82 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This migration denormalises the account_data table into an ignored users table.
+"""
+
+import logging
+from io import StringIO
+
+from synapse.storage._base import db_to_json
+from synapse.storage.engines import BaseDatabaseEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
+from synapse.storage.types import Cursor
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+ pass
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+ logger.info("Creating ignored_users table")
+ execute_statements_from_stream(cur, StringIO(_create_commands))
+
+ # We now upgrade existing data, if any. We don't do this in `run_upgrade` as
+ # we a) want to run these before adding constraints and b) `run_upgrade` is
+ # not run on empty databases.
+ insert_sql = """
+ INSERT INTO ignored_users (ignorer_user_id, ignored_user_id) VALUES (?, ?)
+ """
+
+ logger.info("Converting existing ignore lists")
+ cur.execute(
+ "SELECT user_id, content FROM account_data WHERE account_data_type = 'm.ignored_user_list'"
+ )
+ for user_id, content_json in cur.fetchall():
+ content = db_to_json(content_json)
+
+ # The content should be the form of a dictionary with a key
+ # "ignored_users" pointing to a dictionary with keys of ignored users.
+ #
+ # { "ignored_users": "@someone:example.org": {} }
+ ignored_users = content.get("ignored_users", {})
+ if isinstance(ignored_users, dict) and ignored_users:
+ cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
+
+ # Add indexes after inserting data for efficiency.
+ logger.info("Adding constraints to ignored_users table")
+ execute_statements_from_stream(cur, StringIO(_constraints_commands))
+
+
+# there might be duplicates, so the easiest way to achieve this is to create a new
+# table with the right data, and renaming it into place
+
+_create_commands = """
+-- Users which are ignored when calculating push notifications. This data is
+-- denormalized from account data.
+CREATE TABLE IF NOT EXISTS ignored_users(
+ ignorer_user_id TEXT NOT NULL, -- The user ID of the user who is ignoring another user. (This is a local user.)
+ ignored_user_id TEXT NOT NULL -- The user ID of the user who is being ignored. (This is a local or remote user.)
+);
+"""
+
+_constraints_commands = """
+CREATE UNIQUE INDEX ignored_users_uniqueness ON ignored_users (ignorer_user_id, ignored_user_id);
+
+-- Add an index on ignored_users since look-ups are done to get all ignorers of an ignored user.
+CREATE INDEX ignored_users_ignored_user_id ON ignored_users (ignored_user_id);
+"""
diff --git a/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql
new file mode 100644
index 0000000000..d781a92fec
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE device_inbox ADD COLUMN instance_name TEXT;
+ALTER TABLE device_federation_inbox ADD COLUMN instance_name TEXT;
+ALTER TABLE device_federation_outbox ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres
new file mode 100644
index 0000000000..45a845a3a5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS device_inbox_sequence;
+
+-- We need to take the max across both device_inbox and device_federation_outbox
+-- tables as they share the ID generator
+SELECT setval('device_inbox_sequence', (
+ SELECT GREATEST(
+ (SELECT COALESCE(MAX(stream_id), 1) FROM device_inbox),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM device_federation_outbox)
+ )
+));
diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql
new file mode 100644
index 0000000000..729196cfd5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql
@@ -0,0 +1,52 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- See docs/auth_chain_difference_algorithm.md
+
+CREATE TABLE event_auth_chains (
+ event_id TEXT PRIMARY KEY,
+ chain_id BIGINT NOT NULL,
+ sequence_number BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number);
+
+
+CREATE TABLE event_auth_chain_links (
+ origin_chain_id BIGINT NOT NULL,
+ origin_sequence_number BIGINT NOT NULL,
+
+ target_chain_id BIGINT NOT NULL,
+ target_sequence_number BIGINT NOT NULL
+);
+
+
+CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id);
+
+
+-- Events that we have persisted but not calculated auth chains for,
+-- e.g. out of band memberships (where we don't have the auth chain)
+CREATE TABLE event_auth_chain_to_calculate (
+ event_id TEXT PRIMARY KEY,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL
+);
+
+CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id);
+
+
+-- Whether we've calculated the above index for a room.
+ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres
new file mode 100644
index 0000000000..e8a035bbeb
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id;
diff --git a/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql
new file mode 100644
index 0000000000..64ab696cfe
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is no longer used and was only kept until we bumped the schema version.
+DROP TABLE IF EXISTS account_data_max_stream_id;
diff --git a/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql
new file mode 100644
index 0000000000..fb71b360a0
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is no longer used and was only kept until we bumped the schema version.
+DROP TABLE IF EXISTS cache_invalidation_stream;
diff --git a/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql
new file mode 100644
index 0000000000..fe3dca71dd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (5906, 'chain_cover', '{}', 'rejected_events_metadata');
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
new file mode 100644
index 0000000000..46abf8d562
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_account_data ADD COLUMN instance_name TEXT;
+ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT;
+ALTER TABLE account_data ADD COLUMN instance_name TEXT;
+
+ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
new file mode 100644
index 0000000000..4a6e6c74f5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
@@ -0,0 +1,32 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS account_data_sequence;
+
+-- We need to take the max across all the account_data tables as they share the
+-- ID generator
+SELECT setval('account_data_sequence', (
+ SELECT GREATEST(
+ (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM account_data)
+ )
+));
+
+CREATE SEQUENCE IF NOT EXISTS receipts_sequence;
+
+SELECT setval('receipts_sequence', (
+ SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized
+));
diff --git a/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql
new file mode 100644
index 0000000000..9f2b5ebc5a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We incorrectly populated these, so we delete them and let the
+-- MultiWriterIdGenerator repopulate it.
+DELETE FROM stream_positions WHERE stream_name = 'receipts' OR stream_name = 'account_data';
diff --git a/synapse/storage/databases/main/schema/delta/59/08delete_pushers_for_deactivated_accounts.sql b/synapse/storage/databases/main/schema/delta/59/08delete_pushers_for_deactivated_accounts.sql
new file mode 100644
index 0000000000..0ec6764150
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/08delete_pushers_for_deactivated_accounts.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- We may not have deleted all pushers for deactivated accounts, so we set up a
+-- background job to delete them.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5908, 'remove_deactivated_pushers', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql
new file mode 100644
index 0000000000..85196db288
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Delete all pushers associated with deleted devices. This is to clear up after
+-- a bug where they weren't correctly deleted when using workers.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5908, 'remove_stale_pushers', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/59/09rejected_events_metadata.sql b/synapse/storage/databases/main/schema/delta/59/09rejected_events_metadata.sql
new file mode 100644
index 0000000000..cc9b267c7d
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/09rejected_events_metadata.sql
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This originally was in 58/, but landed after 59/ was created, and so some
+-- servers running develop didn't run this delta. Running it again should be
+-- safe.
+--
+-- We first delete any in progress `rejected_events_metadata` background update,
+-- to ensure that we don't conflict when trying to insert the new one. (We could
+-- alternatively do an ON CONFLICT DO NOTHING, but that syntax isn't supported
+-- by older SQLite versions. Plus, this should be a rare case).
+DELETE FROM background_updates WHERE update_name = 'rejected_events_metadata';
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5828, 'rejected_events_metadata', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql
new file mode 100644
index 0000000000..87cb1f3cfd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5910, 'purged_chain_cover', '{}');
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index e28ec3fa45..301c566a70 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
@@ -67,11 +67,6 @@ CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT
CREATE INDEX user_threepids_user_id ON user_threepids(user_id);
CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value )
/* event_search(event_id,room_id,sender,"key",value) */;
-CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value');
-CREATE TABLE IF NOT EXISTS 'event_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB);
-CREATE TABLE IF NOT EXISTS 'event_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx));
-CREATE TABLE IF NOT EXISTS 'event_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB);
-CREATE TABLE IF NOT EXISTS 'event_search_stat'(id INTEGER PRIMARY KEY, value BLOB);
CREATE TABLE guest_access( event_id TEXT NOT NULL, room_id TEXT NOT NULL, guest_access TEXT NOT NULL, UNIQUE (event_id) );
CREATE TABLE history_visibility( event_id TEXT NOT NULL, room_id TEXT NOT NULL, history_visibility TEXT NOT NULL, UNIQUE (event_id) );
CREATE TABLE room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag TEXT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) );
@@ -149,11 +144,6 @@ CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_las
CREATE TABLE user_directory_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') );
CREATE VIRTUAL TABLE user_directory_search USING fts4 ( user_id, value )
/* user_directory_search(user_id,value) */;
-CREATE TABLE IF NOT EXISTS 'user_directory_search_content'(docid INTEGER PRIMARY KEY, 'c0user_id', 'c1value');
-CREATE TABLE IF NOT EXISTS 'user_directory_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB);
-CREATE TABLE IF NOT EXISTS 'user_directory_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx));
-CREATE TABLE IF NOT EXISTS 'user_directory_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB);
-CREATE TABLE IF NOT EXISTS 'user_directory_search_stat'(id INTEGER PRIMARY KEY, value BLOB);
CREATE TABLE blocked_rooms ( room_id TEXT NOT NULL, user_id TEXT NOT NULL );
CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id);
CREATE TABLE IF NOT EXISTS "local_media_repository_url_cache"( url TEXT, response_code INTEGER, etag TEXT, expires_ts BIGINT, og TEXT, media_id TEXT, download_ts BIGINT );
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e34fce6281..f5e7d9ef98 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -63,7 +64,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
- txn.executemany(sql, args)
+ txn.execute_batch(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
@@ -75,7 +76,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
- txn.executemany(sql, args)
+ txn.execute_batch(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
async def search_rooms(
self,
- room_ids: List[str],
+ room_ids: Collection[str],
search_term: str,
keys: List[str],
limit,
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 3c1e33819b..a7f371732f 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -52,8 +52,7 @@ class _GetStateGroupDelta(
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
- """The parts of StateGroupStore that can be called from workers.
- """
+ """The parts of StateGroupStore that can be called from workers."""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -276,8 +275,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
num_args=1,
)
async def _get_state_group_for_events(self, event_ids):
- """Returns mapping event_id -> state_group
- """
+ """Returns mapping event_id -> state_group"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
@@ -338,7 +336,8 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
columns=["state_group"],
)
self.db_pool.updates.register_background_update_handler(
- self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms,
+ self.DELETE_CURRENT_STATE_UPDATE_NAME,
+ self._background_remove_left_rooms,
)
async def _background_remove_left_rooms(self, progress, batch_size):
@@ -487,7 +486,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
- """ Keeps track of the state at a given event.
+ """Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 356623fc6e..0dbb501f16 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -64,7 +64,7 @@ class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
# N results.
- # We arbitarily limit to 100 stream_id entries to ensure we don't
+ # We arbitrarily limit to 100 stream_id entries to ensure we don't
# select toooo many.
sql = """
SELECT stream_id, count(*)
@@ -81,7 +81,7 @@ class StateDeltasStore(SQLBaseStore):
for stream_id, count in txn:
total += count
if total > 100:
- # We arbitarily limit to 100 entries to ensure we don't
+ # We arbitrarily limit to 100 entries to ensure we don't
# select toooo many.
logger.debug(
"Clipping current_state_delta_stream rows to stream_id %i",
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 62b7551517..38adecc78a 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,11 +15,12 @@
# limitations under the License.
import logging
-from collections import Counter
from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple
+from typing_extensions import Counter
+
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
@@ -320,7 +321,9 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
- async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
+ async def get_earliest_token_for_stats(
+ self, stats_type: str, id: str
+ ) -> Optional[int]:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@@ -340,7 +343,7 @@ class StatsStore(StateDeltasStore):
)
async def bulk_update_stats_delta(
- self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+ self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
) -> None:
"""Bulk update stats tables for a given stream_id and updates the stats
incremental position.
@@ -666,7 +669,7 @@ class StatsStore(StateDeltasStore):
async def get_changes_room_total_events_and_bytes(
self, min_pos: int, max_pos: int
- ) -> Dict[str, Dict[str, int]]:
+ ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
"""Fetches the counts of events in the given range of stream IDs.
Args:
@@ -684,18 +687,19 @@ class StatsStore(StateDeltasStore):
max_pos,
)
- def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
+ def get_changes_room_total_events_and_bytes_txn(
+ self, txn, low_pos: int, high_pos: int
+ ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
"""Gets the total_events and total_event_bytes counts for rooms and
senders, in a range of stream_orderings (including backfilled events).
Args:
txn
- low_pos (int): Low stream ordering
- high_pos (int): High stream ordering
+ low_pos: Low stream ordering
+ high_pos: High stream ordering
Returns:
- tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
- room and user deltas for total_events/total_event_bytes in the
+ The room and user deltas for total_events/total_event_bytes in the
format of `stats_id` -> fields
"""
@@ -998,7 +1002,9 @@ class StatsStore(StateDeltasStore):
ORDER BY {order_by_column} {order}
LIMIT ? OFFSET ?
""".format(
- sql_base=sql_base, order_by_column=order_by_column, order=order,
+ sql_base=sql_base,
+ order_by_column=order_by_column,
+ order=order,
)
args += [limit, start]
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index e3b9ff5ca6..91f8abb67d 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -565,7 +565,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
AND e.stream_ordering > ? AND e.stream_ordering <= ?
ORDER BY e.stream_ordering ASC
"""
- txn.execute(sql, (user_id, min_from_id, max_to_id,))
+ txn.execute(
+ sql,
+ (
+ user_id,
+ min_from_id,
+ max_to_id,
+ ),
+ )
rows = [
_EventDictReturn(event_id, None, stream_ordering)
@@ -695,7 +702,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
return "t%d-%d" % (topo, token)
def get_stream_id_for_event_txn(
- self, txn: LoggingTransaction, event_id: str, allow_none=False,
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ allow_none=False,
) -> int:
return self.db_pool.simple_select_one_onecol_txn(
txn=txn,
@@ -706,8 +716,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
)
async def get_position_for_event(self, event_id: str) -> PersistedEventPosition:
- """Get the persisted position for an event
- """
+ """Get the persisted position for an event"""
row = await self.db_pool.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
@@ -897,19 +906,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
) -> Tuple[int, List[EventBase]]:
"""Get all new events
- Returns all events with from_id < stream_ordering <= current_id.
+ Returns all events with from_id < stream_ordering <= current_id.
- Args:
- from_id: the stream_ordering of the last event we processed
- current_id: the stream_ordering of the most recently processed event
- limit: the maximum number of events to return
+ Args:
+ from_id: the stream_ordering of the last event we processed
+ current_id: the stream_ordering of the most recently processed event
+ limit: the maximum number of events to return
- Returns:
- A tuple of (next_id, events), where `next_id` is the next value to
- pass as `from_id` (it will either be the stream_ordering of the
- last returned event, or, if fewer than `limit` events were found,
- the `current_id`).
- """
+ Returns:
+ A tuple of (next_id, events), where `next_id` is the next value to
+ pass as `from_id` (it will either be the stream_ordering of the
+ last returned event, or, if fewer than `limit` events were found,
+ the `current_id`).
+ """
def get_all_new_events_stream_txn(txn):
sql = (
@@ -1238,8 +1247,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
@cached()
async def get_id_for_instance(self, instance_name: str) -> int:
- """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
- """
+ """Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
def _get_id_for_instance_txn(txn):
instance_id = self.db_pool.simple_select_one_onecol_txn(
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 9f120d3cb6..50067eabfc 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore):
)
return {row["tag"]: db_to_json(row["content"]) for row in rows}
-
-class TagsStore(TagsWorkerStore):
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int:
@@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
+ assert self._can_write_to_account_data
+
content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id):
@@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
+ assert self._can_write_to_account_data
def remove_tag_txn(txn, next_id):
sql = (
@@ -250,21 +251,12 @@ class TagsStore(TagsWorkerStore):
room_id: The ID of the room.
next_id: The the revision to advance to.
"""
+ assert self._can_write_to_account_data
txn.call_after(
self._account_data_stream_cache.entity_has_changed, user_id, next_id
)
- # Note: This is only here for backwards compat to allow admins to
- # roll back to a previous Synapse version. Next time we update the
- # database version we can remove this table.
- update_max_id_sql = (
- "UPDATE account_data_max_stream_id"
- " SET stream_id = ?"
- " WHERE stream_id < ?"
- )
- txn.execute(update_max_id_sql, (next_id, next_id))
-
update_sql = (
"UPDATE room_tags_revisions"
" SET stream_id = ?"
@@ -288,3 +280,7 @@ class TagsStore(TagsWorkerStore):
# which stream_id ends up in the table, as long as it is higher
# than the id that the client has.
pass
+
+
+class TagsStore(TagsWorkerStore):
+ pass
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 59207cadd4..b7072f1f5e 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -22,7 +22,6 @@ from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
from synapse.util.caches.expiringcache import ExpiringCache
@@ -64,8 +63,7 @@ class TransactionWorkerStore(SQLBaseStore):
class TransactionStore(TransactionWorkerStore):
- """A collection of queries for handling PDUs.
- """
+ """A collection of queries for handling PDUs."""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -198,7 +196,7 @@ class TransactionStore(TransactionWorkerStore):
retry_interval: int,
) -> None:
"""Sets the current retry timings for a given destination.
- Both timings should be zero if retrying is no longer occuring.
+ Both timings should be zero if retrying is no longer occurring.
Args:
destination
@@ -299,7 +297,10 @@ class TransactionStore(TransactionWorkerStore):
)
async def store_destination_rooms_entries(
- self, destinations: Iterable[str], room_id: str, stream_ordering: int,
+ self,
+ destinations: Iterable[str],
+ room_id: str,
+ stream_ordering: int,
) -> None:
"""
Updates or creates `destination_rooms` entries in batch for a single event.
@@ -310,49 +311,23 @@ class TransactionStore(TransactionWorkerStore):
stream_ordering: the stream_ordering of the event
"""
- return await self.db_pool.runInteraction(
- "store_destination_rooms_entries",
- self._store_destination_rooms_entries_txn,
- destinations,
- room_id,
- stream_ordering,
+ await self.db_pool.simple_upsert_many(
+ table="destinations",
+ key_names=("destination",),
+ key_values=[(d,) for d in destinations],
+ value_names=[],
+ value_values=[],
+ desc="store_destination_rooms_entries_dests",
)
- def _store_destination_rooms_entries_txn(
- self,
- txn: LoggingTransaction,
- destinations: Iterable[str],
- room_id: str,
- stream_ordering: int,
- ) -> None:
-
- # ensure we have a `destinations` row for this destination, as there is
- # a foreign key constraint.
- if isinstance(self.database_engine, PostgresEngine):
- q = """
- INSERT INTO destinations (destination)
- VALUES (?)
- ON CONFLICT DO NOTHING;
- """
- elif isinstance(self.database_engine, Sqlite3Engine):
- q = """
- INSERT OR IGNORE INTO destinations (destination)
- VALUES (?);
- """
- else:
- raise RuntimeError("Unknown database engine")
-
- txn.execute_batch(q, ((destination,) for destination in destinations))
-
rows = [(destination, room_id) for destination in destinations]
-
- self.db_pool.simple_upsert_many_txn(
- txn,
- "destination_rooms",
- ["destination", "room_id"],
- rows,
- ["stream_ordering"],
- [(stream_ordering,)] * len(rows),
+ await self.db_pool.simple_upsert_many(
+ table="destination_rooms",
+ key_names=("destination", "room_id"),
+ key_values=rows,
+ value_names=["stream_ordering"],
+ value_values=[(stream_ordering,)] * len(rows),
+ desc="store_destination_rooms_entries_rooms",
)
async def get_destination_last_successful_stream_ordering(
@@ -394,7 +369,9 @@ class TransactionStore(TransactionWorkerStore):
)
async def get_catch_up_room_event_ids(
- self, destination: str, last_successful_stream_ordering: int,
+ self,
+ destination: str,
+ last_successful_stream_ordering: int,
) -> List[str]:
"""
Returns at most 50 event IDs and their corresponding stream_orderings
@@ -418,7 +395,9 @@ class TransactionStore(TransactionWorkerStore):
@staticmethod
def _get_catch_up_room_event_ids_txn(
- txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
+ txn: LoggingTransaction,
+ destination: str,
+ last_successful_stream_ordering: int,
) -> List[str]:
q = """
SELECT event_id FROM destination_rooms
@@ -429,7 +408,8 @@ class TransactionStore(TransactionWorkerStore):
LIMIT 50
"""
txn.execute(
- q, (destination, last_successful_stream_ordering),
+ q,
+ (destination, last_successful_stream_ordering),
)
event_ids = [row[0] for row in txn]
return event_ids
@@ -464,19 +444,17 @@ class TransactionStore(TransactionWorkerStore):
txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
) -> List[str]:
q = """
- SELECT destination FROM destinations
- WHERE destination IN (
- SELECT destination FROM destination_rooms
- WHERE destination_rooms.stream_ordering >
- destinations.last_successful_stream_ordering
- )
- AND destination > ?
- AND (
- retry_last_ts IS NULL OR
- retry_last_ts + retry_interval < ?
- )
- ORDER BY destination
- LIMIT 25
+ SELECT DISTINCT destination FROM destinations
+ INNER JOIN destination_rooms USING (destination)
+ WHERE
+ stream_ordering > last_successful_stream_ordering
+ AND destination > ?
+ AND (
+ retry_last_ts IS NULL OR
+ retry_last_ts + retry_interval < ?
+ )
+ ORDER BY destination
+ LIMIT 25
"""
txn.execute(
q,
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 79b7ece330..5473ec1485 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -44,7 +44,11 @@ class UIAuthWorkerStore(SQLBaseStore):
"""
async def create_ui_auth_session(
- self, clientdict: JsonDict, uri: str, method: str, description: str,
+ self,
+ clientdict: JsonDict,
+ uri: str,
+ method: str,
+ description: str,
) -> UIAuthSessionData:
"""
Creates a new user interactive authentication session.
@@ -123,7 +127,10 @@ class UIAuthWorkerStore(SQLBaseStore):
return UIAuthSessionData(session_id, **result)
async def mark_ui_auth_stage_complete(
- self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
+ self,
+ session_id: str,
+ stage_type: str,
+ result: Union[str, bool, JsonDict],
):
"""
Mark a session stage as completed.
@@ -261,10 +268,12 @@ class UIAuthWorkerStore(SQLBaseStore):
return serverdict.get(key, default)
async def add_user_agent_ip_to_ui_auth_session(
- self, session_id: str, user_agent: str, ip: str,
+ self,
+ session_id: str,
+ user_agent: str,
+ ip: str,
):
- """Add the given user agent / IP to the tracking table
- """
+ """Add the given user agent / IP to the tracking table"""
await self.db_pool.simple_upsert(
table="ui_auth_sessions_ips",
keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
@@ -273,7 +282,8 @@ class UIAuthWorkerStore(SQLBaseStore):
)
async def get_user_agents_ips_to_ui_auth_session(
- self, session_id: str,
+ self,
+ session_id: str,
) -> List[Tuple[str, str]]:
"""Get the given user agents / IPs used during the ui auth process
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 452d58bd19..1026f321e5 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -17,7 +17,7 @@ import logging
import re
from typing import Any, Dict, Iterable, Optional, Set, Tuple
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
@@ -336,8 +336,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on)
async def is_room_world_readable_or_publicly_joinable(self, room_id):
- """Check if the room is either world_readable or publically joinable
- """
+ """Check if the room is either world_readable or publically joinable"""
# Create a state filter that only queries join and history state event
types_to_filter = (
@@ -360,7 +359,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
if hist_vis_id:
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
- if hist_vis_ev.content.get("history_visibility") == "world_readable":
+ if (
+ hist_vis_ev.content.get("history_visibility")
+ == HistoryVisibility.WORLD_READABLE
+ ):
return True
return False
@@ -393,9 +395,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
- setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ setweight(to_tsvector('simple', ?), 'A')
+ || setweight(to_tsvector('simple', ?), 'D')
+ || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
"""
txn.execute(
@@ -415,9 +417,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
- setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ setweight(to_tsvector('simple', ?), 'A')
+ || setweight(to_tsvector('simple', ?), 'D')
+ || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
)
"""
txn.execute(
@@ -432,9 +434,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
elif new_entry is False:
sql = """
UPDATE user_directory_search
- SET vector = setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ SET vector = setweight(to_tsvector('simple', ?), 'A')
+ || setweight(to_tsvector('simple', ?), 'D')
+ || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
WHERE user_id = ?
"""
txn.execute(
@@ -495,8 +497,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
async def add_users_in_public_rooms(
self, room_id: str, user_ids: Iterable[str]
) -> None:
- """Insert entries into the users_who_share_private_rooms table. The first
- user should be a local user.
+ """Insert entries into the users_in_public_rooms table.
Args:
room_id
@@ -513,8 +514,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
async def delete_all_from_user_dir(self) -> None:
- """Delete the entire user directory
- """
+ """Delete the entire user directory"""
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
@@ -537,7 +537,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory",
)
- async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+ async def update_user_directory_stream_pos(self, stream_id: int) -> None:
await self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
@@ -669,7 +669,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- @cached()
async def get_shared_rooms_for_users(
self, user_id: str, other_user_id: str
) -> Set[str]:
@@ -711,7 +710,13 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return {row["room_id"] for row in rows}
- async def get_user_directory_stream_pos(self) -> int:
+ async def get_user_directory_stream_pos(self) -> Optional[int]:
+ """
+ Get the stream ID of the user directory stream.
+
+ Returns:
+ The stream token or None if the initial background update hasn't happened yet.
+ """
return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
@@ -781,7 +786,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
INNER JOIN user_directory AS d USING (user_id)
WHERE
%(where_clause)s
- AND vector @@ to_tsquery('english', ?)
+ AND vector @@ to_tsquery('simple', ?)
ORDER BY
(CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
* (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
@@ -790,13 +795,13 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
3 * ts_rank_cd(
'{0.1, 0.1, 0.9, 1.0}',
vector,
- to_tsquery('english', ?),
+ to_tsquery('simple', ?),
8
)
+ ts_rank_cd(
'{0.1, 0.1, 0.9, 1.0}',
vector,
- to_tsquery('english', ?),
+ to_tsquery('simple', ?),
8
)
)
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index acb24e33af..1fd333b707 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -27,7 +27,7 @@ MAX_STATE_DELTA_HOPS = 100
class StateGroupBackgroundUpdateStore(SQLBaseStore):
- """Defines functions related to state groups needed to run the state backgroud
+ """Defines functions related to state groups needed to run the state background
updates.
"""
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0e31cc811a..97ec65f757 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -48,8 +48,7 @@ class _GetStateGroupDelta(
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
- """A data store for fetching/storing state groups.
- """
+ """A data store for fetching/storing state groups."""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -89,7 +88,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
50000,
)
self._state_group_members_cache = DictionaryCache(
- "*stateGroupMembersCache*", 500000,
+ "*stateGroupMembersCache*",
+ 500000,
)
def get_max_state_group_txn(txn: Cursor):
@@ -97,10 +97,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return txn.fetchone()[0]
self._state_group_seq_gen = build_sequence_generator(
- self.database_engine, get_max_state_group_txn, "state_group_id_seq"
- )
- self._state_group_seq_gen.check_consistency(
- db_conn, table="state_groups", id_column="id"
+ db_conn,
+ self.database_engine,
+ get_max_state_group_txn,
+ "state_group_id_seq",
+ table="state_groups",
+ id_column="id",
)
@cached(max_entries=10000, iterable=True)
@@ -181,12 +183,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
requests state from the cache, if False we need to query the DB for the
missing state.
"""
- is_all, known_absent, state_dict_ids = cache.get(group)
+ cache_entry = cache.get(group)
+ state_dict_ids = cache_entry.value
- if is_all or state_filter.is_full():
+ if cache_entry.full or state_filter.is_full():
# Either we have everything or want everything, either way
# `is_all` tells us whether we've gotten everything.
- return state_filter.filter_state(state_dict_ids), is_all
+ return state_filter.filter_state(state_dict_ids), cache_entry.full
# tracks whether any of our requested types are missing from the cache
missing_types = False
@@ -200,7 +203,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# There aren't any wild cards, so `concrete_types()` returns the
# complete list of event types we're wanting.
for key in state_filter.concrete_types():
- if key not in state_dict_ids and key not in known_absent:
+ if key not in state_dict_ids and key not in cache_entry.known_absent:
missing_types = True
break
@@ -565,11 +568,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
logger.info("[purge] removing redundant state groups")
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM state_groups_state WHERE state_group = ?",
((sg,) for sg in state_groups_to_delete),
)
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM state_groups WHERE id = ?",
((sg,) for sg in state_groups_to_delete),
)
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 035f9ea6e9..d15ccfacde 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import platform
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
@@ -28,11 +27,8 @@ def create_engine(database_config) -> BaseDatabaseEngine:
return Sqlite3Engine(sqlite3, database_config)
if name == "psycopg2":
- # pypy requires psycopg2cffi rather than psycopg2
- if platform.python_implementation() == "PyPy":
- import psycopg2cffi as psycopg2 # type: ignore
- else:
- import psycopg2 # type: ignore
+ # Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
+ import psycopg2 # type: ignore
return PostgresEngine(psycopg2, database_config)
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index d6d632dc10..cca839c70f 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -94,14 +94,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def server_version(self) -> str:
- """Gets a string giving the server version. For example: '3.22.0'
- """
+ """Gets a string giving the server version. For example: '3.22.0'"""
...
@abc.abstractmethod
def in_transaction(self, conn: Connection) -> bool:
- """Whether the connection is currently in a transaction.
- """
+ """Whether the connection is currently in a transaction."""
...
@abc.abstractmethod
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 7719ac32f7..80a3558aec 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -138,8 +138,7 @@ class PostgresEngine(BaseDatabaseEngine):
@property
def supports_using_any_list(self):
- """Do we support using `a = ANY(?)` and passing a list
- """
+ """Do we support using `a = ANY(?)` and passing a list"""
return True
def is_deadlock(self, error):
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 5db0f0b520..b87e7798da 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import platform
import struct
import threading
import typing
@@ -28,7 +29,15 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
super().__init__(database_module, database_config)
database = database_config.get("args", {}).get("database")
- self._is_in_memory = database in (None, ":memory:",)
+ self._is_in_memory = database in (
+ None,
+ ":memory:",
+ )
+
+ if platform.python_implementation() == "PyPy":
+ # pypy's sqlite3 module doesn't handle bytearrays, convert them
+ # back to bytes.
+ database_module.register_adapter(bytearray, lambda array: bytes(array))
# The current max state_group, or None if we haven't looked
# in the DB yet.
@@ -57,8 +66,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
@property
def supports_using_any_list(self):
- """Do we support using `a = ANY(?)` and passing a list
- """
+ """Do we support using `a = ANY(?)` and passing a list"""
return False
def check_database(self, db_conn, allow_outdated_version: bool = False):
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index afd10f7bae..c03871f393 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -17,11 +17,12 @@
import logging
import attr
+from signedjson.types import VerifyKey
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True)
class FetchKeyResult:
- verify_key = attr.ib() # VerifyKey: the key itself
- valid_until_ts = attr.ib() # int: how long we can use this key for
+ verify_key = attr.ib(type=VerifyKey) # the key itself
+ valid_until_ts = attr.ib(type=int) # how long we can use this key for
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 70e636b0ba..3a0d6fb32e 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -31,7 +31,14 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.types import (
+ Collection,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StateMap,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -68,6 +75,21 @@ stale_forward_extremities_counter = Histogram(
buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
)
+state_resolutions_during_persistence = Counter(
+ "synapse_storage_events_state_resolutions_during_persistence",
+ "Number of times we had to do state res to calculate new current state",
+)
+
+potential_times_prune_extremities = Counter(
+ "synapse_storage_events_potential_times_prune_extremities",
+ "Number of times we might be able to prune extremities",
+)
+
+times_pruned_extremities = Counter(
+ "synapse_storage_events_times_pruned_extremities",
+ "Number of times we were actually be able to prune extremities",
+)
+
class _EventPeristenceQueue:
"""Queues up events so that they can be persisted in bulk with only one
@@ -389,8 +411,8 @@ class EventsPersistenceStorage:
)
for room_id, ev_ctx_rm in events_by_room.items():
- latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
- room_id
+ latest_event_ids = (
+ await self.main_store.get_latest_event_ids_in_room(room_id)
)
new_latest_event_ids = await self._calculate_new_extremities(
room_id, ev_ctx_rm, latest_event_ids
@@ -454,7 +476,15 @@ class EventsPersistenceStorage:
latest_event_ids,
new_latest_event_ids,
)
- current_state, delta_ids = res
+ current_state, delta_ids, new_latest_event_ids = res
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremeties[room_id] = new_latest_event_ids
# If either are not None then there has been a change,
# and we need to work out the delta (or use that
@@ -573,29 +603,35 @@ class EventsPersistenceStorage:
self,
room_id: str,
events_context: List[Tuple[EventBase, EventContext]],
- old_latest_event_ids: Iterable[str],
- new_latest_event_ids: Iterable[str],
- ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
+ old_latest_event_ids: Set[str],
+ new_latest_event_ids: Set[str],
+ ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]:
"""Calculate the current state dict after adding some new events to
a room
Args:
- room_id (str):
+ room_id:
room to which the events are being added. Used for logging etc
- events_context (list[(EventBase, EventContext)]):
+ events_context:
events and contexts which are being added to the room
- old_latest_event_ids (iterable[str]):
+ old_latest_event_ids:
the old forward extremities for the room.
- new_latest_event_ids (iterable[str]):
+ new_latest_event_ids :
the new forward extremities for the room.
Returns:
- Returns a tuple of two state maps, the first being the full new current
- state and the second being the delta to the existing current state.
- If both are None then there has been no change.
+ Returns a tuple of two state maps and a set of new forward
+ extremities.
+
+ The first state map is the full new current state and the second
+ is the delta to the existing current state. If both are None then
+ there has been no change.
+
+ The function may prune some old entries from the set of new
+ forward extremities if it's safe to do so.
If there has been a change then we only return the delta if its
already been calculated. Conversely if we do know the delta then
@@ -672,7 +708,7 @@ class EventsPersistenceStorage:
# If they old and new groups are the same then we don't need to do
# anything.
if old_state_groups == new_state_groups:
- return None, None
+ return None, None, new_latest_event_ids
if len(new_state_groups) == 1 and len(old_state_groups) == 1:
# If we're going from one state group to another, lets check if
@@ -689,7 +725,7 @@ class EventsPersistenceStorage:
# the current state in memory then lets also return that,
# but it doesn't matter if we don't.
new_state = state_groups_map.get(new_state_group)
- return new_state, delta_ids
+ return new_state, delta_ids, new_latest_event_ids
# Now that we have calculated new_state_groups we need to get
# their state IDs so we can resolve to a single state set.
@@ -701,7 +737,7 @@ class EventsPersistenceStorage:
if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
- return state_groups_map[new_state_groups.pop()], None
+ return state_groups_map[new_state_groups.pop()], None, new_latest_event_ids
# Ok, we need to defer to the state handler to resolve our state sets.
@@ -734,7 +770,140 @@ class EventsPersistenceStorage:
state_res_store=StateResolutionStore(self.main_store),
)
- return res.state, None
+ state_resolutions_during_persistence.inc()
+
+ # If the returned state matches the state group of one of the new
+ # forward extremities then we check if we are able to prune some state
+ # extremities.
+ if res.state_group and res.state_group in new_state_groups:
+ new_latest_event_ids = await self._prune_extremities(
+ room_id,
+ new_latest_event_ids,
+ res.state_group,
+ event_id_to_state_group,
+ events_context,
+ )
+
+ return res.state, None, new_latest_event_ids
+
+ async def _prune_extremities(
+ self,
+ room_id: str,
+ new_latest_event_ids: Set[str],
+ resolved_state_group: int,
+ event_id_to_state_group: Dict[str, int],
+ events_context: List[Tuple[EventBase, EventContext]],
+ ) -> Set[str]:
+ """See if we can prune any of the extremities after calculating the
+ resolved state.
+ """
+ potential_times_prune_extremities.inc()
+
+ # We keep all the extremities that have the same state group, and
+ # see if we can drop the others.
+ new_new_extrems = {
+ e
+ for e in new_latest_event_ids
+ if event_id_to_state_group[e] == resolved_state_group
+ }
+
+ dropped_extrems = set(new_latest_event_ids) - new_new_extrems
+
+ logger.debug("Might drop extremities: %s", dropped_extrems)
+
+ # We only drop events from the extremities list if:
+ # 1. we're not currently persisting them;
+ # 2. they're not our own events (or are dummy events); and
+ # 3. they're either:
+ # 1. over N hours old and more than N events ago (we use depth to
+ # calculate); or
+ # 2. we are persisting an event from the same domain and more than
+ # M events ago.
+ #
+ # The idea is that we don't want to drop events that are "legitimate"
+ # extremities (that we would want to include as prev events), only
+ # "stuck" extremities that are e.g. due to a gap in the graph.
+ #
+ # Note that we either drop all of them or none of them. If we only drop
+ # some of the events we don't know if state res would come to the same
+ # conclusion.
+
+ for ev, _ in events_context:
+ if ev.event_id in dropped_extrems:
+ logger.debug(
+ "Not dropping extremities: %s is being persisted", ev.event_id
+ )
+ return new_latest_event_ids
+
+ dropped_events = await self.main_store.get_events(
+ dropped_extrems,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+
+ new_senders = {get_domain_from_id(e.sender) for e, _ in events_context}
+
+ one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ current_depth = max(e.depth for e, _ in events_context)
+ for event in dropped_events.values():
+ # If the event is a local dummy event then we should check it
+ # doesn't reference any local events, as we want to reference those
+ # if we send any new events.
+ #
+ # Note we do this recursively to handle the case where a dummy event
+ # references a dummy event that only references remote events.
+ #
+ # Ideally we'd figure out a way of still being able to drop old
+ # dummy events that reference local events, but this is good enough
+ # as a first cut.
+ events_to_check = [event]
+ while events_to_check:
+ new_events = set()
+ for event_to_check in events_to_check:
+ if self.is_mine_id(event_to_check.sender):
+ if event_to_check.type != EventTypes.Dummy:
+ logger.debug("Not dropping own event")
+ return new_latest_event_ids
+ new_events.update(event_to_check.prev_event_ids())
+
+ prev_events = await self.main_store.get_events(
+ new_events,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+ events_to_check = prev_events.values()
+
+ if (
+ event.origin_server_ts < one_day_ago
+ and event.depth < current_depth - 100
+ ):
+ continue
+
+ # We can be less conservative about dropping extremities from the
+ # same domain, though we do want to wait a little bit (otherwise
+ # we'll immediately remove all extremities from a given server).
+ if (
+ get_domain_from_id(event.sender) in new_senders
+ and event.depth < current_depth - 20
+ ):
+ continue
+
+ logger.debug(
+ "Not dropping as too new and not in new_senders: %s",
+ new_senders,
+ )
+
+ return new_latest_event_ids
+
+ times_pruned_extremities.inc()
+
+ logger.info(
+ "Pruning forward extremities in room %s: from %s -> %s",
+ room_id,
+ new_latest_event_ids,
+ new_new_extrems,
+ )
+ return new_new_extrems
async def _calculate_state_delta(
self, room_id: str, current_state: StateMap[str]
@@ -836,7 +1005,10 @@ class EventsPersistenceStorage:
remote_event_ids = [
event_id
- for (typ, state_key,), event_id in current_state.items()
+ for (
+ typ,
+ state_key,
+ ), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key)
]
rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 459754feab..6c3c2da520 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -18,9 +18,10 @@ import logging
import os
import re
from collections import Counter
-from typing import Optional, TextIO
+from typing import Generator, Iterable, List, Optional, TextIO, Tuple
import attr
+from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection
@@ -34,10 +35,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-# XXX: If you're about to bump this to 59 (or higher) please create an update
-# that drops the unused `cache_invalidation_stream` table, as per #7436!
-# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656!
-SCHEMA_VERSION = 58
+SCHEMA_VERSION = 59
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -70,7 +68,7 @@ def prepare_database(
db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
- databases: Collection[str] = ["main", "state"],
+ databases: Collection[str] = ("main", "state"),
):
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -115,7 +113,7 @@ def prepare_database(
# which should be empty.
if config is None:
raise ValueError(
- "config==None in prepare_database, but databse is not empty"
+ "config==None in prepare_database, but database is not empty"
)
# if it's a worker app, refuse to upgrade the database, to avoid multiple
@@ -155,7 +153,9 @@ def prepare_database(
raise
-def _setup_new_database(cur, database_engine, databases):
+def _setup_new_database(
+ cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str]
+) -> None:
"""Sets up the physical database by finding a base set of "full schemas" and
then applying any necessary deltas, including schemas from the given data
stores.
@@ -188,10 +188,9 @@ def _setup_new_database(cur, database_engine, databases):
folder as well those in the data stores specified.
Args:
- cur (Cursor): a database cursor
- database_engine (DatabaseEngine)
- databases (list[str]): The names of the databases to instantiate
- on the given physical database.
+ cur: a database cursor
+ database_engine
+ databases: The names of the databases to instantiate on the given physical database.
"""
# We're about to set up a brand new database so we check that its
@@ -199,12 +198,11 @@ def _setup_new_database(cur, database_engine, databases):
database_engine.check_new_database(cur)
current_dir = os.path.join(dir_path, "schema", "full_schemas")
- directory_entries = os.listdir(current_dir)
# First we find the highest full schema version we have
valid_versions = []
- for filename in directory_entries:
+ for filename in os.listdir(current_dir):
try:
ver = int(filename)
except ValueError:
@@ -237,7 +235,7 @@ def _setup_new_database(cur, database_engine, databases):
for database in databases
)
- directory_entries = []
+ directory_entries = [] # type: List[_DirectoryListing]
for directory in directories:
directory_entries.extend(
_DirectoryListing(file_name, os.path.join(directory, file_name))
@@ -275,15 +273,15 @@ def _setup_new_database(cur, database_engine, databases):
def _upgrade_existing_database(
- cur,
- current_version,
- applied_delta_files,
- upgraded,
- database_engine,
- config,
- databases,
- is_empty=False,
-):
+ cur: Cursor,
+ current_version: int,
+ applied_delta_files: List[str],
+ upgraded: bool,
+ database_engine: BaseDatabaseEngine,
+ config: Optional[HomeServerConfig],
+ databases: Collection[str],
+ is_empty: bool = False,
+) -> None:
"""Upgrades an existing physical database.
Delta files can either be SQL stored in *.sql files, or python modules
@@ -323,21 +321,20 @@ def _upgrade_existing_database(
for a version before applying those in the next version.
Args:
- cur (Cursor)
- current_version (int): The current version of the schema.
- applied_delta_files (list): A list of deltas that have already been
- applied.
- upgraded (bool): Whether the current version was generated by having
+ cur
+ current_version: The current version of the schema.
+ applied_delta_files: A list of deltas that have already been applied.
+ upgraded: Whether the current version was generated by having
applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files.
- database_engine (DatabaseEngine)
- config (synapse.config.homeserver.HomeServerConfig|None):
+ database_engine
+ config:
None if we are initialising a blank database, otherwise the application
config
- databases (list[str]): The names of the databases to instantiate
+ databases: The names of the databases to instantiate
on the given physical database.
- is_empty (bool): Is this a blank database? I.e. do we need to run the
+ is_empty: Is this a blank database? I.e. do we need to run the
upgrade portions of the delta scripts.
"""
if is_empty:
@@ -358,6 +355,7 @@ def _upgrade_existing_database(
if not is_empty and "main" in databases:
from synapse.storage.databases.main import check_database_before_upgrade
+ assert config is not None
check_database_before_upgrade(cur, database_engine, config)
start_ver = current_version
@@ -374,7 +372,16 @@ def _upgrade_existing_database(
specific_engine_extensions = (".sqlite", ".postgres")
for v in range(start_ver, SCHEMA_VERSION + 1):
- logger.info("Applying schema deltas for v%d", v)
+ if not is_worker:
+ logger.info("Applying schema deltas for v%d", v)
+
+ cur.execute("DELETE FROM schema_version")
+ cur.execute(
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
+ (v, True),
+ )
+ else:
+ logger.info("Checking schema deltas for v%d", v)
# We need to search both the global and per data store schema
# directories for schema updates.
@@ -388,10 +395,10 @@ def _upgrade_existing_database(
)
# Used to check if we have any duplicate file names
- file_name_counter = Counter()
+ file_name_counter = Counter() # type: CounterType[str]
# Now find which directories have anything of interest.
- directory_entries = []
+ directory_entries = [] # type: List[_DirectoryListing]
for directory in directories:
logger.debug("Looking for schema deltas in %s", directory)
try:
@@ -418,7 +425,10 @@ def _upgrade_existing_database(
# We don't support using the same file name in the same delta version.
raise PrepareDatabaseException(
"Found multiple delta files with the same name in v%d: %s"
- % (v, duplicates,)
+ % (
+ v,
+ duplicates,
+ )
)
# We sort to ensure that we apply the delta files in a consistent
@@ -445,11 +455,11 @@ def _upgrade_existing_database(
module_name = "synapse.storage.v%d_%s" % (v, root_name)
with open(absolute_path) as python_file:
- module = imp.load_source(module_name, absolute_path, python_file)
+ module = imp.load_source(module_name, absolute_path, python_file) # type: ignore
logger.info("Running script %s", relative_path)
- module.run_create(cur, database_engine)
+ module.run_create(cur, database_engine) # type: ignore
if not is_empty:
- module.run_upgrade(cur, database_engine, config=config)
+ module.run_upgrade(cur, database_engine, config=config) # type: ignore
elif ext == ".pyc" or file_name == "__pycache__":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
@@ -488,23 +498,18 @@ def _upgrade_existing_database(
(v, relative_path),
)
- cur.execute("DELETE FROM schema_version")
- cur.execute(
- "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
- (v, True),
- )
-
logger.info("Schema now up to date")
-def _apply_module_schemas(txn, database_engine, config):
+def _apply_module_schemas(
+ txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
+) -> None:
"""Apply the module schemas for the dynamic modules, if any
Args:
cur: database cursor
- database_engine: synapse database engine class
- config (synapse.config.homeserver.HomeServerConfig):
- application config
+ database_engine:
+ config: application config
"""
for (mod, _config) in config.password_providers:
if not hasattr(mod, "get_db_schema_files"):
@@ -515,18 +520,23 @@ def _apply_module_schemas(txn, database_engine, config):
)
-def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+def _apply_module_schema_files(
+ cur: Cursor,
+ database_engine: BaseDatabaseEngine,
+ modname: str,
+ names_and_streams: Iterable[Tuple[str, TextIO]],
+) -> None:
"""Apply the module schemas for a single module
Args:
cur: database cursor
database_engine: synapse database engine class
- modname (str): fully qualified name of the module
- names_and_streams (Iterable[(str, file)]): the names and streams of
- schemas to be applied
+ modname: fully qualified name of the module
+ names_and_streams: the names and streams of schemas to be applied
"""
cur.execute(
- "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
+ "SELECT file FROM applied_module_schemas WHERE module_name = ?",
+ (modname,),
)
applied_deltas = {d for d, in cur}
for (name, stream) in names_and_streams:
@@ -549,7 +559,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
)
-def get_statements(f):
+def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment
@@ -594,32 +604,34 @@ def get_statements(f):
statement_buffer = statements[-1].strip()
-def executescript(txn, schema_path):
+def executescript(txn: Cursor, schema_path: str) -> None:
with open(schema_path, "r") as f:
execute_statements_from_stream(txn, f)
-def execute_statements_from_stream(cur: Cursor, f: TextIO):
+def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
for statement in get_statements(f):
cur.execute(statement)
-def _get_or_create_schema_state(txn, database_engine):
+def _get_or_create_schema_state(
+ txn: Cursor, database_engine: BaseDatabaseEngine
+) -> Optional[Tuple[int, List[str], bool]]:
# Bluntly try creating the schema_version tables.
schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
executescript(txn, schema_path)
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
- current_version = int(row[0]) if row else None
- upgraded = bool(row[1]) if row else None
- if current_version:
+ if row is not None:
+ current_version = int(row[0])
txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,),
)
applied_deltas = [d for d, in txn]
+ upgraded = bool(row[1])
return current_version, applied_deltas, upgraded
return None
@@ -634,5 +646,5 @@ class _DirectoryListing:
`file_name` attr is kept first.
"""
- file_name = attr.ib()
- absolute_path = attr.ib()
+ file_name = attr.ib(type=str)
+ absolute_path = attr.ib(type=str)
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index bfa0a9fd06..ad954990a7 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -15,21 +15,24 @@
import itertools
import logging
-from typing import Set
+from typing import TYPE_CHECKING, Set
+
+from synapse.storage.databases import Databases
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class PurgeEventsStorage:
- """High level interface for purging rooms and event history.
- """
+ """High level interface for purging rooms and event history."""
- def __init__(self, hs, stores):
+ def __init__(self, hs: "HomeServer", stores: Databases):
self.stores = stores
- async def purge_room(self, room_id: str):
- """Deletes all record of a room
- """
+ async def purge_room(self, room_id: str) -> None:
+ """Deletes all record of a room"""
state_groups_to_delete = await self.stores.main.purge_room(room_id)
await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
@@ -70,9 +73,6 @@ class PurgeEventsStorage:
Returns:
The set of state groups that can be deleted.
"""
- # Graph of state group -> previous group
- graph = {}
-
# Set of events that we have found to be referenced by events
referenced_groups = set()
@@ -108,8 +108,6 @@ class PurgeEventsStorage:
next_to_search |= prevs
state_groups_seen |= prevs
- graph.update(edges)
-
to_delete = state_groups_seen - referenced_groups
return to_delete
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index cec96ad6a7..2564f34b47 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -14,10 +14,12 @@
# limitations under the License.
import logging
+from typing import Any, Dict, List, Optional, Tuple
import attr
from synapse.api.errors import SynapseError
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -27,18 +29,18 @@ class PaginationChunk:
"""Returned by relation pagination APIs.
Attributes:
- chunk (list): The rows returned by pagination
- next_batch (Any|None): Token to fetch next set of results with, if
+ chunk: The rows returned by pagination
+ next_batch: Token to fetch next set of results with, if
None then there are no more results.
- prev_batch (Any|None): Token to fetch previous set of results with, if
+ prev_batch: Token to fetch previous set of results with, if
None then there are no previous results.
"""
- chunk = attr.ib()
- next_batch = attr.ib(default=None)
- prev_batch = attr.ib(default=None)
+ chunk = attr.ib(type=List[JsonDict])
+ next_batch = attr.ib(type=Optional[Any], default=None)
+ prev_batch = attr.ib(type=Optional[Any], default=None)
- def to_dict(self):
+ def to_dict(self) -> Dict[str, Any]:
d = {"chunk": self.chunk}
if self.next_batch:
@@ -59,25 +61,25 @@ class RelationPaginationToken:
boundaries of the chunk as pagination tokens.
Attributes:
- topological (int): The topological ordering of the boundary event
- stream (int): The stream ordering of the boundary event.
+ topological: The topological ordering of the boundary event
+ stream: The stream ordering of the boundary event.
"""
- topological = attr.ib()
- stream = attr.ib()
+ topological = attr.ib(type=int)
+ stream = attr.ib(type=int)
@staticmethod
- def from_string(string):
+ def from_string(string: str) -> "RelationPaginationToken":
try:
t, s = string.split("-")
return RelationPaginationToken(int(t), int(s))
except ValueError:
raise SynapseError(400, "Invalid token")
- def to_string(self):
+ def to_string(self) -> str:
return "%d-%d" % (self.topological, self.stream)
- def as_tuple(self):
+ def as_tuple(self) -> Tuple[Any, ...]:
return attr.astuple(self)
@@ -89,23 +91,23 @@ class AggregationPaginationToken:
aggregation groups, we can just use them as our pagination token.
Attributes:
- count (int): The count of relations in the boundar group.
- stream (int): The MAX stream ordering in the boundary group.
+ count: The count of relations in the boundary group.
+ stream: The MAX stream ordering in the boundary group.
"""
- count = attr.ib()
- stream = attr.ib()
+ count = attr.ib(type=int)
+ stream = attr.ib(type=int)
@staticmethod
- def from_string(string):
+ def from_string(string: str) -> "AggregationPaginationToken":
try:
c, s = string.split("-")
return AggregationPaginationToken(int(c), int(s))
except ValueError:
raise SynapseError(400, "Invalid token")
- def to_string(self):
+ def to_string(self) -> str:
return "%d-%d" % (self.count, self.stream)
- def as_tuple(self):
+ def as_tuple(self) -> Tuple[Any, ...]:
return attr.astuple(self)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index f152f63321..d2ff4da6b9 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
)
GetRoomsForUserWithStreamOrdering = namedtuple(
- "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
+ "GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 08a69f2f96..2e277a21c4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -12,9 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+)
import attr
@@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.storage.databases import Databases
+
logger = logging.getLogger(__name__)
# Used for generic functions below
@@ -327,13 +340,14 @@ class StateFilter:
class StateGroupStorage:
- """High level interface to fetching state for event.
- """
+ """High level interface to fetching state for event."""
- def __init__(self, hs, stores):
+ def __init__(self, hs: "HomeServer", stores: "Databases"):
self.stores = stores
- async def get_state_group_delta(self, state_group: int):
+ async def get_state_group_delta(
+ self, state_group: int
+ ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
"""Given a state group try to return a previous group and a delta between
the old and the new.
@@ -341,8 +355,8 @@ class StateGroupStorage:
state_group: The state group used to retrieve state deltas.
Returns:
- Tuple[Optional[int], Optional[StateMap[str]]]:
- (prev_group, delta_ids)
+ A tuple of the previous group and a state map of the event IDs which
+ make up the delta between the old and new state groups.
"""
return await self.stores.state.get_state_group_delta(state_group)
@@ -385,7 +399,7 @@ class StateGroupStorage:
async def get_state_groups(
self, room_id: str, event_ids: Iterable[str]
) -> Dict[int, List[EventBase]]:
- """ Get the state groups for the given list of event_ids
+ """Get the state groups for the given list of event_ids
Args:
room_id: ID of the room for these events.
@@ -435,8 +449,8 @@ class StateGroupStorage:
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events(
- self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
- ):
+ self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
+ ) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
@@ -471,8 +485,8 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events(
- self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
- ):
+ self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
+ ) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
@@ -500,7 +514,7 @@ class StateGroupStorage:
async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[EventBase]:
"""
Get the state dict corresponding to a particular event
@@ -516,7 +530,7 @@ class StateGroupStorage:
async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 9cadcba18f..17291c9d5e 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterable, Iterator, List, Optional, Tuple
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from typing_extensions import Protocol
@@ -20,23 +20,44 @@ from typing_extensions import Protocol
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
+_Parameters = Union[Sequence[Any], Mapping[str, Any]]
+
class Cursor(Protocol):
- def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+ def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
...
- def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+ def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
...
- def fetchall(self) -> List[Tuple]:
+ def fetchone(self) -> Optional[Tuple]:
+ ...
+
+ def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
...
- def fetchone(self) -> Tuple:
+ def fetchall(self) -> List[Tuple]:
...
@property
- def description(self) -> Any:
- return None
+ def description(
+ self,
+ ) -> Optional[
+ Sequence[
+ # Note that this is an approximate typing based on sqlite3 and other
+ # drivers, and may not be entirely accurate.
+ Tuple[
+ str,
+ Optional[Any],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ ]
+ ]
+ ]:
+ ...
@property
def rowcount(self) -> int:
@@ -59,7 +80,7 @@ class Connection(Protocol):
def commit(self) -> None:
...
- def rollback(self, *args, **kwargs) -> None:
+ def rollback(self) -> None:
...
def __enter__(self) -> "Connection":
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 02d71302ea..d4643c4fdf 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@
import heapq
import logging
import threading
-from collections import deque
+from collections import OrderedDict
from contextlib import contextmanager
-from typing import Dict, List, Optional, Set, Union
+from typing import Dict, List, Optional, Set, Tuple, Union
import attr
-from typing_extensions import Deque
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -101,7 +100,13 @@ class StreamIdGenerator:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque() # type: Deque[int]
+
+ # We use this as an ordered set, as we want to efficiently append items,
+ # remove items and get the first item. Since we insert IDs in order, the
+ # insertion ordering will ensure its in the correct ordering.
+ #
+ # The key and values are the same, but we never look at the values.
+ self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
def get_next(self):
"""
@@ -113,7 +118,7 @@ class StreamIdGenerator:
self._current += self._step
next_id = self._current
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -121,7 +126,7 @@ class StreamIdGenerator:
yield next_id
finally:
with self._lock:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -140,7 +145,7 @@ class StreamIdGenerator:
self._current += n * self._step
for next_id in next_ids:
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -149,20 +154,20 @@ class StreamIdGenerator:
finally:
with self._lock:
for next_id in next_ids:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
- def get_current_token(self):
+ def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
Returns:
- int
+ The maximum stream id.
"""
with self._lock:
if self._unfinished_ids:
- return self._unfinished_ids[0] - self._step
+ return next(iter(self._unfinished_ids)) - self._step
return self._current
@@ -186,11 +191,12 @@ class MultiWriterIdGenerator:
Args:
db_conn
db
- stream_name: A name for the stream.
+ stream_name: A name for the stream, for use in the `stream_positions`
+ table. (Does not need to be the same as the replication stream name)
instance_name: The name of this instance.
- table: Database table associated with stream.
- instance_column: Column that stores the row's writer's instance name
- id_column: Column that stores the stream ID.
+ tables: List of tables associated with the stream. Tuple of table
+ name, column name that stores the writer's instance name, and
+ column name that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
writers: A list of known writers to use to populate current positions
@@ -206,9 +212,7 @@ class MultiWriterIdGenerator:
db: DatabasePool,
stream_name: str,
instance_name: str,
- table: str,
- instance_column: str,
- id_column: str,
+ tables: List[Tuple[str, str, str]],
sequence_name: str,
writers: List[str],
positive: bool = True,
@@ -241,7 +245,7 @@ class MultiWriterIdGenerator:
# and b) noting that if we have seen a run of persisted positions
# without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
#
- # Note: There is no guarentee that the IDs generated by the sequence
+ # Note: There is no guarantee that the IDs generated by the sequence
# will be gapless; gaps can form when e.g. a transaction was rolled
# back. This means that sometimes we won't be able to skip forward the
# position even though everything has been persisted. However, since
@@ -260,15 +264,22 @@ class MultiWriterIdGenerator:
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
- self._sequence_gen.check_consistency(
- db_conn, table=table, id_column=id_column, positive=positive
- )
+ for table, _, id_column in tables:
+ self._sequence_gen.check_consistency(
+ db_conn,
+ table=table,
+ id_column=id_column,
+ stream_name=stream_name,
+ positive=positive,
+ )
# This goes and fills out the above state from the database.
- self._load_current_ids(db_conn, table, instance_column, id_column)
+ self._load_current_ids(db_conn, tables)
def _load_current_ids(
- self, db_conn, table: str, instance_column: str, id_column: str
+ self,
+ db_conn,
+ tables: List[Tuple[str, str, str]],
):
cur = db_conn.cursor(txn_name="_load_current_ids")
@@ -306,17 +317,22 @@ class MultiWriterIdGenerator:
# We add a GREATEST here to ensure that the result is always
# positive. (This can be a problem for e.g. backfill streams where
# the server has never backfilled).
- sql = """
- SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
- FROM %(table)s
- """ % {
- "id": id_column,
- "table": table,
- "agg": "MAX" if self._positive else "-MIN",
- }
- cur.execute(sql)
- (stream_id,) = cur.fetchone()
- self._persisted_upto_position = stream_id
+ max_stream_id = 1
+ for table, _, id_column in tables:
+ sql = """
+ SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+ FROM %(table)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if self._positive else "-MIN",
+ }
+ cur.execute(sql)
+ (stream_id,) = cur.fetchone()
+
+ max_stream_id = max(max_stream_id, stream_id)
+
+ self._persisted_upto_position = max_stream_id
else:
# If we have a min_stream_id then we pull out everything greater
# than it from the DB so that we can prefill
@@ -329,21 +345,31 @@ class MultiWriterIdGenerator:
# stream positions table before restart (or the stream position
# table otherwise got out of date).
- sql = """
- SELECT %(instance)s, %(id)s FROM %(table)s
- WHERE ? %(cmp)s %(id)s
- """ % {
- "id": id_column,
- "table": table,
- "instance": instance_column,
- "cmp": "<=" if self._positive else ">=",
- }
- cur.execute(sql, (min_stream_id * self._return_factor,))
-
self._persisted_upto_position = min_stream_id
+ rows = []
+ for table, instance_column, id_column in tables:
+ sql = """
+ SELECT %(instance)s, %(id)s FROM %(table)s
+ WHERE ? %(cmp)s %(id)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "instance": instance_column,
+ "cmp": "<=" if self._positive else ">=",
+ }
+ cur.execute(sql, (min_stream_id * self._return_factor,))
+
+ rows.extend(cur)
+
+ # Sort so that we handle rows in order for each instance.
+ rows.sort()
+
with self._lock:
- for (instance, stream_id,) in cur:
+ for (
+ instance,
+ stream_id,
+ ) in rows:
stream_id = self._return_factor * stream_id
self._add_persisted_position(stream_id)
@@ -397,7 +423,7 @@ class MultiWriterIdGenerator:
# bother, as nothing will read it).
#
# We only do this on the success path so that the persisted current
- # position points to a persited row with the correct instance name.
+ # position points to a persisted row with the correct instance name.
if self._writers:
txn.call_after(
run_as_background_process,
@@ -460,8 +486,7 @@ class MultiWriterIdGenerator:
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer.
- """
+ """Returns the position of the given writer."""
# If we don't have an entry for the given instance name, we assume it's a
# new writer.
@@ -488,7 +513,7 @@ class MultiWriterIdGenerator:
}
def advance(self, instance_name: str, new_id: int):
- """Advance the postion of the named writer to the given ID, if greater
+ """Advance the position of the named writer to the given ID, if greater
than existing entry.
"""
@@ -560,8 +585,7 @@ class MultiWriterIdGenerator:
break
def _update_stream_positions_table_txn(self, txn: Cursor):
- """Update the `stream_positions` table with newly persisted position.
- """
+ """Update the `stream_positions` table with newly persisted position."""
if not self._writers:
return
@@ -601,8 +625,7 @@ class _AsyncCtxManagerWrapper:
@attr.s(slots=True)
class _MultiWriterCtxManager:
- """Async context manager returned by MultiWriterIdGenerator
- """
+ """Async context manager returned by MultiWriterIdGenerator"""
id_gen = attr.ib(type=MultiWriterIdGenerator)
multiple_ids = attr.ib(type=Optional[int], default=None)
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 4386b6101e..36a67e7019 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -15,9 +15,8 @@
import abc
import logging
import threading
-from typing import Callable, List, Optional
+from typing import TYPE_CHECKING, Callable, List, Optional
-from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import (
BaseDatabaseEngine,
IncorrectDatabaseSetup,
@@ -25,6 +24,9 @@ from synapse.storage.engines import (
)
from synapse.storage.types import Connection, Cursor
+if TYPE_CHECKING:
+ from synapse.storage.database import LoggingDatabaseConnection
+
logger = logging.getLogger(__name__)
@@ -43,6 +45,21 @@ and run the following SQL:
See docs/postgres.md for more information.
"""
+_INCONSISTENT_STREAM_ERROR = """
+Postgres sequence '%(seq)s' is inconsistent with associated stream position
+of '%(stream_name)s' in the 'stream_positions' table.
+
+This is likely a programming error and should be reported at
+https://github.com/matrix-org/synapse.
+
+A temporary workaround to fix this error is to shut down Synapse (including
+any and all workers) and run the following SQL:
+
+ DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s';
+
+This will need to be done every time the server is restarted.
+"""
+
class SequenceGenerator(metaclass=abc.ABCMeta):
"""A class which generates a unique sequence of integers"""
@@ -53,19 +70,30 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
...
@abc.abstractmethod
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ """Get the next `n` IDs in the sequence"""
+ ...
+
+ @abc.abstractmethod
def check_consistency(
self,
- db_conn: LoggingDatabaseConnection,
+ db_conn: "LoggingDatabaseConnection",
table: str,
id_column: str,
+ stream_name: Optional[str] = None,
positive: bool = True,
):
"""Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table.
- This is to handle various cases where the sequence value can get out
- of sync with the table, e.g. if Synapse gets rolled back to a previous
+ This is to handle various cases where the sequence value can get out of
+ sync with the table, e.g. if Synapse gets rolled back to a previous
version and the rolled forwards again.
+
+ If a stream name is given then this will check that any value in the
+ `stream_positions` table is less than or equal to the current sequence
+ value. If it isn't then it's likely that streams have been crossed
+ somewhere (e.g. two ID generators have the same stream name).
"""
...
@@ -78,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
def get_next_id_txn(self, txn: Cursor) -> int:
txn.execute("SELECT nextval(?)", (self._sequence_name,))
- return txn.fetchone()[0]
+ fetch_res = txn.fetchone()
+ assert fetch_res is not None
+ return fetch_res[0]
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
txn.execute(
@@ -88,11 +118,14 @@ class PostgresSequenceGenerator(SequenceGenerator):
def check_consistency(
self,
- db_conn: LoggingDatabaseConnection,
+ db_conn: "LoggingDatabaseConnection",
table: str,
id_column: str,
+ stream_name: Optional[str] = None,
positive: bool = True,
):
+ """See SequenceGenerator.check_consistency for docstring."""
+
txn = db_conn.cursor(txn_name="sequence.check_consistency")
# First we get the current max ID from the table.
@@ -115,7 +148,21 @@ class PostgresSequenceGenerator(SequenceGenerator):
txn.execute(
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
)
- last_value, is_called = txn.fetchone()
+ fetch_res = txn.fetchone()
+ assert fetch_res is not None
+ last_value, is_called = fetch_res
+
+ # If we have an associated stream check the stream_positions table.
+ max_in_stream_positions = None
+ if stream_name:
+ txn.execute(
+ "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?",
+ (stream_name,),
+ )
+ row = txn.fetchone()
+ if row:
+ max_in_stream_positions = row[0]
+
txn.close()
# If `is_called` is False then `last_value` is actually the value that
@@ -136,6 +183,14 @@ class PostgresSequenceGenerator(SequenceGenerator):
% {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
)
+ # If we have values in the stream positions table then they have to be
+ # less than or equal to `last_value`
+ if max_in_stream_positions and max_in_stream_positions > last_value:
+ raise IncorrectDatabaseSetup(
+ _INCONSISTENT_STREAM_ERROR
+ % {"seq": self._sequence_name, "stream_name": stream_name}
+ )
+
GetFirstCallbackType = Callable[[Cursor], int]
@@ -172,17 +227,38 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ first_id = self._current_max_id + 1
+ self._current_max_id += n
+ return [first_id + i for i in range(n)]
+
def check_consistency(
- self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ self,
+ db_conn: Connection,
+ table: str,
+ id_column: str,
+ stream_name: Optional[str] = None,
+ positive: bool = True,
):
# There is nothing to do for in memory sequences
pass
def build_sequence_generator(
+ db_conn: "LoggingDatabaseConnection",
database_engine: BaseDatabaseEngine,
get_first_callback: GetFirstCallbackType,
sequence_name: str,
+ table: Optional[str],
+ id_column: Optional[str],
+ stream_name: Optional[str] = None,
+ positive: bool = True,
) -> SequenceGenerator:
"""Get the best impl of SequenceGenerator available
@@ -194,8 +270,23 @@ def build_sequence_generator(
get_first_callback: a callback which gets the next sequence ID. Used if
we're on sqlite.
sequence_name: the name of a postgres sequence to use.
+ table, id_column, stream_name, positive: If set then `check_consistency`
+ is called on the created sequence. See docstring for
+ `check_consistency` details.
"""
if isinstance(database_engine, PostgresEngine):
- return PostgresSequenceGenerator(sequence_name)
+ seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator
else:
- return LocalSequenceGenerator(get_first_callback)
+ seq = LocalSequenceGenerator(get_first_callback)
+
+ if table:
+ assert id_column
+ seq.check_consistency(
+ db_conn=db_conn,
+ table=table,
+ id_column=id_column,
+ stream_name=stream_name,
+ positive=positive,
+ )
+
+ return seq
diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
index 4589e4539b..a047699cc4 100644
--- a/synapse/third_party_rules/access_rules.py
+++ b/synapse/third_party_rules/access_rules.py
@@ -81,7 +81,9 @@ class RoomAccessRules(object):
"""
def __init__(
- self, config: Dict, module_api: ModuleApi,
+ self,
+ config: Dict,
+ module_api: ModuleApi,
):
self.id_server = config["id_server"]
self.module_api = module_api
@@ -111,7 +113,10 @@ class RoomAccessRules(object):
return config
async def on_create_room(
- self, requester: Requester, config: Dict, is_requester_admin: bool,
+ self,
+ requester: Requester,
+ config: Dict,
+ is_requester_admin: bool,
) -> bool:
"""Implements synapse.events.ThirdPartyEventRules.on_create_room.
@@ -259,7 +264,10 @@ class RoomAccessRules(object):
}
async def check_threepid_can_be_invited(
- self, medium: str, address: str, state_events: StateMap[EventBase],
+ self,
+ medium: str,
+ address: str,
+ state_events: StateMap[EventBase],
) -> bool:
"""Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited.
@@ -309,7 +317,9 @@ class RoomAccessRules(object):
return True
async def check_event_allowed(
- self, event: EventBase, state_events: StateMap[EventBase],
+ self,
+ event: EventBase,
+ state_events: StateMap[EventBase],
) -> bool:
"""Implements synapse.events.ThirdPartyEventRules.check_event_allowed.
@@ -434,7 +444,10 @@ class RoomAccessRules(object):
)
async def _on_membership_or_invite(
- self, event: EventBase, rule: str, state_events: StateMap[EventBase],
+ self,
+ event: EventBase,
+ rule: str,
+ state_events: StateMap[EventBase],
) -> bool:
"""Applies the correct rule for incoming m.room.member and
m.room.third_party_invite events.
@@ -659,7 +672,9 @@ class RoomAccessRules(object):
return True
def _on_membership_or_invite_direct(
- self, event: EventBase, state_events: StateMap[EventBase],
+ self,
+ event: EventBase,
+ state_events: StateMap[EventBase],
) -> bool:
"""Implements the checks and behaviour specified for the "direct" rule.
diff --git a/synapse/types.py b/synapse/types.py
index 7c9f716804..00655f6dd2 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -36,8 +36,17 @@ import attr
from signedjson.key import decode_verify_key_bytes
from six.moves import filter
from unpaddedbase64 import decode_base64
+from zope.interface import Interface
+
+from twisted.internet.interfaces import (
+ IReactorCore,
+ IReactorPluggableNameResolver,
+ IReactorTCP,
+ IReactorTime,
+)
from synapse.api.errors import Codes, SynapseError
+from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService
@@ -67,33 +76,40 @@ MutableStateMap = MutableMapping[StateKey, T]
JsonDict = Dict[str, Any]
-class Requester(
- namedtuple(
- "Requester",
- [
- "user",
- "access_token_id",
- "is_guest",
- "shadow_banned",
- "device_id",
- "app_service",
- "authenticated_entity",
- ],
- )
+# Note that this seems to require inheriting *directly* from Interface in order
+# for mypy-zope to realize it is an interface.
+class ISynapseReactor(
+ IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
):
+ """The interfaces necessary for Synapse to function."""
+
+
+@attr.s(frozen=True, slots=True)
+class Requester:
"""
Represents the user making a request
Attributes:
- user (UserID): id of the user making the request
- access_token_id (int|None): *ID* of the access token used for this
+ user: id of the user making the request
+ access_token_id: *ID* of the access token used for this
request, or None if it came via the appservice API or similar
- is_guest (bool): True if the user making this request is a guest user
- shadow_banned (bool): True if the user making this request has been shadow-banned.
- device_id (str|None): device_id which was set at authentication time
- app_service (ApplicationService|None): the AS requesting on behalf of the user
+ is_guest: True if the user making this request is a guest user
+ shadow_banned: True if the user making this request has been shadow-banned.
+ device_id: device_id which was set at authentication time
+ app_service: the AS requesting on behalf of the user
+ authenticated_entity: The entity that authenticated when making the request.
+ This is different to the user_id when an admin user or the server is
+ "puppeting" the user.
"""
+ user = attr.ib(type="UserID")
+ access_token_id = attr.ib(type=Optional[int])
+ is_guest = attr.ib(type=bool)
+ shadow_banned = attr.ib(type=bool)
+ device_id = attr.ib(type=Optional[str])
+ app_service = attr.ib(type=Optional["ApplicationService"])
+ authenticated_entity = attr.ib(type=str)
+
def serialize(self):
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@@ -141,23 +157,23 @@ class Requester(
def create_requester(
user_id: Union[str, "UserID"],
access_token_id: Optional[int] = None,
- is_guest: Optional[bool] = False,
- shadow_banned: Optional[bool] = False,
+ is_guest: bool = False,
+ shadow_banned: bool = False,
device_id: Optional[str] = None,
app_service: Optional["ApplicationService"] = None,
authenticated_entity: Optional[str] = None,
-):
+) -> Requester:
"""
Create a new ``Requester`` object
Args:
- user_id (str|UserID): id of the user making the request
- access_token_id (int|None): *ID* of the access token used for this
+ user_id: id of the user making the request
+ access_token_id: *ID* of the access token used for this
request, or None if it came via the appservice API or similar
- is_guest (bool): True if the user making this request is a guest user
- shadow_banned (bool): True if the user making this request is shadow-banned.
- device_id (str|None): device_id which was set at authentication time
- app_service (ApplicationService|None): the AS requesting on behalf of the user
+ is_guest: True if the user making this request is a guest user
+ shadow_banned: True if the user making this request is shadow-banned.
+ device_id: device_id which was set at authentication time
+ app_service: the AS requesting on behalf of the user
authenticated_entity: The entity that authenticated when making the request.
This is different to the user_id when an admin user or the server is
"puppeting" the user.
@@ -258,8 +274,13 @@ class DomainSpecificString(
@classmethod
def is_valid(cls: Type[DS], s: str) -> bool:
+ """Parses the input string and attempts to ensure it is valid."""
try:
- cls.from_string(s)
+ obj = cls.from_string(s)
+ # Apply additional validation to the domain. This is only done
+ # during is_valid (and not part of from_string) since it is
+ # possible for invalid data to exist in room-state, etc.
+ parse_and_validate_server_name(obj.domain)
return True
except Exception:
return False
@@ -363,15 +384,17 @@ NON_MXID_CHARACTER_PATTERN = re.compile(
)
-def map_username_to_mxid_localpart(username, case_sensitive=False):
+def map_username_to_mxid_localpart(
+ username: Union[str, bytes], case_sensitive: bool = False
+) -> str:
"""Map a username onto a string suitable for a MXID
This follows the algorithm laid out at
https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
Args:
- username (unicode|bytes): username to be mapped
- case_sensitive (bool): true if TEST and test should be mapped
+ username: username to be mapped
+ case_sensitive: true if TEST and test should be mapped
onto different mxids
Returns:
@@ -475,8 +498,7 @@ class RoomStreamToken:
)
def __attrs_post_init__(self):
- """Validates that both `topological` and `instance_map` aren't set.
- """
+ """Validates that both `topological` and `instance_map` aren't set."""
if self.instance_map and self.topological:
raise ValueError(
@@ -504,7 +526,11 @@ class RoomStreamToken:
instance_name = await store.get_name_from_instance_id(instance_id)
instance_map[instance_name] = pos
- return cls(topological=None, stream=stream, instance_map=instance_map,)
+ return cls(
+ topological=None,
+ stream=stream,
+ instance_map=instance_map,
+ )
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@@ -681,7 +707,7 @@ class PersistedEventPosition:
persisted in the same room after this position will be after the
returned `RoomStreamToken`.
- Note: no guarentees are made about ordering w.r.t. events in other
+ Note: no guarantees are made about ordering w.r.t. events in other
rooms.
"""
# Doing the naive thing satisfies the desired properties described in
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 382f0cf3f0..c3b2d981ea 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -15,10 +15,12 @@
# limitations under the License.
import collections
+import inspect
import logging
from contextlib import contextmanager
from typing import (
Any,
+ Awaitable,
Callable,
Dict,
Hashable,
@@ -74,11 +76,16 @@ class ObservableDeferred:
def callback(r):
object.__setattr__(self, "_result", (True, r))
while self._observers:
+ observer = self._observers.pop()
try:
- # TODO: Handle errors here.
- self._observers.pop().callback(r)
- except Exception:
- pass
+ observer.callback(r)
+ except Exception as e:
+ logger.exception(
+ "%r threw an exception on .callback(%r), ignoring...",
+ observer,
+ r,
+ exc_info=e,
+ )
return r
def errback(f):
@@ -88,11 +95,16 @@ class ObservableDeferred:
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
+ observer = self._observers.pop()
try:
- # TODO: Handle errors here.
- self._observers.pop().errback(f)
- except Exception:
- pass
+ observer.errback(f)
+ except Exception as e:
+ logger.exception(
+ "%r threw an exception on .errback(%r), ignoring...",
+ observer,
+ f,
+ exc_info=e,
+ )
if consumeErrors:
return None
@@ -250,8 +262,7 @@ class Linearizer:
self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry]
def is_queued(self, key: Hashable) -> bool:
- """Checks whether there is a process queued up waiting
- """
+ """Checks whether there is a process queued up waiting"""
entry = self.key_to_defer.get(key)
if not entry:
# No entry so nothing is waiting.
@@ -450,7 +461,9 @@ R = TypeVar("R")
def timeout_deferred(
- deferred: defer.Deferred, timeout: float, reactor: IReactorTime,
+ deferred: defer.Deferred,
+ timeout: float,
+ reactor: IReactorTime,
) -> defer.Deferred:
"""The in built twisted `Deferred.addTimeout` fails to time out deferreds
that have a canceller that throws exceptions. This method creates a new
@@ -483,7 +496,7 @@ def timeout_deferred(
try:
deferred.cancel()
- except: # noqa: E722, if we throw any exception it'll break time outs
+ except Exception: # if we throw any exception it'll break time outs
logger.exception("Canceller failed during timeout")
# the cancel() call should have set off a chain of errbacks which
@@ -495,7 +508,7 @@ def timeout_deferred(
delayed_call = reactor.callLater(timeout, time_it_out)
def convert_cancelled(value: failure.Failure):
- # if the orgininal deferred was cancelled, and our timeout has fired, then
+ # if the original deferred was cancelled, and our timeout has fired, then
# the reason it was cancelled was due to our timeout. Turn the CancelledError
# into a TimeoutError.
if timed_out[0] and value.check(CancelledError):
@@ -527,8 +540,7 @@ def timeout_deferred(
@attr.s(slots=True, frozen=True)
class DoneAwaitable:
- """Simple awaitable that returns the provided value.
- """
+ """Simple awaitable that returns the provided value."""
value = attr.ib()
@@ -542,11 +554,10 @@ class DoneAwaitable:
raise StopIteration(self.value)
-def maybe_awaitable(value):
- """Convert a value to an awaitable if not already an awaitable.
- """
-
- if hasattr(value, "__await__"):
+def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
+ """Convert a value to an awaitable if not already an awaitable."""
+ if inspect.isawaitable(value):
+ assert isinstance(value, Awaitable)
return value
return DoneAwaitable(value)
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 89f0b38535..48f64eeb38 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -25,8 +25,8 @@ from synapse.config.cache import add_resizable_cache
logger = logging.getLogger(__name__)
-caches_by_name = {}
-collectors_by_name = {} # type: Dict
+caches_by_name = {} # type: Dict[str, Sized]
+collectors_by_name = {} # type: Dict[str, CacheMetric]
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
@@ -116,7 +116,7 @@ def register_cache(
"""
if resizable:
if not resize_callback:
- resize_callback = getattr(cache, "set_cache_factor")
+ resize_callback = cache.set_cache_factor # type: ignore
add_resizable_cache(cache_name, resize_callback)
metric = CacheMetric(cache, cache_type, cache_name, collect_callback)
@@ -149,8 +149,7 @@ KNOWN_KEYS = {
def intern_string(string):
- """Takes a (potentially) unicode string and interns it if it's ascii
- """
+ """Takes a (potentially) unicode string and interns it if it's ascii"""
if string is None:
return None
@@ -161,8 +160,7 @@ def intern_string(string):
def intern_dict(dictionary):
- """Takes a dictionary and interns well known keys and their values
- """
+ """Takes a dictionary and interns well known keys and their values"""
return {
KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
for key, value in dictionary.items()
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
new file mode 100644
index 0000000000..3ee0f2317a
--- /dev/null
+++ b/synapse/util/caches/cached_call.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
+
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+from synapse.logging.context import make_deferred_yieldable, run_in_background
+
+TV = TypeVar("TV")
+
+
+class CachedCall(Generic[TV]):
+ """A wrapper for asynchronous calls whose results should be shared
+
+ This is useful for wrapping asynchronous functions, where there might be multiple
+ callers, but we only want to call the underlying function once (and have the result
+ returned to all callers).
+
+ Similar results can be achieved via a lock of some form, but that typically requires
+ more boilerplate (and ends up being less efficient).
+
+ Correctly handles Synapse logcontexts (logs and resource usage for the underlying
+ function are logged against the logcontext which is active when get() is first
+ called).
+
+ Example usage:
+
+ _cached_val = CachedCall(_load_prop)
+
+ async def handle_request() -> X:
+ # We can call this multiple times, but it will result in a single call to
+ # _load_prop().
+ return await _cached_val.get()
+
+ async def _load_prop() -> X:
+ await difficult_operation()
+
+
+ The implementation is deliberately single-shot (ie, once the call is initiated,
+ there is no way to ask for it to be run). This keeps the implementation and
+ semantics simple. If you want to make a new call, simply replace the whole
+ CachedCall object.
+ """
+
+ __slots__ = ["_callable", "_deferred", "_result"]
+
+ def __init__(self, f: Callable[[], Awaitable[TV]]):
+ """
+ Args:
+ f: The underlying function. Only one call to this function will be alive
+ at once (per instance of CachedCall)
+ """
+ self._callable = f # type: Optional[Callable[[], Awaitable[TV]]]
+ self._deferred = None # type: Optional[Deferred]
+ self._result = None # type: Union[None, Failure, TV]
+
+ async def get(self) -> TV:
+ """Kick off the call if necessary, and return the result"""
+
+ # Fire off the callable now if this is our first time
+ if not self._deferred:
+ self._deferred = run_in_background(self._callable)
+
+ # we will never need the callable again, so make sure it can be GCed
+ self._callable = None
+
+ # once the deferred completes, store the result. We cannot simply leave the
+ # result in the deferred, since if it's a Failure, GCing the deferred
+ # would then log a critical error about unhandled Failures.
+ def got_result(r):
+ self._result = r
+
+ self._deferred.addBoth(got_result)
+
+ # TODO: consider cancellation semantics. Currently, if the call to get()
+ # is cancelled, the underlying call will continue (and any future calls
+ # will get the result/exception), which I think is *probably* ok, modulo
+ # the fact the underlying call may be logged to a cancelled logcontext,
+ # and any eventual exception may not be reported.
+
+ # we can now await the deferred, and once it completes, return the result.
+ await make_deferred_yieldable(self._deferred)
+
+ # I *think* this is the easiest way to correctly raise a Failure without having
+ # to gut-wrench into the implementation of Deferred.
+ d = Deferred()
+ d.callback(self._result)
+ return await d
+
+
+class RetryOnExceptionCachedCall(Generic[TV]):
+ """A wrapper around CachedCall which will retry the call if an exception is thrown
+
+ This is used in much the same way as CachedCall, but adds some extra functionality
+ so that if the underlying function throws an exception, then the next call to get()
+ will initiate another call to the underlying function. (Any calls to get() which
+ are already pending will raise the exception.)
+ """
+
+ slots = ["_cachedcall"]
+
+ def __init__(self, f: Callable[[], Awaitable[TV]]):
+ async def _wrapper() -> TV:
+ try:
+ return await f()
+ except Exception:
+ # the call raised an exception: replace the underlying CachedCall to
+ # trigger another call next time get() is called
+ self._cachedcall = CachedCall(_wrapper)
+ raise
+
+ self._cachedcall = CachedCall(_wrapper)
+
+ async def get(self) -> TV:
+ return await self._cachedcall.get()
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 601305487c..1adc92eb90 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -105,7 +105,7 @@ class DeferredCache(Generic[KT, VT]):
keylen=keylen,
cache_name=name,
cache_type=cache_type,
- size_callback=(lambda d: len(d)) if iterable else None,
+ size_callback=(lambda d: len(d) or 1) if iterable else None,
metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config,
) # type: LruCache[KT, VT]
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index a924140cdf..4e84379914 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -122,7 +122,8 @@ class _LruCachedFunction(Generic[F]):
def lru_cache(
- max_entries: int = 1000, cache_context: bool = False,
+ max_entries: int = 1000,
+ cache_context: bool = False,
) -> Callable[[F], _LruCachedFunction[F]]:
"""A method decorator that applies a memoizing cache around the function.
@@ -156,7 +157,9 @@ def lru_cache(
def func(orig: F) -> _LruCachedFunction[F]:
desc = LruCacheDescriptor(
- orig, max_entries=max_entries, cache_context=cache_context,
+ orig,
+ max_entries=max_entries,
+ cache_context=cache_context,
)
return cast(_LruCachedFunction[F], desc)
@@ -170,14 +173,18 @@ class LruCacheDescriptor(_CacheDescriptorBase):
sentinel = object()
def __init__(
- self, orig, max_entries: int = 1000, cache_context: bool = False,
+ self,
+ orig,
+ max_entries: int = 1000,
+ cache_context: bool = False,
):
super().__init__(orig, num_args=None, cache_context=cache_context)
self.max_entries = max_entries
def __get__(self, obj, owner):
cache = LruCache(
- cache_name=self.orig.__name__, max_size=self.max_entries,
+ cache_name=self.orig.__name__,
+ max_size=self.max_entries,
) # type: LruCache[CacheKey, Any]
get_cache_key = self.cache_key_builder
@@ -212,7 +219,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
class DeferredCacheDescriptor(_CacheDescriptorBase):
- """ A method decorator that applies a memoizing cache around the function.
+ """A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 588d2d49f2..b3b413b02c 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -15,26 +15,38 @@
import enum
import logging
import threading
-from collections import namedtuple
-from typing import Any
+from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
+
+import attr
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
-class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
+# The type of the cache keys.
+KT = TypeVar("KT")
+# The type of the dictionary keys.
+DKT = TypeVar("DKT")
+
+
+@attr.s(slots=True)
+class DictionaryEntry:
"""Returned when getting an entry from the cache
Attributes:
- full (bool): Whether the cache has the full or dict or just some keys.
+ full: Whether the cache has the full or dict or just some keys.
If not full then not all requested keys will necessarily be present
in `value`
- known_absent (set): Keys that were looked up in the dict and were not
+ known_absent: Keys that were looked up in the dict and were not
there.
- value (dict): The full or partial dict value
+ value: The full or partial dict value
"""
+ full = attr.ib(type=bool)
+ known_absent = attr.ib()
+ value = attr.ib()
+
def __len__(self):
return len(self.value)
@@ -45,21 +57,21 @@ class _Sentinel(enum.Enum):
sentinel = object()
-class DictionaryCache:
+class DictionaryCache(Generic[KT, DKT]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
- def __init__(self, name, max_entries=1000):
+ def __init__(self, name: str, max_entries: int = 1000):
self.cache = LruCache(
max_size=max_entries, cache_name=name, size_callback=len
- ) # type: LruCache[Any, DictionaryEntry]
+ ) # type: LruCache[KT, DictionaryEntry]
self.name = name
self.sequence = 0
- self.thread = None
+ self.thread = None # type: Optional[threading.Thread]
- def check_thread(self):
+ def check_thread(self) -> None:
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
@@ -69,12 +81,14 @@ class DictionaryCache:
"Cache objects can only be accessed from the main thread"
)
- def get(self, key, dict_keys=None):
+ def get(
+ self, key: KT, dict_keys: Optional[Iterable[DKT]] = None
+ ) -> DictionaryEntry:
"""Fetch an entry out of the cache
Args:
key
- dict_key(list): If given a set of keys then return only those keys
+ dict_key: If given a set of keys then return only those keys
that exist in the cache.
Returns:
@@ -95,7 +109,7 @@ class DictionaryCache:
return DictionaryEntry(False, set(), {})
- def invalidate(self, key):
+ def invalidate(self, key: KT) -> None:
self.check_thread()
# Increment the sequence number so that any SELECT statements that
@@ -103,19 +117,25 @@ class DictionaryCache:
self.sequence += 1
self.cache.pop(key, None)
- def invalidate_all(self):
+ def invalidate_all(self) -> None:
self.check_thread()
self.sequence += 1
self.cache.clear()
- def update(self, sequence, key, value, fetched_keys=None):
+ def update(
+ self,
+ sequence: int,
+ key: KT,
+ value: Dict[DKT, Any],
+ fetched_keys: Optional[Set[DKT]] = None,
+ ) -> None:
"""Updates the entry in the cache
Args:
sequence
- key (K)
- value (dict[X,Y]): The value to update the cache with.
- fetched_keys (None|set[X]): All of the dictionary keys which were
+ key
+ value: The value to update the cache with.
+ fetched_keys: All of the dictionary keys which were
fetched from the database.
If None, this is the complete value for key K. Otherwise, it
@@ -131,7 +151,9 @@ class DictionaryCache:
else:
self._update_or_insert(key, value, fetched_keys)
- def _update_or_insert(self, key, value, known_absent):
+ def _update_or_insert(
+ self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]
+ ) -> None:
# We pop and reinsert as we need to tell the cache the size may have
# changed
@@ -140,5 +162,5 @@ class DictionaryCache:
entry.known_absent.update(known_absent)
self.cache[key] = entry
- def _insert(self, key, value, known_absent):
+ def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None:
self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 32228f42ee..46ea8e0964 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -13,17 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
+from typing import Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.util import Clock
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
-if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
-
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -37,11 +35,11 @@ class ResponseCache(Generic[T]):
used rather than trying to compute a new response.
"""
- def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
+ def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
# Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
- self.clock = hs.get_clock()
+ self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
self._name = name
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index c541bf4579..644e9e778a 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -84,8 +84,7 @@ class StreamChangeCache:
return False
def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
- """Returns True if the entity may have been updated since stream_pos
- """
+ """Returns True if the entity may have been updated since stream_pos"""
assert isinstance(stream_pos, int)
if stream_pos < self._earliest_known_stream_pos:
@@ -133,8 +132,7 @@ class StreamChangeCache:
return result
def has_any_entity_changed(self, stream_pos: int) -> bool:
- """Returns if any entity has changed
- """
+ """Returns if any entity has changed"""
assert type(stream_pos) is int
if not self._cache:
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 6ce2a3d12b..96a8274940 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -15,6 +15,7 @@
import logging
import time
+from typing import Any, Callable, Dict, Generic, Tuple, TypeVar, Union
import attr
from sortedcontainers import SortedList
@@ -23,15 +24,19 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
-SENTINEL = object()
+SENTINEL = object() # type: Any
+T = TypeVar("T")
+KT = TypeVar("KT")
+VT = TypeVar("VT")
-class TTLCache:
+
+class TTLCache(Generic[KT, VT]):
"""A key/value cache implementation where each entry has its own TTL"""
- def __init__(self, cache_name, timer=time.time):
+ def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
# map from key to _CacheEntry
- self._data = {}
+ self._data = {} # type: Dict[KT, _CacheEntry]
# the _CacheEntries, sorted by expiry time
self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
@@ -40,26 +45,27 @@ class TTLCache:
self._metrics = register_cache("ttl", cache_name, self, resizable=False)
- def set(self, key, value, ttl):
+ def set(self, key: KT, value: VT, ttl: float) -> None:
"""Add/update an entry in the cache
Args:
key: key for this entry
value: value for this entry
- ttl (float): TTL for this entry, in seconds
+ ttl: TTL for this entry, in seconds
"""
expiry = self._timer() + ttl
self.expire()
e = self._data.pop(key, SENTINEL)
- if e != SENTINEL:
+ if e is not SENTINEL:
+ assert isinstance(e, _CacheEntry)
self._expiry_list.remove(e)
entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value)
self._data[key] = entry
self._expiry_list.add(entry)
- def get(self, key, default=SENTINEL):
+ def get(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Get a value from the cache
Args:
@@ -72,23 +78,23 @@ class TTLCache:
"""
self.expire()
e = self._data.get(key, SENTINEL)
- if e == SENTINEL:
+ if e is SENTINEL:
self._metrics.inc_misses()
- if default == SENTINEL:
+ if default is SENTINEL:
raise KeyError(key)
return default
+ assert isinstance(e, _CacheEntry)
self._metrics.inc_hits()
return e.value
- def get_with_expiry(self, key):
+ def get_with_expiry(self, key: KT) -> Tuple[VT, float, float]:
"""Get a value, and its expiry time, from the cache
Args:
key: key to look up
Returns:
- Tuple[Any, float, float]: the value from the cache, the expiry time
- and the TTL
+ A tuple of the value from the cache, the expiry time and the TTL
Raises:
KeyError if the entry is not found
@@ -102,7 +108,7 @@ class TTLCache:
self._metrics.inc_hits()
return e.value, e.expiry_time, e.ttl
- def pop(self, key, default=SENTINEL):
+ def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore
"""Remove a value from the cache
If key is in the cache, remove it and return its value, else return default.
@@ -118,29 +124,30 @@ class TTLCache:
"""
self.expire()
e = self._data.pop(key, SENTINEL)
- if e == SENTINEL:
+ if e is SENTINEL:
self._metrics.inc_misses()
- if default == SENTINEL:
+ if default is SENTINEL:
raise KeyError(key)
return default
+ assert isinstance(e, _CacheEntry)
self._expiry_list.remove(e)
self._metrics.inc_hits()
return e.value
- def __getitem__(self, key):
+ def __getitem__(self, key: KT) -> VT:
return self.get(key)
- def __delitem__(self, key):
+ def __delitem__(self, key: KT) -> None:
self.pop(key)
- def __contains__(self, key):
+ def __contains__(self, key: KT) -> bool:
return key in self._data
- def __len__(self):
+ def __len__(self) -> int:
self.expire()
return len(self._data)
- def expire(self):
+ def expire(self) -> None:
"""Run the expiry on the cache. Any entries whose expiry times are due will
be removed
"""
@@ -158,7 +165,7 @@ class _CacheEntry:
"""TTLCache entry"""
# expiry_time is the first attribute, so that entries are sorted by expiry.
- expiry_time = attr.ib()
- ttl = attr.ib()
+ expiry_time = attr.ib(type=float)
+ ttl = attr.ib(type=float)
key = attr.ib()
value = attr.ib()
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index f73e95393c..3c47285d05 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -105,13 +105,13 @@ class Signal:
async def do(observer):
try:
- result = observer(*args, **kwargs)
- if inspect.isawaitable(result):
- result = await result
- return result
+ return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
logger.warning(
- "%s signal observer %s failed: %r", self.name, observer, e,
+ "%s signal observer %s failed: %r",
+ self.name,
+ observer,
+ e,
)
deferreds = [run_in_background(do, o) for o in self.observers]
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 733f5e26e6..68dc632491 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -83,15 +83,13 @@ class BackgroundFileConsumer:
self._producer.resumeProducing()
def unregisterProducer(self):
- """Part of IProducer interface
- """
+ """Part of IProducer interface"""
self._producer = None
if not self._finished_deferred.called:
self._bytes_queue.put_nowait(None)
def write(self, bytes):
- """Part of IProducer interface
- """
+ """Part of IProducer interface"""
if self._write_exception:
raise self._write_exception
@@ -107,8 +105,7 @@ class BackgroundFileConsumer:
self._producer.pauseProducing()
def _writer(self):
- """This is run in a background thread to write to the file.
- """
+ """This is run in a background thread to write to the file."""
try:
while self._producer or not self._bytes_queue.empty():
# If we've paused the producer check if we should resume the
@@ -135,13 +132,11 @@ class BackgroundFileConsumer:
self._file_obj.close()
def wait(self):
- """Returns a deferred that resolves when finished writing to file
- """
+ """Returns a deferred that resolves when finished writing to file"""
return make_deferred_yieldable(self._finished_deferred)
def _resume_paused_producer(self):
- """Gets called if we should resume producing after being paused
- """
+ """Gets called if we should resume producing after being paused"""
if self._paused_producer and self._producer:
self._paused_producer = False
self._producer.resumeProducing()
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 5f7a6dd1d3..5ca2e71e60 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -36,7 +36,7 @@ def freeze(o):
def unfreeze(o):
if isinstance(o, (dict, frozendict)):
- return dict({k: unfreeze(v) for k, v in o.items()})
+ return {k: unfreeze(v) for k, v in o.items()}
if isinstance(o, (bytes, str)):
return o
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 06faeebe7f..98707c119d 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -13,8 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import heapq
from itertools import islice
-from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
+from typing import (
+ Dict,
+ Generator,
+ Iterable,
+ Iterator,
+ Mapping,
+ Sequence,
+ Set,
+ Tuple,
+ TypeVar,
+)
+
+from synapse.types import Collection
T = TypeVar("T")
@@ -46,3 +59,42 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
If the input is empty, no chunks are returned.
"""
return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
+
+
+def sorted_topologically(
+ nodes: Iterable[T],
+ graph: Mapping[T, Collection[T]],
+) -> Generator[T, None, None]:
+ """Given a set of nodes and a graph, yield the nodes in toplogical order.
+
+ For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`.
+ """
+
+ # This is implemented by Kahn's algorithm.
+
+ degree_map = {node: 0 for node in nodes}
+ reverse_graph = {} # type: Dict[T, Set[T]]
+
+ for node, edges in graph.items():
+ if node not in degree_map:
+ continue
+
+ for edge in set(edges):
+ if edge in degree_map:
+ degree_map[node] += 1
+
+ reverse_graph.setdefault(edge, set()).add(node)
+ reverse_graph.setdefault(node, set())
+
+ zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+ heapq.heapify(zero_degree)
+
+ while zero_degree:
+ node = heapq.heappop(zero_degree)
+ yield node
+
+ for edge in reverse_graph.get(node, []):
+ if edge in degree_map:
+ degree_map[edge] -= 1
+ if degree_map[edge] == 0:
+ heapq.heappush(zero_degree, edge)
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index 50516926f3..e3a8ed5b2f 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -15,7 +15,7 @@
class JsonEncodedObject:
- """ A common base class for defining protocol units that are represented
+ """A common base class for defining protocol units that are represented
as JSON.
Attributes:
@@ -39,7 +39,7 @@ class JsonEncodedObject:
"""
def __init__(self, **kwargs):
- """ Takes the dict of `kwargs` and loads all keys that are *valid*
+ """Takes the dict of `kwargs` and loads all keys that are *valid*
(i.e., are included in the `valid_keys` list) into the dictionary`
instance variable.
@@ -61,7 +61,7 @@ class JsonEncodedObject:
self.unrecognized_keys[k] = v
def get_dict(self):
- """ Converts this protocol unit into a :py:class:`dict`, ready to be
+ """Converts this protocol unit into a :py:class:`dict`, ready to be
encoded as JSON.
The keys it encodes are: `valid_keys` - `internal_keys`
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
new file mode 100644
index 0000000000..12cdd53327
--- /dev/null
+++ b/synapse/util/macaroons.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for manipulating macaroons"""
+
+from typing import Callable, Optional
+
+import pymacaroons
+from pymacaroons.exceptions import MacaroonVerificationFailedException
+
+
+def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+ """Extracts a caveat value from a macaroon token.
+
+ Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
+ and returns the extracted value.
+
+ Args:
+ macaroon: the token
+ key: the key of the caveat to extract
+
+ Returns:
+ The extracted value
+
+ Raises:
+ MacaroonVerificationFailedException: if there are conflicting values for the
+ caveat in the macaroon, or if the caveat was not found in the macaroon.
+ """
+ prefix = key + " = "
+ result = None # type: Optional[str]
+ for caveat in macaroon.caveats:
+ if not caveat.caveat_id.startswith(prefix):
+ continue
+
+ val = caveat.caveat_id[len(prefix) :]
+
+ if result is None:
+ # first time we found this caveat: record the value
+ result = val
+ elif val != result:
+ # on subsequent occurrences, raise if the value is different.
+ raise MacaroonVerificationFailedException(
+ "Conflicting values for caveat " + key
+ )
+
+ if result is not None:
+ return result
+
+ # If the caveat is not there, we raise a MacaroonVerificationFailedException.
+ # Note that it is insecure to generate a macaroon without all the caveats you
+ # might need (because there is nothing stopping people from adding extra caveats),
+ # so if the caveat isn't there, something odd must be going on.
+ raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
+
+
+def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
+ """Make a macaroon verifier which accepts 'time' caveats
+
+ Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
+ the given macaroon verifier.
+
+ Args:
+ v: the macaroon verifier
+ get_time_ms: a callable which will return the timestamp after which the caveat
+ should be considered expired. Normally the current time.
+ """
+
+ def verify_expiry_caveat(caveat: str):
+ time_msec = get_time_ms()
+ prefix = "time < "
+ if not caveat.startswith(prefix):
+ return False
+ expiry = int(caveat[len(prefix) :])
+ return time_msec < expiry
+
+ v.satisfy_general(verify_expiry_caveat)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ffdea0de8d..1023c856d1 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -108,7 +108,16 @@ class Measure:
def __init__(self, clock, name):
self.clock = clock
self.name = name
- parent_context = current_context()
+ curr_context = current_context()
+ if not curr_context:
+ logger.warning(
+ "Starting metrics collection %r from sentinel context: metrics will be lost",
+ name,
+ )
+ parent_context = None
+ else:
+ assert isinstance(curr_context, LoggingContext)
+ parent_context = curr_context
self._logging_context = LoggingContext(
"Measure[%s]" % (self.name,), parent_context
)
@@ -152,8 +161,7 @@ class Measure:
return self._logging_context.get_resource_usage()
def _update_in_flight(self, metrics):
- """Gets called when processing in flight metrics
- """
+ """Gets called when processing in flight metrics"""
duration = self.clock.time() - self.start
metrics.real_time_max = max(metrics.real_time_max, duration)
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 94b59afb38..d184e2a90c 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -15,28 +15,57 @@
import importlib
import importlib.util
+import itertools
+from typing import Any, Iterable, Tuple, Type
+
+import jsonschema
from synapse.config._base import ConfigError
+from synapse.config._util import json_error_to_config_error
+
+def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
+ """Loads a synapse module with its config
-def load_module(provider):
- """ Loads a synapse module with its config
- Take a dict with keys 'module' (the module name) and 'config'
- (the config dict).
+ Args:
+ provider: a dict with keys 'module' (the module name) and 'config'
+ (the config dict).
+ config_path: the path within the config file. This will be used as a basis
+ for any error message.
Returns
Tuple of (provider class, parsed config object)
"""
+
+ modulename = provider.get("module")
+ if not isinstance(modulename, str):
+ raise ConfigError(
+ "expected a string", path=itertools.chain(config_path, ("module",))
+ )
+
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
- module, clz = provider["module"].rsplit(".", 1)
+ module, clz = modulename.rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
+ # Load the module config. If None, pass an empty dictionary instead
+ module_config = provider.get("config") or {}
try:
- provider_config = provider_class.parse_config(provider.get("config"))
+ provider_config = provider_class.parse_config(module_config)
+ except jsonschema.ValidationError as e:
+ raise json_error_to_config_error(e, itertools.chain(config_path, ("config",)))
+ except ConfigError as e:
+ raise _wrap_config_error(
+ "Failed to parse config for module %r" % (modulename,),
+ prefix=itertools.chain(config_path, ("config",)),
+ e=e,
+ )
except Exception as e:
- raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e))
+ raise ConfigError(
+ "Failed to parse config for module %r" % (modulename,),
+ path=itertools.chain(config_path, ("config",)),
+ ) from e
return provider_class, provider_config
@@ -56,3 +85,27 @@ def load_python_module(location: str):
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore
return mod
+
+
+def _wrap_config_error(
+ msg: str, prefix: Iterable[str], e: ConfigError
+) -> "ConfigError":
+ """Wrap a relative ConfigError with a new path
+
+ This is useful when we have a ConfigError with a relative path due to a problem
+ parsing part of the config, and we now need to set it in context.
+ """
+ path = prefix
+ if e.path:
+ path = itertools.chain(prefix, e.path)
+
+ e1 = ConfigError(msg, path)
+
+ # ideally we would set the 'cause' of the new exception to the original exception;
+ # however now that we have merged the path into our own, the stringification of
+ # e will be incorrect, so instead we create a new exception with just the "msg"
+ # part.
+
+ e1.__cause__ = Exception(e.msg)
+ e1.__cause__.__cause__ = e.__cause__
+ return e1
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 72574d3af2..d9f9ae99d6 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -204,16 +204,13 @@ def _check_yield_points(f: Callable, changes: List[str]):
# We don't raise here as its perfectly valid for contexts to
# change in a function, as long as it sets the correct context
# on resolving (which is checked separately).
- err = (
- "%s changed context from %s to %s, happened between lines %d and %d in %s"
- % (
- frame.f_code.co_name,
- expected_context,
- current_context(),
- last_yield_line_no,
- frame.f_lineno,
- frame.f_code.co_filename,
- )
+ err = "%s changed context from %s to %s, happened between lines %d and %d in %s" % (
+ frame.f_code.co_name,
+ expected_context,
+ current_context(),
+ last_yield_line_no,
+ frame.f_lineno,
+ frame.f_code.co_filename,
)
changes.append(err)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 61d96a6c28..9ce7873ab5 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -18,13 +18,23 @@ import random
import re
import string
from collections.abc import Iterable
+from typing import Optional, Tuple
from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
-client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
+CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
+
+# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
+# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
+# says "there is no grammar for media ids"
+#
+# The server_name part of this is purposely lax: use parse_and_validate_mxc for
+# additional validation.
+#
+MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
# random_string and random_string_with_symbols are used for a range of things,
# some cryptographically important, some less so. We use SystemRandom to make sure
@@ -32,33 +42,118 @@ client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
rand = random.SystemRandom()
-def random_string(length):
+def random_string(length: int) -> str:
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
-def random_string_with_symbols(length):
+def random_string_with_symbols(length: int) -> str:
return "".join(rand.choice(_string_with_symbols) for _ in range(length))
-def is_ascii(s):
- if isinstance(s, bytes):
- try:
- s.decode("ascii").encode("ascii")
- except UnicodeDecodeError:
- return False
- except UnicodeEncodeError:
- return False
- return True
+def is_ascii(s: bytes) -> bool:
+ try:
+ s.decode("ascii").encode("ascii")
+ except UnicodeDecodeError:
+ return False
+ except UnicodeEncodeError:
+ return False
+ return True
-def assert_valid_client_secret(client_secret):
- """Validate that a given string matches the client_secret regex defined by the spec"""
- if client_secret_regex.match(client_secret) is None:
+def assert_valid_client_secret(client_secret: str) -> None:
+ """Validate that a given string matches the client_secret defined by the spec"""
+ if (
+ len(client_secret) <= 0
+ or len(client_secret) > 255
+ or CLIENT_SECRET_REGEX.match(client_secret) is None
+ ):
raise SynapseError(
400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
)
+def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
+ """Split a server name into host/port parts.
+
+ Args:
+ server_name: server name to parse
+
+ Returns:
+ host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ try:
+ if server_name[-1] == "]":
+ # ipv6 literal, hopefully
+ return server_name, None
+
+ domain_port = server_name.rsplit(":", 1)
+ domain = domain_port[0]
+ port = int(domain_port[1]) if domain_port[1:] else None
+ return domain, port
+ except Exception:
+ raise ValueError("Invalid server name '%s'" % server_name)
+
+
+VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
+
+
+def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
+ """Split a server name into host/port parts and do some basic validation.
+
+ Args:
+ server_name: server name to parse
+
+ Returns:
+ host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ host, port = parse_server_name(server_name)
+
+ # these tests don't need to be bulletproof as we'll find out soon enough
+ # if somebody is giving us invalid data. What we *do* need is to be sure
+ # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+ # look for ipv6 literals
+ if host[0] == "[":
+ if host[-1] != "]":
+ raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
+ return host, port
+
+ # otherwise it should only be alphanumerics.
+ if not VALID_HOST_REGEX.match(host):
+ raise ValueError(
+ "Server name '%s' contains invalid characters" % (server_name,)
+ )
+
+ return host, port
+
+
+def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
+ """Parse the given string as an MXC URI
+
+ Checks that the "server name" part is a valid server name
+
+ Args:
+ mxc: the (alleged) MXC URI to be checked
+ Returns:
+ hostname, port, media id
+ Raises:
+ ValueError if the URI cannot be parsed
+ """
+ m = MXC_REGEX.match(mxc)
+ if not m:
+ raise ValueError("mxc URI %r did not match expected format" % (mxc,))
+ server_name = m.group(1)
+ media_id = m.group(2)
+ host, port = parse_and_validate_server_name(server_name)
+ return host, port, media_id
+
+
def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
"""If iterable has maxitems or fewer, return the stringification of a list
containing those items.
@@ -75,3 +170,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
if len(items) <= maxitems:
return str(items)
return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
+
+
+def strtobool(val: str) -> bool:
+ """Convert a string representation of truth to True or False
+
+ True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
+ are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
+ 'val' is anything else.
+
+ This is lifted from distutils.util.strtobool, with the exception that it actually
+ returns a bool, rather than an int.
+ """
+ val = val.lower()
+ if val in ("y", "yes", "t", "true", "on", "1"):
+ return True
+ elif val in ("n", "no", "f", "false", "off", "0"):
+ return False
+ else:
+ raise ValueError("invalid truth value %r" % (val,))
diff --git a/synapse/util/templates.py b/synapse/util/templates.py
new file mode 100644
index 0000000000..392dae4a40
--- /dev/null
+++ b/synapse/util/templates.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for dealing with jinja2 templates"""
+
+import time
+import urllib.parse
+from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
+
+import jinja2
+
+if TYPE_CHECKING:
+ from synapse.config.homeserver import HomeServerConfig
+
+
+def build_jinja_env(
+ template_search_directories: Iterable[str],
+ config: "HomeServerConfig",
+ autoescape: Union[bool, Callable[[str], bool], None] = None,
+) -> jinja2.Environment:
+ """Set up a Jinja2 environment to load templates from the given search path
+
+ The returned environment defines the following filters:
+ - format_ts: formats timestamps as strings in the server's local timezone
+ (XXX: why is that useful??)
+ - mxc_to_http: converts mxc: uris to http URIs. Args are:
+ (uri, width, height, resize_method="crop")
+
+ and the following global variables:
+ - server_name: matrix server name
+
+ Args:
+ template_search_directories: directories to search for templates
+
+ config: homeserver config, for things like `server_name` and `public_baseurl`
+
+ autoescape: whether template variables should be autoescaped. bool, or
+ a function mapping from template name to bool. Defaults to escaping templates
+ whose names end in .html, .xml or .htm.
+
+ Returns:
+ jinja environment
+ """
+
+ if autoescape is None:
+ autoescape = jinja2.select_autoescape()
+
+ loader = jinja2.FileSystemLoader(template_search_directories)
+ env = jinja2.Environment(loader=loader, autoescape=autoescape)
+
+ # Update the environment with our custom filters
+ env.filters.update(
+ {
+ "format_ts": _format_ts_filter,
+ "mxc_to_http": _create_mxc_to_http_filter(config.public_baseurl),
+ }
+ )
+
+ # common variables for all templates
+ env.globals.update({"server_name": config.server_name})
+
+ return env
+
+
+def _create_mxc_to_http_filter(
+ public_baseurl: Optional[str],
+) -> Callable[[str, int, int, str], str]:
+ """Create and return a jinja2 filter that converts MXC urls to HTTP
+
+ Args:
+ public_baseurl: The public, accessible base URL of the homeserver
+ """
+
+ def mxc_to_http_filter(
+ value: str, width: int, height: int, resize_method: str = "crop"
+ ) -> str:
+ if not public_baseurl:
+ raise RuntimeError(
+ "public_baseurl must be set in the homeserver config to convert MXC URLs to HTTP URLs."
+ )
+
+ if value[0:6] != "mxc://":
+ return ""
+
+ server_and_media_id = value[6:]
+ fragment = None
+ if "#" in server_and_media_id:
+ server_and_media_id, fragment = server_and_media_id.split("#", 1)
+ fragment = "#" + fragment
+
+ params = {"width": width, "height": height, "method": resize_method}
+ return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+ public_baseurl,
+ server_and_media_id,
+ urllib.parse.urlencode(params),
+ fragment or "",
+ )
+
+ return mxc_to_http_filter
+
+
+def _format_ts_filter(value: int, format: str):
+ return time.strftime(format, time.localtime(value / 1000))
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 527365498e..ff53a49b3a 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -12,20 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-import operator
+from typing import Dict, FrozenSet, List, Optional
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import (
+ AccountDataTypes,
+ EventTypes,
+ HistoryVisibility,
+ Membership,
+)
+from synapse.events import EventBase
from synapse.events.utils import prune_event
from synapse.storage import Storage
from synapse.storage.state import StateFilter
-from synapse.types import get_domain_from_id
+from synapse.types import StateMap, get_domain_from_id
logger = logging.getLogger(__name__)
-VISIBILITY_PRIORITY = ("world_readable", "shared", "invited", "joined")
+VISIBILITY_PRIORITY = (
+ HistoryVisibility.WORLD_READABLE,
+ HistoryVisibility.SHARED,
+ HistoryVisibility.INVITED,
+ HistoryVisibility.JOINED,
+)
MEMBERSHIP_PRIORITY = (
@@ -39,38 +49,39 @@ MEMBERSHIP_PRIORITY = (
async def filter_events_for_client(
storage: Storage,
- user_id,
- events,
- is_peeking=False,
- always_include_ids=frozenset(),
- filter_send_to_client=True,
-):
+ user_id: str,
+ events: List[EventBase],
+ is_peeking: bool = False,
+ always_include_ids: FrozenSet[str] = frozenset(),
+ filter_send_to_client: bool = True,
+) -> List[EventBase]:
"""
Check which events a user is allowed to see. If the user can see the event but its
sender asked for their data to be erased, prune the content of the event.
Args:
storage
- user_id(str): user id to be checked
- events(list[synapse.events.EventBase]): sequence of events to be checked
- is_peeking(bool): should be True if:
+ user_id: user id to be checked
+ events: sequence of events to be checked
+ is_peeking: should be True if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
- always_include_ids (set(event_id)): set of event ids to specifically
+ always_include_ids: set of event ids to specifically
include (unless sender is ignored)
- filter_send_to_client (bool): Whether we're checking an event that's going to be
+ filter_send_to_client: Whether we're checking an event that's going to be
sent to a client. This might not always be the case since this function can
also be called to check whether a user can see the state at a given point.
Returns:
- list[synapse.events.EventBase]
+ The filtered events.
"""
# Filter out events that have been soft failed so that we don't relay them
# to clients.
events = [e for e in events if not e.internal_metadata.is_soft_failed()]
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
+
event_id_to_state = await storage.state.get_state_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types),
@@ -80,7 +91,7 @@ async def filter_events_for_client(
AccountDataTypes.IGNORED_USER_LIST, user_id
)
- ignore_list = frozenset()
+ ignore_list = frozenset() # type: FrozenSet[str]
if ignore_dict_content:
ignored_users_dict = ignore_dict_content.get("ignored_users", {})
if isinstance(ignored_users_dict, dict):
@@ -97,26 +108,25 @@ async def filter_events_for_client(
room_id
] = await storage.main.get_retention_policy_for_room(room_id)
- def allowed(event):
+ def allowed(event: EventBase) -> Optional[EventBase]:
"""
Args:
- event (synapse.events.EventBase): event to check
+ event: event to check
Returns:
- None|EventBase:
- None if the user cannot see this event at all
+ None if the user cannot see this event at all
- a redacted copy of the event if they can only see a redacted
- version
+ a redacted copy of the event if they can only see a redacted
+ version
- the original event if they can see it as normal.
+ the original event if they can see it as normal.
"""
# Only run some checks if these events aren't about to be sent to clients. This is
# because, if this is not the case, we're probably only checking if the users can
# see events in the room at that point in the DAG, and that shouldn't be decided
# on those checks.
if filter_send_to_client:
- if event.type == "org.matrix.dummy_event":
+ if event.type == EventTypes.Dummy:
return None
if not event.is_state() and event.sender in ignore_list:
@@ -150,12 +160,14 @@ async def filter_events_for_client(
# get the room_visibility at the time of the event.
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
- visibility = visibility_event.content.get("history_visibility", "shared")
+ visibility = visibility_event.content.get(
+ "history_visibility", HistoryVisibility.SHARED
+ )
else:
- visibility = "shared"
+ visibility = HistoryVisibility.SHARED
if visibility not in VISIBILITY_PRIORITY:
- visibility = "shared"
+ visibility = HistoryVisibility.SHARED
# Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive
@@ -165,7 +177,7 @@ async def filter_events_for_client(
prev_visibility = prev_content.get("history_visibility", None)
if prev_visibility not in VISIBILITY_PRIORITY:
- prev_visibility = "shared"
+ prev_visibility = HistoryVisibility.SHARED
new_priority = VISIBILITY_PRIORITY.index(visibility)
old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
@@ -210,19 +222,19 @@ async def filter_events_for_client(
# otherwise, it depends on the room visibility.
- if visibility == "joined":
+ if visibility == HistoryVisibility.JOINED:
# we weren't a member at the time of the event, so we can't
# see this event.
return None
- elif visibility == "invited":
+ elif visibility == HistoryVisibility.INVITED:
# user can also see the event if they were *invited* at the time
# of the event.
return event if membership == Membership.INVITE else None
- elif visibility == "shared" and is_peeking:
+ elif visibility == HistoryVisibility.SHARED and is_peeking:
# if the visibility is shared, users cannot see the event unless
- # they have *subequently* joined the room (or were members at the
+ # they have *subsequently* joined the room (or were members at the
# time, of course)
#
# XXX: if the user has subsequently joined and then left again,
@@ -240,52 +252,52 @@ async def filter_events_for_client(
return event
- # check each event: gives an iterable[None|EventBase]
+ # Check each event: gives an iterable of None or (a potentially modified)
+ # EventBase.
filtered_events = map(allowed, events)
- # remove the None entries
- filtered_events = filter(operator.truth, filtered_events)
-
- # we turn it into a list before returning it.
- return list(filtered_events)
+ # Turn it into a list and remove None entries before returning.
+ return [ev for ev in filtered_events if ev]
async def filter_events_for_server(
storage: Storage,
- server_name,
- events,
- redact=True,
- check_history_visibility_only=False,
-):
+ server_name: str,
+ events: List[EventBase],
+ redact: bool = True,
+ check_history_visibility_only: bool = False,
+) -> List[EventBase]:
"""Filter a list of events based on whether given server is allowed to
see them.
Args:
storage
- server_name (str)
- events (iterable[FrozenEvent])
- redact (bool): Whether to return a redacted version of the event, or
+ server_name
+ events
+ redact: Whether to return a redacted version of the event, or
to filter them out entirely.
- check_history_visibility_only (bool): Whether to only check the
+ check_history_visibility_only: Whether to only check the
history visibility, rather than things like if the sender has been
erased. This is used e.g. during pagination to decide whether to
backfill or not.
Returns
- list[FrozenEvent]
+ The filtered events.
"""
- def is_sender_erased(event, erased_senders):
+ def is_sender_erased(event: EventBase, erased_senders: Dict[str, bool]) -> bool:
if erased_senders and erased_senders[event.sender]:
logger.info("Sender of %s has been erased, redacting", event.event_id)
return True
return False
- def check_event_is_visible(event, state):
+ def check_event_is_visible(event: EventBase, state: StateMap[EventBase]) -> bool:
history = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if history:
- visibility = history.content.get("history_visibility", "shared")
- if visibility in ["invited", "joined"]:
+ visibility = history.content.get(
+ "history_visibility", HistoryVisibility.SHARED
+ )
+ if visibility in [HistoryVisibility.INVITED, HistoryVisibility.JOINED]:
# We now loop through all state events looking for
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
@@ -305,7 +317,7 @@ async def filter_events_for_server(
if memtype == Membership.JOIN:
return True
elif memtype == Membership.INVITE:
- if visibility == "invited":
+ if visibility == HistoryVisibility.INVITED:
return True
else:
# server has no users in the room: redact
@@ -336,7 +348,8 @@ async def filter_events_for_server(
else:
event_map = await storage.main.get_events(visibility_ids)
all_open = all(
- e.content.get("history_visibility") in (None, "shared", "world_readable")
+ e.content.get("history_visibility")
+ in (None, HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE)
for e in event_map.values()
)
diff --git a/synctl b/synctl
index cfa9cec0c4..56c0e3940f 100755
--- a/synctl
+++ b/synctl
@@ -30,7 +30,7 @@ import yaml
from synapse.config import find_config_files
-SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
+SYNAPSE = [sys.executable, "-m", "synapse.app.homeserver"]
GREEN = "\x1b[1;32m"
YELLOW = "\x1b[1;33m"
@@ -117,7 +117,6 @@ def start_worker(app: str, configfile: str, worker_configfile: str) -> bool:
args = [
sys.executable,
- "-B",
"-m",
app,
"-c",
diff --git a/synmark/__main__.py b/synmark/__main__.py
index de13c1a909..f55968a5a4 100644
--- a/synmark/__main__.py
+++ b/synmark/__main__.py
@@ -96,5 +96,6 @@ if __name__ == "__main__":
runner.args.loops = orig_loops
loops = "auto"
runner.bench_time_func(
- suite.__name__ + "_" + str(loops), make_test(suite.main),
+ suite.__name__ + "_" + str(loops),
+ make_test(suite.main),
)
diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py
index c9d9cf761e..c306891b27 100644
--- a/synmark/suites/logging.py
+++ b/synmark/suites/logging.py
@@ -98,7 +98,9 @@ async def main(reactor, loops):
logger = logging.getLogger("synapse.logging.test_terse_json")
_setup_stdlib_logging(
- hs_config, log_config, logBeginner=beginner,
+ hs_config,
+ log_config,
+ logBeginner=beginner,
)
# Wait for it to connect...
diff --git a/test_postgresql.sh b/test_postgresql.sh
index 1ffcaabd31..c10828fbbc 100755
--- a/test_postgresql.sh
+++ b/test_postgresql.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
# This script builds the Docker image to run the PostgreSQL tests, and then runs
# the tests.
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index ee5217b074..34f72ae795 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -17,8 +17,6 @@ from mock import Mock
import pymacaroons
-from twisted.internet import defer
-
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import (
@@ -33,19 +31,17 @@ from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID
from tests import unittest
-from tests.utils import mock_getRawHeaders, setup_test_homeserver
+from tests.test_utils import simple_async_mock
+from tests.utils import mock_getRawHeaders
-class AuthTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.state_handler = Mock()
+class AuthTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = Mock()
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.hs.get_datastore = Mock(return_value=self.store)
- self.hs.get_auth_handler().store = self.store
- self.auth = Auth(self.hs)
+ hs.get_datastore = Mock(return_value=self.store)
+ hs.get_auth_handler().store = self.store
+ self.auth = Auth(hs)
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
@@ -57,64 +53,59 @@ class AuthTestCase(unittest.TestCase):
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
- self.store.is_support_user = Mock(return_value=defer.succeed(False))
+ self.store.insert_client_ip = simple_async_mock(None)
+ self.store.is_support_user = simple_async_mock(False)
- @defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device"
)
- self.store.get_user_by_access_token = Mock(
- return_value=defer.succeed(user_info)
- )
+ self.store.get_user_by_access_token = simple_async_mock(user_info)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
+ requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = defer.ensureDeferred(self.auth.get_user_by_req(request))
- f = self.failureResultOf(d, InvalidClientTokenError).value
+ f = self.get_failure(
+ self.auth.get_user_by_req(request), InvalidClientTokenError
+ ).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
- self.store.get_user_by_access_token = Mock(
- return_value=defer.succeed(user_info)
- )
+ self.store.get_user_by_access_token = simple_async_mock(user_info)
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = defer.ensureDeferred(self.auth.get_user_by_req(request))
- f = self.failureResultOf(d, MissingClientTokenError).value
+ f = self.get_failure(
+ self.auth.get_user_by_req(request), MissingClientTokenError
+ ).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
- @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token(self):
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
+ requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
- @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_good_ip(self):
from netaddr import IPSet
@@ -125,13 +116,13 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
+ requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
@@ -144,42 +135,44 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = defer.ensureDeferred(self.auth.get_user_by_req(request))
- f = self.failureResultOf(d, InvalidClientTokenError).value
+ f = self.get_failure(
+ self.auth.get_user_by_req(request), InvalidClientTokenError
+ ).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = defer.ensureDeferred(self.auth.get_user_by_req(request))
- f = self.failureResultOf(d, InvalidClientTokenError).value
+ f = self.get_failure(
+ self.auth.get_user_by_req(request), InvalidClientTokenError
+ ).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = defer.ensureDeferred(self.auth.get_user_by_req(request))
- f = self.failureResultOf(d, MissingClientTokenError).value
+ f = self.get_failure(
+ self.auth.get_user_by_req(request), MissingClientTokenError
+ ).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
- @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock(
@@ -188,17 +181,15 @@ class AuthTestCase(unittest.TestCase):
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
- self.store.get_user_by_id = Mock(
- return_value=defer.succeed({"is_guest": False})
- )
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
+ requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(
requester.user.to_string(), masquerading_user_id.decode("utf8")
)
@@ -210,22 +201,18 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = defer.ensureDeferred(self.auth.get_user_by_req(request))
- self.failureResultOf(d, AuthError)
+ self.get_failure(self.auth.get_user_by_req(request), AuthError)
- @defer.inlineCallbacks
def test_get_user_from_macaroon(self):
- self.store.get_user_by_access_token = Mock(
- return_value=defer.succeed(
- TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
- )
+ self.store.get_user_by_access_token = simple_async_mock(
+ TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
)
user_id = "@baldrick:matrix.org"
@@ -237,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
- user_info = yield defer.ensureDeferred(
+ user_info = self.get_success(
self.auth.get_user_by_access_token(macaroon.serialize())
)
self.assertEqual(user_id, user_info.user_id)
@@ -246,10 +233,9 @@ class AuthTestCase(unittest.TestCase):
# from the db.
self.assertEqual(user_info.device_id, "device")
- @defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
- self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_id = simple_async_mock({"is_guest": True})
+ self.store.get_user_by_access_token = simple_async_mock(None)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@@ -263,20 +249,17 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()
- user_info = yield defer.ensureDeferred(
- self.auth.get_user_by_access_token(serialized)
- )
+ user_info = self.get_success(self.auth.get_user_by_access_token(serialized))
self.assertEqual(user_id, user_info.user_id)
self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
- @defer.inlineCallbacks
def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org"
- self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
- self.store.get_device = Mock(return_value=defer.succeed(None))
+ self.store.add_access_token_to_user = simple_async_mock(None)
+ self.store.get_device = simple_async_mock(None)
- token = yield defer.ensureDeferred(
+ token = self.get_success(
self.hs.get_auth_handler().get_access_token_for_user_id(
USER_ID, "DEVICE", valid_until_ms=None
)
@@ -289,25 +272,24 @@ class AuthTestCase(unittest.TestCase):
puppets_user_id=None,
)
- def get_user(tok):
+ async def get_user(tok):
if token != tok:
- return defer.succeed(None)
- return defer.succeed(
- TokenLookupResult(
- user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
- )
+ return None
+ return TokenLookupResult(
+ user_id=USER_ID,
+ is_guest=False,
+ token_id=1234,
+ device_id="DEVICE",
)
self.store.get_user_by_access_token = get_user
- self.store.get_user_by_id = Mock(
- return_value=defer.succeed({"is_guest": False})
- )
+ self.store.get_user_by_id = simple_async_mock({"is_guest": False})
# check the token works
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield defer.ensureDeferred(
+ requester = self.get_success(
self.auth.get_user_by_req(request, allow_guest=True)
)
self.assertEqual(UserID.from_string(USER_ID), requester.user)
@@ -323,17 +305,16 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [guest_tok.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- with self.assertRaises(InvalidClientCredentialsError) as cm:
- yield defer.ensureDeferred(
- self.auth.get_user_by_req(request, allow_guest=True)
- )
+ cm = self.get_failure(
+ self.auth.get_user_by_req(request, allow_guest=True),
+ InvalidClientCredentialsError,
+ )
- self.assertEqual(401, cm.exception.code)
- self.assertEqual("Guest access token used for regular user", cm.exception.msg)
+ self.assertEqual(401, cm.value.code)
+ self.assertEqual("Guest access token used for regular user", cm.value.msg)
self.store.get_user_by_id.assert_called_with(USER_ID)
- @defer.inlineCallbacks
def test_blocking_mau(self):
self.auth_blocking._limit_usage_by_mau = False
self.auth_blocking._max_mau_value = 50
@@ -341,77 +322,61 @@ class AuthTestCase(unittest.TestCase):
small_number_of_users = 1
# Ensure no error thrown
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
+ self.get_success(self.auth.check_auth_blocking())
self.auth_blocking._limit_usage_by_mau = True
- self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(lots_of_users)
- )
+ self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
- with self.assertRaises(ResourceLimitError) as e:
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
- self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- self.assertEquals(e.exception.code, 403)
+ e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ self.assertEquals(e.value.code, 403)
# Ensure does not throw an error
- self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(small_number_of_users)
- )
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
+ self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
+ self.get_success(self.auth.check_auth_blocking())
- @defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self):
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True
- self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
+ self.store.get_monthly_active_count = simple_async_mock(100)
# Support users allowed
- yield defer.ensureDeferred(
- self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
- )
- self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
+ self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
+ self.store.get_monthly_active_count = simple_async_mock(100)
# Bots not allowed
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(
- self.auth.check_auth_blocking(user_type=UserTypes.BOT)
- )
- self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
+ self.get_failure(
+ self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
+ )
+ self.store.get_monthly_active_count = simple_async_mock(100)
# Real users not allowed
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
+ self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
- @defer.inlineCallbacks
def test_reserved_threepid(self):
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1
- self.store.get_monthly_active_count = lambda: defer.succeed(2)
+ self.store.get_monthly_active_count = simple_async_mock(2)
threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.auth_blocking._mau_limits_reserved_threepids = [threepid]
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
+ self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(
- self.auth.check_auth_blocking(threepid=unknown_threepid)
- )
+ self.get_failure(
+ self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
+ )
- yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
+ self.get_success(self.auth.check_auth_blocking(threepid=threepid))
- @defer.inlineCallbacks
def test_hs_disabled(self):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
- with self.assertRaises(ResourceLimitError) as e:
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
- self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- self.assertEquals(e.exception.code, 403)
+ e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ self.assertEquals(e.value.code, 403)
- @defer.inlineCallbacks
def test_hs_disabled_no_server_notices_user(self):
"""Check that 'hs_disabled_message' works correctly when there is no
server_notices user.
@@ -422,16 +387,14 @@ class AuthTestCase(unittest.TestCase):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
- with self.assertRaises(ResourceLimitError) as e:
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
- self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- self.assertEquals(e.exception.code, 403)
+ e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ self.assertEquals(e.value.code, 403)
- @defer.inlineCallbacks
def test_server_notices_mxid_special_cased(self):
self.auth_blocking._hs_disabled = True
user = "@user:server"
self.auth_blocking._server_notices_mxid = user
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
- yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
+ self.get_success(self.auth.check_auth_blocking(user))
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index bcf1a8010e..ab7d290724 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -16,19 +16,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
-
import jsonschema
-from twisted.internet import defer
-
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict
from tests import unittest
-from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver
user_localpart = "test_user"
@@ -41,22 +36,9 @@ def MockEvent(**kwargs):
return make_event_from_dict(kwargs)
-class FilteringTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_federation_resource = MockHttpResource()
-
- self.mock_http_client = Mock(spec=[])
- self.mock_http_client.put_json = DeferredMockCallable()
-
- hs = yield setup_test_homeserver(
- self.addCleanup,
- federation_http_client=self.mock_http_client,
- keyring=Mock(),
- )
-
+class FilteringTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.filtering = hs.get_filtering()
-
self.datastore = hs.get_datastore()
def test_errors_on_invalid_filters(self):
@@ -365,10 +347,9 @@ class FilteringTestCase(unittest.TestCase):
self.assertTrue(Filter(definition).check(event))
- @defer.inlineCallbacks
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
- filter_id = yield defer.ensureDeferred(
+ filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
@@ -376,7 +357,7 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event]
- user_filter = yield defer.ensureDeferred(
+ user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
@@ -385,11 +366,10 @@ class FilteringTestCase(unittest.TestCase):
results = user_filter.filter_presence(events=events)
self.assertEquals(events, results)
- @defer.inlineCallbacks
def test_filter_presence_no_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
- filter_id = yield defer.ensureDeferred(
+ filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart + "2", user_filter=user_filter_json
)
@@ -401,7 +381,7 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield defer.ensureDeferred(
+ user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart + "2", filter_id=filter_id
)
@@ -410,10 +390,9 @@ class FilteringTestCase(unittest.TestCase):
results = user_filter.filter_presence(events=events)
self.assertEquals([], results)
- @defer.inlineCallbacks
def test_filter_room_state_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield defer.ensureDeferred(
+ filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
@@ -421,7 +400,7 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event]
- user_filter = yield defer.ensureDeferred(
+ user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
@@ -430,10 +409,9 @@ class FilteringTestCase(unittest.TestCase):
results = user_filter.filter_room_state(events=events)
self.assertEquals(events, results)
- @defer.inlineCallbacks
def test_filter_room_state_no_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield defer.ensureDeferred(
+ filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
@@ -443,7 +421,7 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield defer.ensureDeferred(
+ user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
@@ -468,11 +446,10 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
- @defer.inlineCallbacks
def test_add_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield defer.ensureDeferred(
+ filter_id = self.get_success(
self.filtering.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
@@ -482,7 +459,7 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals(
user_filter_json,
(
- yield defer.ensureDeferred(
+ self.get_success(
self.datastore.get_user_filter(
user_localpart=user_localpart, filter_id=0
)
@@ -490,17 +467,16 @@ class FilteringTestCase(unittest.TestCase):
),
)
- @defer.inlineCallbacks
def test_get_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield defer.ensureDeferred(
+ filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
)
- filter = yield defer.ensureDeferred(
+ filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index fe504d0869..483418192c 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -43,7 +43,11 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
- None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
+ None,
+ "example.com",
+ id="foo",
+ rate_limited=True,
+ sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
@@ -68,7 +72,11 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
- None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
+ None,
+ "example.com",
+ id="foo",
+ rate_limited=False,
+ sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
@@ -113,12 +121,18 @@ class TestRatelimiter(unittest.TestCase):
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
# First attempt should be allowed
- allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,)
+ allowed, time_allowed = limiter.can_do_action(
+ ("test_id",),
+ _time_now_s=0,
+ )
self.assertTrue(allowed)
self.assertEqual(10.0, time_allowed)
# Second attempt, 1s later, will fail
- allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,)
+ allowed, time_allowed = limiter.can_do_action(
+ ("test_id",),
+ _time_now_s=1,
+ )
self.assertFalse(allowed)
self.assertEqual(10.0, time_allowed)
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 43fef5d64a..e0ca288829 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -57,7 +57,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- _, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
+ channel = make_request(self.reactor, site, "PUT", "presence/a/status")
# 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400)
@@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- _, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
+ channel = make_request(self.reactor, site, "PUT", "presence/a/status")
# 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index b260ab734d..467033e201 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -73,7 +73,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
return
raise
- _, channel = make_request(
+ channel = make_request(
self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
@@ -121,7 +121,7 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
return
raise
- _, channel = make_request(
+ channel = make_request(
self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index d3ec24c975..2b7f09c14b 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -127,8 +127,7 @@ class CacheConfigTests(TestCase):
self.assertEqual(cache.max_size, 150)
def test_cache_with_asterisk_in_name(self):
- """Some caches have asterisks in their name, test that they are set correctly.
- """
+ """Some caches have asterisks in their name, test that they are set correctly."""
config = {
"caches": {
@@ -164,7 +163,8 @@ class CacheConfigTests(TestCase):
t.read_config(config, config_dir_path="", data_dir_path="")
cache = LruCache(
- max_size=t.caches.event_cache_size, apply_cache_factor_from_config=False,
+ max_size=t.caches.event_cache_size,
+ apply_cache_factor_from_config=False,
)
add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor)
diff --git a/tests/config/test_server.py b/tests/config/test_server.py
index a10d017120..98af7aa675 100644
--- a/tests/config/test_server.py
+++ b/tests/config/test_server.py
@@ -15,7 +15,8 @@
import yaml
-from synapse.config.server import ServerConfig, is_threepid_reserved
+from synapse.config._base import ConfigError
+from synapse.config.server import ServerConfig, generate_ip_set, is_threepid_reserved
from tests import unittest
@@ -128,3 +129,61 @@ class ServerConfigTestCase(unittest.TestCase):
)
self.assertEqual(conf["listeners"], expected_listeners)
+
+
+class GenerateIpSetTestCase(unittest.TestCase):
+ def test_empty(self):
+ ip_set = generate_ip_set(())
+ self.assertFalse(ip_set)
+
+ ip_set = generate_ip_set((), ())
+ self.assertFalse(ip_set)
+
+ def test_generate(self):
+ """Check adding IPv4 and IPv6 addresses."""
+ # IPv4 address
+ ip_set = generate_ip_set(("1.2.3.4",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+ # IPv4 CIDR
+ ip_set = generate_ip_set(("1.2.3.4/24",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+ # IPv6 address
+ ip_set = generate_ip_set(("2001:db8::8a2e:370:7334",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 1)
+
+ # IPv6 CIDR
+ ip_set = generate_ip_set(("2001:db8::/104",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 1)
+
+ # The addresses can overlap OK.
+ ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4"))
+ self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+ def test_extra(self):
+ """Extra IP addresses are treated the same."""
+ ip_set = generate_ip_set((), ("1.2.3.4",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+ ip_set = generate_ip_set(("1.1.1.1",), ("1.2.3.4",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 8)
+
+ # They can duplicate without error.
+ ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+ def test_bad_value(self):
+ """An error should be raised if a bad value is passed in."""
+ with self.assertRaises(ConfigError):
+ generate_ip_set(("not-an-ip",))
+
+ with self.assertRaises(ConfigError):
+ generate_ip_set(("1.2.3.4/128",))
+
+ with self.assertRaises(ConfigError):
+ generate_ip_set((":::",))
+
+ # The following get treated as empty data.
+ self.assertFalse(generate_ip_set(None))
+ self.assertFalse(generate_ip_set({}))
diff --git a/tests/config/test_util.py b/tests/config/test_util.py
new file mode 100644
index 0000000000..10363e3765
--- /dev/null
+++ b/tests/config/test_util.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config import ConfigError
+from synapse.config._util import validate_config
+
+from tests.unittest import TestCase
+
+
+class ValidateConfigTestCase(TestCase):
+ """Test cases for synapse.config._util.validate_config"""
+
+ def test_bad_object_in_array(self):
+ """malformed objects within an array should be validated correctly"""
+
+ # consider a structure:
+ #
+ # array_of_objs:
+ # - r: 1
+ # foo: 2
+ #
+ # - r: 2
+ # bar: 3
+ #
+ # ... where each entry must contain an "r": check that the path
+ # to the required item is correclty reported.
+
+ schema = {
+ "type": "object",
+ "properties": {
+ "array_of_objs": {
+ "type": "array",
+ "items": {"type": "object", "required": ["r"]},
+ },
+ },
+ }
+
+ with self.assertRaises(ConfigError) as c:
+ validate_config(schema, {"array_of_objs": [{}]}, ("base",))
+
+ self.assertEqual(c.exception.path, ["base", "array_of_objs", "<item 0>"])
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index d146f2254f..30fcc4c1bf 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
return val
def test_verify_json_objects_for_server_awaits_previous_requests(self):
- mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher = Mock()
mock_fetcher.get_keys = Mock()
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
@@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""Tests that we correctly handle key requests for keys we've stored
with a null `ts_valid_until_ms`
"""
- mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
kr = keyring.Keyring(
@@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
}
}
- mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(side_effect=get_keys)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
@@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
}
}
- mock_fetcher1 = keyring.KeyFetcher()
+ mock_fetcher1 = Mock()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
- mock_fetcher2 = keyring.KeyFetcher()
+ mock_fetcher2 = Mock()
mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
@@ -400,7 +400,10 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
)
def build_perspectives_response(
- self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
+ self,
+ server_name: str,
+ signing_key: SigningKey,
+ valid_until_ts: int,
) -> dict:
"""
Build a valid perspectives server response to a request for the given key
@@ -455,7 +458,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
VALID_UNTIL_TS = 200 * 1000
response = self.build_perspectives_response(
- SERVER_NAME, testkey, VALID_UNTIL_TS,
+ SERVER_NAME,
+ testkey,
+ VALID_UNTIL_TS,
)
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 3a80626224..ec85324c0c 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -43,7 +43,10 @@ class TestEventContext(unittest.HomeserverTestCase):
event, context = self.get_success(
create_event(
- self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ sender=self.user_id,
)
)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index c1274c14af..8ba36c6074 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -34,11 +34,17 @@ def MockEvent(**kwargs):
class PruneEventTestCase(unittest.TestCase):
- """ Asserts that a new event constructed with `evdict` will look like
- `matchdict` when it is redacted. """
-
def run_test(self, evdict, matchdict, **kwargs):
- self.assertEquals(
+ """
+ Asserts that a new event constructed with `evdict` will look like
+ `matchdict` when it is redacted.
+
+ Args:
+ evdict: The dictionary to build the event from.
+ matchdict: The expected resulting dictionary.
+ kwargs: Additional keyword arguments used to create the event.
+ """
+ self.assertEqual(
prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
)
@@ -55,54 +61,80 @@ class PruneEventTestCase(unittest.TestCase):
)
def test_basic_keys(self):
+ """Ensure that the keys that should be untouched are kept."""
+ # Note that some of the values below don't really make sense, but the
+ # pruning of events doesn't worry about the values of any fields (with
+ # the exception of the content field).
self.run_test(
{
+ "event_id": "$3:domain",
"type": "A",
"room_id": "!1:domain",
"sender": "@2:domain",
- "event_id": "$3:domain",
+ "state_key": "B",
+ "content": {"other_key": "foo"},
+ "hashes": "hashes",
+ "signatures": {"domain": {"algo:1": "sigs"}},
+ "depth": 4,
+ "prev_events": "prev_events",
+ "prev_state": "prev_state",
+ "auth_events": "auth_events",
"origin": "domain",
+ "origin_server_ts": 1234,
+ "membership": "join",
+ # Also include a key that should be removed.
+ "other_key": "foo",
},
{
+ "event_id": "$3:domain",
"type": "A",
"room_id": "!1:domain",
"sender": "@2:domain",
- "event_id": "$3:domain",
+ "state_key": "B",
+ "hashes": "hashes",
+ "depth": 4,
+ "prev_events": "prev_events",
+ "prev_state": "prev_state",
+ "auth_events": "auth_events",
"origin": "domain",
+ "origin_server_ts": 1234,
+ "membership": "join",
"content": {},
- "signatures": {},
+ "signatures": {"domain": {"algo:1": "sigs"}},
"unsigned": {},
},
)
- def test_unsigned_age_ts(self):
+ # As of MSC2176 we now redact the membership and prev_states keys.
self.run_test(
- {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}},
- {
- "type": "B",
- "event_id": "$test:domain",
- "content": {},
- "signatures": {},
- "unsigned": {"age_ts": 20},
- },
+ {"type": "A", "prev_state": "prev_state", "membership": "join"},
+ {"type": "A", "content": {}, "signatures": {}, "unsigned": {}},
+ room_version=RoomVersions.MSC2176,
)
+ def test_unsigned(self):
+ """Ensure that unsigned properties get stripped (except age_ts and replaces_state)."""
self.run_test(
{
"type": "B",
"event_id": "$test:domain",
- "unsigned": {"other_key": "here"},
+ "unsigned": {
+ "age_ts": 20,
+ "replaces_state": "$test2:domain",
+ "other_key": "foo",
+ },
},
{
"type": "B",
"event_id": "$test:domain",
"content": {},
"signatures": {},
- "unsigned": {},
+ "unsigned": {"age_ts": 20, "replaces_state": "$test2:domain"},
},
)
def test_content(self):
+ """The content dictionary should be stripped in most cases."""
self.run_test(
{"type": "C", "event_id": "$test:domain", "content": {"things": "here"}},
{
@@ -114,11 +146,35 @@ class PruneEventTestCase(unittest.TestCase):
},
)
+ # Some events keep a single content key/value.
+ EVENT_KEEP_CONTENT_KEYS = [
+ ("member", "membership", "join"),
+ ("join_rules", "join_rule", "invite"),
+ ("history_visibility", "history_visibility", "shared"),
+ ]
+ for event_type, key, value in EVENT_KEEP_CONTENT_KEYS:
+ self.run_test(
+ {
+ "type": "m.room." + event_type,
+ "event_id": "$test:domain",
+ "content": {key: value, "other_key": "foo"},
+ },
+ {
+ "type": "m.room." + event_type,
+ "event_id": "$test:domain",
+ "content": {key: value},
+ "signatures": {},
+ "unsigned": {},
+ },
+ )
+
+ def test_create(self):
+ """Create events are partially redacted until MSC2176."""
self.run_test(
{
"type": "m.room.create",
"event_id": "$test:domain",
- "content": {"creator": "@2:domain", "other_field": "here"},
+ "content": {"creator": "@2:domain", "other_key": "foo"},
},
{
"type": "m.room.create",
@@ -129,6 +185,68 @@ class PruneEventTestCase(unittest.TestCase):
},
)
+ # After MSC2176, create events get nothing redacted.
+ self.run_test(
+ {"type": "m.room.create", "content": {"not_a_real_key": True}},
+ {
+ "type": "m.room.create",
+ "content": {"not_a_real_key": True},
+ "signatures": {},
+ "unsigned": {},
+ },
+ room_version=RoomVersions.MSC2176,
+ )
+
+ def test_power_levels(self):
+ """Power level events keep a variety of content keys."""
+ self.run_test(
+ {
+ "type": "m.room.power_levels",
+ "event_id": "$test:domain",
+ "content": {
+ "ban": 1,
+ "events": {"m.room.name": 100},
+ "events_default": 2,
+ "invite": 3,
+ "kick": 4,
+ "redact": 5,
+ "state_default": 6,
+ "users": {"@admin:domain": 100},
+ "users_default": 7,
+ "other_key": 8,
+ },
+ },
+ {
+ "type": "m.room.power_levels",
+ "event_id": "$test:domain",
+ "content": {
+ "ban": 1,
+ "events": {"m.room.name": 100},
+ "events_default": 2,
+ # Note that invite is not here.
+ "kick": 4,
+ "redact": 5,
+ "state_default": 6,
+ "users": {"@admin:domain": 100},
+ "users_default": 7,
+ },
+ "signatures": {},
+ "unsigned": {},
+ },
+ )
+
+ # After MSC2176, power levels events keep the invite key.
+ self.run_test(
+ {"type": "m.room.power_levels", "content": {"invite": 75}},
+ {
+ "type": "m.room.power_levels",
+ "content": {"invite": 75},
+ "signatures": {},
+ "unsigned": {},
+ },
+ room_version=RoomVersions.MSC2176,
+ )
+
def test_alias_event(self):
"""Alias events have special behavior up through room version 6."""
self.run_test(
@@ -146,8 +264,7 @@ class PruneEventTestCase(unittest.TestCase):
},
)
- def test_msc2432_alias_event(self):
- """After MSC2432, alias events have no special behavior."""
+ # After MSC2432, alias events have no special behavior.
self.run_test(
{"type": "m.room.aliases", "content": {"aliases": ["test"]}},
{
@@ -159,6 +276,32 @@ class PruneEventTestCase(unittest.TestCase):
room_version=RoomVersions.V6,
)
+ def test_redacts(self):
+ """Redaction events have no special behaviour until MSC2174/MSC2176."""
+
+ self.run_test(
+ {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
+ {
+ "type": "m.room.redaction",
+ "content": {},
+ "signatures": {},
+ "unsigned": {},
+ },
+ room_version=RoomVersions.V6,
+ )
+
+ # After MSC2174, redaction events keep the redacts content key.
+ self.run_test(
+ {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
+ {
+ "type": "m.room.redaction",
+ "content": {"redacts": "$test2:domain"},
+ "signatures": {},
+ "unsigned": {},
+ },
+ room_version=RoomVersions.MSC2176,
+ )
+
class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields):
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 0187f56e21..8186b8ca01 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -48,7 +48,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
)
# Get the room complexity
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEquals(200, channel.code)
@@ -60,7 +60,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
# Get the room complexity again -- make sure it's our artificial value
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEquals(200, channel.code)
@@ -150,8 +150,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
)
# Artificially raise the complexity
- self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable(
- 600
+ self.hs.get_datastore().get_current_state_event_counts = (
+ lambda x: make_awaitable(600)
)
d = handler._remote_join(
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 1a3ccb263d..95eac6a5a3 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -2,11 +2,13 @@ from typing import List, Tuple
from mock import Mock
+from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
+from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable
from tests.unittest import FederatingHomeserverTestCase, override_config
@@ -49,7 +51,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
else:
data = json_cb()
self.failed_pdus.extend(data["pdus"])
- raise IOError("Failed to connect because this is a test!")
+ raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination)
def get_destination_room(self, room: str, destination: str = "host2") -> dict:
"""
@@ -420,3 +422,51 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertNotIn("zzzerver", woken)
# - all destinations are woken exactly once; they appear once in woken.
self.assertCountEqual(woken, server_names[:-1])
+
+ @override_config({"send_federation": True})
+ def test_not_latest_event(self):
+ """Test that we send the latest event in the room even if its not ours."""
+
+ per_dest_queue, sent_pdus = self.make_fake_destination_queue()
+
+ # Make a room with a local user, and two servers. One will go offline
+ # and one will send some events.
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_1 = self.helper.create_room_as("u1", tok=u1_token)
+
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
+ )
+ event_1 = self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join")
+ )
+
+ # First we send something from the local server, so that we notice the
+ # remote is down and go into catchup mode.
+ self.helper.send(room_1, "you hear me!!", tok=u1_token)
+
+ # Now simulate us receiving an event from the still online remote.
+ event_2 = self.get_success(
+ event_injection.inject_event(
+ self.hs,
+ type=EventTypes.Message,
+ sender="@user:host3",
+ room_id=room_1,
+ content={"msgtype": "m.text", "body": "Hello"},
+ )
+ )
+
+ self.get_success(
+ self.hs.get_datastore().set_destination_last_successful_stream_ordering(
+ "host2", event_1.internal_metadata.stream_ordering
+ )
+ )
+
+ self.get_success(per_dest_queue._catch_up_transmission_loop())
+
+ # We expect only the last message from the remote, event_2, to have been
+ # sent, rather than the last *local* event that was sent.
+ self.assertEqual(len(sent_pdus), 1)
+ self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
+ self.assertFalse(per_dest_queue._catching_up)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 917762e6b6..ecc3faa572 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -279,7 +279,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
ret = self.get_success(
e2e_handler.upload_signatures_for_device_keys(
- u1, {u1: {"D1": d1_json, "D2": d2_json}},
+ u1,
+ {u1: {"D1": d1_json, "D2": d2_json}},
)
)
self.assertEqual(ret["failures"], {})
@@ -486,9 +487,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertGreaterEqual(content["stream_id"], prev_stream_id)
return content["stream_id"]
- def check_signing_key_update_txn(self, txn: JsonDict,) -> None:
- """Check that the txn has an EDU with a signing key update.
- """
+ def check_signing_key_update_txn(
+ self,
+ txn: JsonDict,
+ ) -> None:
+ """Check that the txn has an EDU with a signing key update."""
edus = txn["edus"]
self.assertEqual(len(edus), 1)
@@ -502,7 +505,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.get_success(
self.hs.get_e2e_keys_handler().upload_keys_for_user(
- user_id, device_id, {"device_keys": device_dict},
+ user_id,
+ device_id,
+ {"device_keys": device_dict},
)
)
return sk
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 3009fbb6c4..cfeccc0577 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -46,7 +46,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
"/get_missing_events/(?P<room_id>[^/]*)/?"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content,
@@ -95,7 +95,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
room_1 = self.helper.create_room_as(u1, tok=u1_token)
self.inject_room_member(room_1, "@user:other.example.com", "join")
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
self.assertEquals(200, channel.code, channel.result)
@@ -127,7 +127,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
room_1 = self.helper.create_room_as(u1, tok=u1_token)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/federation/transport/__init__.py b/tests/federation/transport/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/federation/transport/__init__.py
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 38dc0e5aa2..fbce19f38b 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -34,7 +34,10 @@ KNOCK_UNSTABLE_IDENTIFIER = "xyz.amorgan.knock"
class KnockingStrippedStateEventHelperMixin(TestCase):
def send_example_state_events_to_room(
- self, hs: "HomeServer", room_id: str, sender: str,
+ self,
+ hs: "HomeServer",
+ room_id: str,
+ sender: str,
) -> OrderedDict:
"""Adds some state to a room. State events are those that should be sent to a knocking
user after they knock on the room, as well as some state that *shouldn't* be sent
@@ -133,7 +136,9 @@ class KnockingStrippedStateEventHelperMixin(TestCase):
return room_state
def check_knock_room_state_against_room_state(
- self, knock_room_state: List[Dict], expected_room_state: Dict,
+ self,
+ knock_room_state: List[Dict],
+ expected_room_state: Dict,
) -> None:
"""Test a list of stripped room state events received over federation against a
dict of expected state events.
@@ -227,7 +232,7 @@ class FederationKnockingTestCase(
self.hs, room_id, user_id
)
- _, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/federation/unstable/%s/make_knock/%s/%s?ver=%s"
% (
@@ -271,7 +276,7 @@ class FederationKnockingTestCase(
)
# Send the signed knock event into the room
- _, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/federation/unstable/%s/send_knock/%s/%s"
% (KNOCK_UNSTABLE_IDENTIFIER, room_id, signed_knock_event.event_id),
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index f9e3c7a51f..85500e169c 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,38 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
-from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.federation.transport import server
-from synapse.util.ratelimitutils import FederationRateLimiter
-
from tests import unittest
from tests.unittest import override_config
-class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
- class Authenticator:
- def authenticate_request(self, request, content):
- return defer.succeed("otherserver.nottld")
-
- ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig())
- server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
-
+class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
@override_config({"allow_public_rooms_over_federation": False})
def test_blocked_public_room_list_over_federation(self):
- request, channel = self.make_request(
- "GET", "/_matrix/federation/v1/publicRooms"
+ """Test that unauthenticated requests to the public rooms directory 403 when
+ allow_public_rooms_over_federation is False.
+ """
+ channel = self.make_request(
+ "GET",
+ "/_matrix/federation/v1/publicRooms",
+ federation_auth_origin=b"example.com",
)
self.assertEquals(403, channel.code)
@override_config({"allow_public_rooms_over_federation": True})
def test_open_public_room_list_over_federation(self):
- request, channel = self.make_request(
- "GET", "/_matrix/federation/v1/publicRooms"
+ """Test that unauthenticated requests to the public rooms directory 200 when
+ allow_public_rooms_over_federation is True.
+ """
+ channel = self.make_request(
+ "GET",
+ "/_matrix/federation/v1/publicRooms",
+ federation_auth_origin=b"example.com",
)
self.assertEquals(200, channel.code)
diff --git a/tests/handlers/oidc_test_key.p8 b/tests/handlers/oidc_test_key.p8
new file mode 100644
index 0000000000..bb92976333
--- /dev/null
+++ b/tests/handlers/oidc_test_key.p8
@@ -0,0 +1,5 @@
+-----BEGIN PRIVATE KEY-----
+MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrHMvFcFjFhei6gHp
+Gfy4C8+6z7634MZbC7SSx4a17GahRANCAATp0YxEzGUXuqszggiFxczDdPgDpCJA
+P18rRuN7FLwZDuzYQPb8zVd8eGh4BqxjiVocICnVWyaSWD96N00I96SW
+-----END PRIVATE KEY-----
diff --git a/tests/handlers/oidc_test_key.pub.pem b/tests/handlers/oidc_test_key.pub.pem
new file mode 100644
index 0000000000..176d4a4b4b
--- /dev/null
+++ b/tests/handlers/oidc_test_key.pub.pem
@@ -0,0 +1,4 @@
+-----BEGIN PUBLIC KEY-----
+MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6dGMRMxlF7qrM4IIhcXMw3T4A6Qi
+QD9fK0bjexS8GQ7s2ED2/M1XfHhoeAasY4laHCAp1Vsmklg/ejdNCPeklg==
+-----END PUBLIC KEY-----
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 5c2b4de1a6..a01fdd0839 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -44,8 +44,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.token2 = self.login("user2", "password")
def test_single_public_joined_room(self):
- """Test that we write *all* events for a public room
- """
+ """Test that we write *all* events for a public room"""
room_id = self.helper.create_room_as(
self.user1, tok=self.token1, is_public=True
)
@@ -116,8 +115,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
def test_single_left_room(self):
- """Tests that we don't see events in the room after we leave.
- """
+ """Tests that we don't see events in the room after we leave."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
self.helper.join(room_id, self.user2, tok=self.token2)
@@ -190,8 +188,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
def test_invite(self):
- """Tests that pending invites get handled correctly.
- """
+ """Tests that pending invites get handled correctly."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
self.helper.invite(room_id, self.user1, self.user2, tok=self.token1)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 53763cd0f9..d5d3fdd99a 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -35,8 +35,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_scheduler = Mock()
hs = Mock()
hs.get_datastore.return_value = self.mock_store
- self.mock_store.get_received_ts.return_value = defer.succeed(0)
- self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None)
+ self.mock_store.get_received_ts.return_value = make_awaitable(0)
+ self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
@@ -50,16 +50,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice(is_interested=False),
]
- self.mock_as_api.query_user.return_value = defer.succeed(True)
+ self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = defer.succeed([])
+ self.mock_store.get_user_by_id.return_value = make_awaitable([])
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
self.mock_store.get_new_events_for_appservice.side_effect = [
- defer.succeed((0, [event])),
- defer.succeed((0, [])),
+ make_awaitable((0, [event])),
+ make_awaitable((0, [])),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -72,13 +72,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = defer.succeed(None)
+ self.mock_store.get_user_by_id.return_value = make_awaitable(None)
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.query_user.return_value = defer.succeed(True)
+ self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_new_events_for_appservice.side_effect = [
- defer.succeed((0, [event])),
- defer.succeed((0, [])),
+ make_awaitable((0, [event])),
+ make_awaitable((0, [])),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -90,13 +90,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id})
+ self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.query_user.return_value = defer.succeed(True)
+ self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_new_events_for_appservice.side_effect = [
- defer.succeed((0, [event])),
- defer.succeed((0, [])),
+ make_awaitable((0, [event])),
+ make_awaitable((0, [])),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -106,7 +106,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
"query_user called when it shouldn't have been.",
)
- @defer.inlineCallbacks
def test_query_room_alias_exists(self):
room_alias_str = "#foo:bar"
room_alias = Mock()
@@ -127,8 +126,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
Mock(room_id=room_id, servers=servers)
)
- result = yield defer.ensureDeferred(
- self.handler.query_room_alias_exists(room_alias)
+ result = self.successResultOf(
+ defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias))
)
self.mock_as_api.query_alias.assert_called_once_with(
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index e24ce81284..c9f889b511 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -16,28 +16,21 @@ from mock import Mock
import pymacaroons
-from twisted.internet import defer
-
-import synapse
-import synapse.api.errors
-from synapse.api.errors import ResourceLimitError
+from synapse.api.errors import AuthError, ResourceLimitError
from tests import unittest
from tests.test_utils import make_awaitable
-from tests.utils import setup_test_homeserver
-class AuthTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.auth_handler = self.hs.get_auth_handler()
- self.macaroon_generator = self.hs.get_macaroon_generator()
+class AuthTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.auth_handler = hs.get_auth_handler()
+ self.macaroon_generator = hs.get_macaroon_generator()
# MAU tests
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
- self.auth_blocking = self.hs.get_auth()._auth_blocking
+ self.auth_blocking = hs.get_auth()._auth_blocking
self.auth_blocking._max_mau_value = 50
self.small_number_of_users = 1
@@ -52,8 +45,6 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self):
- self.hs.get_clock().now = 5000
-
token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -76,135 +67,126 @@ class AuthTestCase(unittest.TestCase):
v.satisfy_general(verify_nonce)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
- @defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self):
- self.hs.get_clock().now = 1000
-
- token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
- user_id = yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", "", 5000
)
- self.assertEqual("a_user", user_id)
+ res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+ self.assertEqual("a_user", res.user_id)
+ self.assertEqual("", res.auth_provider_id)
# when we advance the clock, the token should be rejected
- self.hs.get_clock().now = 6000
- with self.assertRaises(synapse.api.errors.AuthError):
- yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
- )
+ self.reactor.advance(6)
+ self.get_failure(
+ self.auth_handler.validate_short_term_login_token(token),
+ AuthError,
+ )
+
+ def test_short_term_login_token_gives_auth_provider(self):
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", auth_provider_id="my_idp"
+ )
+ res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+ self.assertEqual("a_user", res.user_id)
+ self.assertEqual("my_idp", res.auth_provider_id)
- @defer.inlineCallbacks
def test_short_term_login_token_cannot_replace_user_id(self):
- token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", "", 5000
+ )
macaroon = pymacaroons.Macaroon.deserialize(token)
- user_id = yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
- )
+ res = self.get_success(
+ self.auth_handler.validate_short_term_login_token(macaroon.serialize())
)
- self.assertEqual("a_user", user_id)
+ self.assertEqual("a_user", res.user_id)
# add another "user_id" caveat, which might allow us to override the
# user_id.
macaroon.add_first_party_caveat("user_id = b_user")
- with self.assertRaises(synapse.api.errors.AuthError):
- yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
- )
- )
+ self.get_failure(
+ self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
+ AuthError,
+ )
- @defer.inlineCallbacks
def test_mau_limits_disabled(self):
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
- yield defer.ensureDeferred(
+ self.get_success(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
)
- yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.get_success(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
- @defer.inlineCallbacks
def test_mau_limits_exceeded_large(self):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
)
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(
- self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
- )
- )
+ self.get_failure(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ ),
+ ResourceLimitError,
+ )
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
)
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
- )
- )
+ self.get_failure(
+ self.auth_handler.validate_short_term_login_token(
+ self._get_macaroon().serialize()
+ ),
+ ResourceLimitError,
+ )
- @defer.inlineCallbacks
def test_mau_limits_parity(self):
+ # Ensure we're not at the unix epoch.
+ self.reactor.advance(1)
self.auth_blocking._limit_usage_by_mau = True
- # If not in monthly active cohort
+ # Set the server to be at the edge of too many users.
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(
- self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
- )
- )
- self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=make_awaitable(self.auth_blocking._max_mau_value)
+ # If not in monthly active cohort
+ self.get_failure(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ ),
+ ResourceLimitError,
)
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
- )
- )
+ self.get_failure(
+ self.auth_handler.validate_short_term_login_token(
+ self._get_macaroon().serialize()
+ ),
+ ResourceLimitError,
+ )
+
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- return_value=make_awaitable(self.hs.get_clock().time_msec())
+ return_value=make_awaitable(self.clock.time_msec())
)
- self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=make_awaitable(self.auth_blocking._max_mau_value)
- )
- yield defer.ensureDeferred(
+ self.get_success(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
)
- self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- return_value=make_awaitable(self.hs.get_clock().time_msec())
- )
- self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=make_awaitable(self.auth_blocking._max_mau_value)
- )
- yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.get_success(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
- @defer.inlineCallbacks
def test_mau_limits_not_exceeded(self):
self.auth_blocking._limit_usage_by_mau = True
@@ -212,7 +194,7 @@ class AuthTestCase(unittest.TestCase):
return_value=make_awaitable(self.small_number_of_users)
)
# Ensure does not raise exception
- yield defer.ensureDeferred(
+ self.get_success(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
@@ -221,12 +203,14 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.small_number_of_users)
)
- yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.get_success(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
def _get_macaroon(self):
- token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "user_a", "", 5000
+ )
return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
new file mode 100644
index 0000000000..7975af243c
--- /dev/null
+++ b/tests/handlers/test_cas.py
@@ -0,0 +1,169 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from synapse.handlers.cas_handler import CasResponse
+
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase, override_config
+
+# These are a few constants that are used as config parameters in the tests.
+BASE_URL = "https://synapse/"
+SERVER_URL = "https://issuer/"
+
+
+class CasHandlerTestCase(HomeserverTestCase):
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+ cas_config = {
+ "enabled": True,
+ "server_url": SERVER_URL,
+ "service_url": BASE_URL,
+ }
+
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ cas_config.update(config.get("cas_config", {}))
+ config["cas_config"] = cas_config
+
+ return config
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+
+ self.handler = hs.get_cas_handler()
+
+ # Reduce the number of attempts when generating MXIDs.
+ sso_handler = hs.get_sso_handler()
+ sso_handler._MAP_USERNAME_RETRIES = 3
+
+ return hs
+
+ def test_map_cas_user_to_user(self):
+ """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ cas_response = CasResponse("test_user", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
+ )
+
+ def test_map_cas_user_to_existing_user(self):
+ """Existing users can log in with CAS account."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # Map a user via SSO.
+ cas_response = CasResponse("test_user", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
+ )
+
+ # Subsequent calls should map to the same mxid.
+ auth_handler.complete_sso_login.reset_mock()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
+ )
+
+ def test_map_cas_user_to_invalid_localpart(self):
+ """CAS automaps invalid characters to base-64 encoding."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ cas_response = CasResponse("föö", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
+ )
+
+ @override_config(
+ {
+ "cas_config": {
+ "required_attributes": {"userGroup": "staff", "department": None}
+ }
+ }
+ )
+ def test_required_attributes(self):
+ """The required attributes must be met from the CAS response."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # The response doesn't have the proper userGroup or department.
+ cas_response = CasResponse("test_user", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # The response doesn't have any department.
+ cas_response = CasResponse("test_user", {"userGroup": "staff"})
+ request.reset_mock()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # Add the proper attributes and it should succeed.
+ cas_response = CasResponse(
+ "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
+ )
+ request.reset_mock()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
+ )
+
+
+def _mock_request():
+ """Returns a mock which will stand in as a SynapseRequest"""
+ return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 5dfeccfeb6..821629bc38 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -260,7 +260,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# Create a new login for the user and dehydrated the device
device_id, access_token = self.get_success(
self.registration.register_device(
- user_id=user_id, device_id=None, initial_display_name="new device",
+ user_id=user_id,
+ device_id=None,
+ initial_display_name="new device",
)
)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 2f1f2a5517..fadec16e13 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -133,7 +133,9 @@ class TestCreateAlias(unittest.HomeserverTestCase):
"""A user can create an alias for a room they're in."""
self.get_success(
self.handler.create_association(
- create_requester(self.test_user), self.room_alias, self.room_id,
+ create_requester(self.test_user),
+ self.room_alias,
+ self.room_id,
)
)
@@ -145,7 +147,9 @@ class TestCreateAlias(unittest.HomeserverTestCase):
self.get_failure(
self.handler.create_association(
- create_requester(self.test_user), self.room_alias, other_room_id,
+ create_requester(self.test_user),
+ self.room_alias,
+ other_room_id,
),
synapse.api.errors.SynapseError,
)
@@ -158,7 +162,9 @@ class TestCreateAlias(unittest.HomeserverTestCase):
self.get_success(
self.handler.create_association(
- create_requester(self.admin_user), self.room_alias, other_room_id,
+ create_requester(self.admin_user),
+ self.room_alias,
+ other_room_id,
)
)
@@ -277,8 +283,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
class CanonicalAliasTestCase(unittest.HomeserverTestCase):
- """Test modifications of the canonical alias when delete aliases.
- """
+ """Test modifications of the canonical alias when delete aliases."""
servlets = [
synapse.rest.admin.register_servlets,
@@ -319,7 +324,10 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
def _set_canonical_alias(self, content):
"""Configure the canonical alias state on the room."""
self.helper.send_state(
- self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok,
+ self.room_id,
+ "m.room.canonical_alias",
+ content,
+ tok=self.admin_user_tok,
)
def _get_canonical_alias(self):
@@ -407,7 +415,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
def test_denied(self):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
b"directory/room/%23test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
@@ -417,7 +425,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
def test_allowed(self):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
b"directory/room/%23unofficial_test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
@@ -433,7 +441,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
self.assertEquals(200, channel.code, channel.result)
@@ -448,7 +456,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.directory_handler.enable_room_list_search = True
# Room list is enabled so we should get some results
- request, channel = self.make_request("GET", b"publicRooms")
+ channel = self.make_request("GET", b"publicRooms")
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) > 0)
@@ -456,13 +464,13 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.directory_handler.enable_room_list_search = False
# Room list disabled so we should get no results
- request, channel = self.make_request("GET", b"publicRooms")
+ channel = self.make_request("GET", b"publicRooms")
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) == 0)
# Room list disabled so we shouldn't be allowed to publish rooms
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 924f29f051..5e86c5e56b 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -18,42 +18,26 @@ import mock
from signedjson import key as key, sign as sign
-from twisted.internet import defer
-
-import synapse.handlers.e2e_keys
-import synapse.storage
-from synapse.api import errors
from synapse.api.constants import RoomEncryptionAlgorithms
+from synapse.api.errors import Codes, SynapseError
-from tests import unittest, utils
+from tests import unittest
-class E2eKeysHandlerTestCase(unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.hs = None # type: synapse.server.HomeServer
- self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
- self.store = None # type: synapse.storage.Storage
+class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(federation_client=mock.Mock())
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield utils.setup_test_homeserver(
- self.addCleanup, federation_client=mock.Mock()
- )
- self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastore()
- @defer.inlineCallbacks
def test_query_local_devices_no_devices(self):
- """If the user has no devices, we expect an empty list.
- """
+ """If the user has no devices, we expect an empty list."""
local_user = "@boris:" + self.hs.hostname
- res = yield defer.ensureDeferred(
- self.handler.query_local_devices({local_user: None})
- )
+ res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
- @defer.inlineCallbacks
def test_reupload_one_time_keys(self):
"""we should be able to re-upload the same keys"""
local_user = "@boris:" + self.hs.hostname
@@ -64,7 +48,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
)
@@ -73,14 +57,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
)
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
- @defer.inlineCallbacks
def test_change_one_time_keys(self):
"""attempts to change one-time-keys should be rejected"""
@@ -92,75 +75,66 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
)
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
- try:
- yield defer.ensureDeferred(
- self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
- )
- )
- self.fail("No error when changing string key")
- except errors.SynapseError:
- pass
-
- try:
- yield defer.ensureDeferred(
- self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
- )
- )
- self.fail("No error when replacing dict key with string")
- except errors.SynapseError:
- pass
-
- try:
- yield defer.ensureDeferred(
- self.handler.upload_keys_for_user(
- local_user,
- device_id,
- {"one_time_keys": {"alg1:k1": {"key": "key"}}},
- )
- )
- self.fail("No error when replacing string key with dict")
- except errors.SynapseError:
- pass
-
- try:
- yield defer.ensureDeferred(
- self.handler.upload_keys_for_user(
- local_user,
- device_id,
- {
- "one_time_keys": {
- "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
- }
- },
- )
- )
- self.fail("No error when replacing dict key")
- except errors.SynapseError:
- pass
+ # Error when changing string key
+ self.get_failure(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+ ),
+ SynapseError,
+ )
+
+ # Error when replacing dict key with strin
+ self.get_failure(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+ ),
+ SynapseError,
+ )
+
+ # Error when replacing string key with dict
+ self.get_failure(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"one_time_keys": {"alg1:k1": {"key": "key"}}},
+ ),
+ SynapseError,
+ )
+
+ # Error when replacing dict key
+ self.get_failure(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {
+ "one_time_keys": {
+ "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
+ }
+ },
+ ),
+ SynapseError,
+ )
- @defer.inlineCallbacks
def test_claim_one_time_key(self):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {"alg1:k1": "key1"}
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
)
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
- res2 = yield defer.ensureDeferred(
+ res2 = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
@@ -173,7 +147,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
)
- @defer.inlineCallbacks
def test_fallback_key(self):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
@@ -181,12 +154,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
otk = {"alg1:k2": "key2"}
# we shouldn't have any unused fallback keys yet
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
@@ -195,14 +168,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
# we should now have an unused alg1 key
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, ["alg1"])
# claiming an OTK when no OTKs are available should return the fallback
# key
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
@@ -213,13 +186,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
# we shouldn't have any unused fallback keys again
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])
# claiming an OTK again should return the same fallback key
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
@@ -231,22 +204,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": otk}
)
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
@@ -256,7 +230,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
- @defer.inlineCallbacks
def test_replace_master_key(self):
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
@@ -270,9 +243,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield defer.ensureDeferred(
- self.handler.upload_signing_keys_for_user(local_user, keys1)
- )
+ self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
keys2 = {
"master_key": {
@@ -284,16 +255,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield defer.ensureDeferred(
- self.handler.upload_signing_keys_for_user(local_user, keys2)
- )
+ self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
- devices = yield defer.ensureDeferred(
+ devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
- @defer.inlineCallbacks
def test_reupload_signatures(self):
"""re-uploading a signature should not fail"""
local_user = "@boris:" + self.hs.hostname
@@ -326,9 +294,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
"2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
)
- yield defer.ensureDeferred(
- self.handler.upload_signing_keys_for_user(local_user, keys1)
- )
+ self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
# upload two device keys, which will be signed later by the self-signing key
device_key_1 = {
@@ -358,12 +324,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"signatures": {local_user: {"ed25519:def": "base64+signature"}},
}
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_keys_for_user(
local_user, "abc", {"device_keys": device_key_1}
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_keys_for_user(
local_user, "def", {"device_keys": device_key_2}
)
@@ -372,7 +338,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# sign the first device key and upload it
del device_key_1["signatures"]
sign.sign_json(device_key_1, local_user, signing_key)
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_signatures_for_device_keys(
local_user, {local_user: {"abc": device_key_1}}
)
@@ -383,7 +349,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# signature for it
del device_key_2["signatures"]
sign.sign_json(device_key_2, local_user, signing_key)
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_signatures_for_device_keys(
local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
)
@@ -391,7 +357,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
- devices = yield defer.ensureDeferred(
+ devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
del devices["device_keys"][local_user]["abc"]["unsigned"]
@@ -399,7 +365,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
- @defer.inlineCallbacks
def test_self_signing_key_doesnt_show_up_as_device(self):
"""signing keys should be hidden when fetching a user's devices"""
local_user = "@boris:" + self.hs.hostname
@@ -413,29 +378,22 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield defer.ensureDeferred(
- self.handler.upload_signing_keys_for_user(local_user, keys1)
- )
-
- res = None
- try:
- yield defer.ensureDeferred(
- self.hs.get_device_handler().check_device_registered(
- user_id=local_user,
- device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
- initial_device_display_name="new display name",
- )
- )
- except errors.SynapseError as e:
- res = e.code
- self.assertEqual(res, 400)
+ self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
- res = yield defer.ensureDeferred(
- self.handler.query_local_devices({local_user: None})
+ e = self.get_failure(
+ self.hs.get_device_handler().check_device_registered(
+ user_id=local_user,
+ device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+ initial_device_display_name="new display name",
+ ),
+ SynapseError,
)
+ res = e.value.code
+ self.assertEqual(res, 400)
+
+ res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
- @defer.inlineCallbacks
def test_upload_signatures(self):
"""should check signatures that are uploaded"""
# set up a user with cross-signing keys and a device. This user will
@@ -458,7 +416,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
)
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"device_keys": device_key}
)
@@ -501,7 +459,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_signing_key": usersigning_key,
"self_signing_key": selfsigning_key,
}
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
)
@@ -515,14 +473,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"usage": ["master"],
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
}
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_signing_keys_for_user(
other_user, {"master_key": other_master_key}
)
)
# test various signature failures (see below)
- ret = yield defer.ensureDeferred(
+ ret = self.get_success(
self.handler.upload_signatures_for_device_keys(
local_user,
{
@@ -602,20 +560,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
user_failures = ret["failures"][local_user]
+ self.assertEqual(user_failures[device_id]["errcode"], Codes.INVALID_SIGNATURE)
self.assertEqual(
- user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE
+ user_failures[master_pubkey]["errcode"], Codes.INVALID_SIGNATURE
)
- self.assertEqual(
- user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE
- )
- self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND)
+ self.assertEqual(user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
other_user_failures = ret["failures"][other_user]
+ self.assertEqual(other_user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
self.assertEqual(
- other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND
- )
- self.assertEqual(
- other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN
+ other_user_failures[other_master_pubkey]["errcode"], Codes.UNKNOWN
)
# test successful signatures
@@ -623,7 +577,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
sign.sign_json(device_key, local_user, selfsigning_signing_key)
sign.sign_json(master_key, local_user, device_signing_key)
sign.sign_json(other_master_key, local_user, usersigning_signing_key)
- ret = yield defer.ensureDeferred(
+ ret = self.get_success(
self.handler.upload_signatures_for_device_keys(
local_user,
{
@@ -636,7 +590,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(ret["failures"], {})
# fetch the signed keys/devices and make sure that the signatures are there
- ret = yield defer.ensureDeferred(
+ ret = self.get_success(
self.handler.query_devices(
{"device_keys": {local_user: [], other_user: []}}, 0, local_user
)
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 45f201a399..d7498aa51a 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -19,14 +19,9 @@ import copy
import mock
-from twisted.internet import defer
+from synapse.api.errors import SynapseError
-import synapse.api.errors
-import synapse.handlers.e2e_room_keys
-import synapse.storage
-from synapse.api import errors
-
-from tests import unittest, utils
+from tests import unittest
# sample room_key data for use in the tests
room_keys = {
@@ -45,51 +40,38 @@ room_keys = {
}
-class E2eRoomKeysHandlerTestCase(unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.hs = None # type: synapse.server.HomeServer
- self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
+class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(replication_layer=mock.Mock())
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield utils.setup_test_homeserver(
- self.addCleanup, replication_layer=mock.Mock()
- )
- self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
- self.local_user = "@boris:" + self.hs.hostname
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_e2e_room_keys_handler()
+ self.local_user = "@boris:" + hs.hostname
- @defer.inlineCallbacks
def test_get_missing_current_version_info(self):
"""Check that we get a 404 if we ask for info about the current version
if there is no version.
"""
- res = None
- try:
- yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
- except errors.SynapseError as e:
- res = e.code
+ e = self.get_failure(
+ self.handler.get_version_info(self.local_user), SynapseError
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_get_missing_version_info(self):
"""Check that we get a 404 if we ask for info about a specific version
if it doesn't exist.
"""
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.get_version_info(self.local_user, "bogus_version")
- )
- except errors.SynapseError as e:
- res = e.code
+ e = self.get_failure(
+ self.handler.get_version_info(self.local_user, "bogus_version"),
+ SynapseError,
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_create_version(self):
- """Check that we can create and then retrieve versions.
- """
- res = yield defer.ensureDeferred(
+ """Check that we can create and then retrieve versions."""
+ res = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -101,7 +83,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, "1")
# check we can retrieve it as the current version
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
version_etag = res["etag"]
self.assertIsInstance(version_etag, str)
del res["etag"]
@@ -116,9 +98,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# check we can retrieve it as a specific version
- res = yield defer.ensureDeferred(
- self.handler.get_version_info(self.local_user, "1")
- )
+ res = self.get_success(self.handler.get_version_info(self.local_user, "1"))
self.assertEqual(res["etag"], version_etag)
del res["etag"]
self.assertDictEqual(
@@ -132,7 +112,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# upload a new one...
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -144,7 +124,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, "2")
# check we can retrieve it as the current version
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -156,11 +136,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
},
)
- @defer.inlineCallbacks
def test_update_version(self):
- """Check that we can update versions.
- """
- version = yield defer.ensureDeferred(
+ """Check that we can update versions."""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -171,7 +149,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.update_version(
self.local_user,
version,
@@ -185,7 +163,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(res, {})
# check we can retrieve it as the current version
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -197,32 +175,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
},
)
- @defer.inlineCallbacks
def test_update_missing_version(self):
- """Check that we get a 404 on updating nonexistent versions
- """
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.update_version(
- self.local_user,
- "1",
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "1",
- },
- )
- )
- except errors.SynapseError as e:
- res = e.code
+ """Check that we get a 404 on updating nonexistent versions"""
+ e = self.get_failure(
+ self.handler.update_version(
+ self.local_user,
+ "1",
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "1",
+ },
+ ),
+ SynapseError,
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_update_omitted_version(self):
- """Check that the update succeeds if the version is missing from the body
- """
- version = yield defer.ensureDeferred(
+ """Check that the update succeeds if the version is missing from the body"""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -233,7 +205,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.update_version(
self.local_user,
version,
@@ -245,7 +217,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# check we can retrieve it as the current version
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
del res["etag"] # etag is opaque, so don't test its contents
self.assertDictEqual(
res,
@@ -257,11 +229,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
},
)
- @defer.inlineCallbacks
def test_update_bad_version(self):
- """Check that we get a 400 if the version in the body doesn't match
- """
- version = yield defer.ensureDeferred(
+ """Check that we get a 400 if the version in the body doesn't match"""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -272,52 +242,38 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "incorrect",
- },
- )
- )
- except errors.SynapseError as e:
- res = e.code
+ e = self.get_failure(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "incorrect",
+ },
+ ),
+ SynapseError,
+ )
+ res = e.value.code
self.assertEqual(res, 400)
- @defer.inlineCallbacks
def test_delete_missing_version(self):
- """Check that we get a 404 on deleting nonexistent versions
- """
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.delete_version(self.local_user, "1")
- )
- except errors.SynapseError as e:
- res = e.code
+ """Check that we get a 404 on deleting nonexistent versions"""
+ e = self.get_failure(
+ self.handler.delete_version(self.local_user, "1"), SynapseError
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_delete_missing_current_version(self):
- """Check that we get a 404 on deleting nonexistent current version
- """
- res = None
- try:
- yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
- except errors.SynapseError as e:
- res = e.code
+ """Check that we get a 404 on deleting nonexistent current version"""
+ e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_delete_version(self):
- """Check that we can create and then delete versions.
- """
- res = yield defer.ensureDeferred(
+ """Check that we can create and then delete versions."""
+ res = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -329,36 +285,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, "1")
# check we can delete it
- yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
+ self.get_success(self.handler.delete_version(self.local_user, "1"))
# check that it's gone
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.get_version_info(self.local_user, "1")
- )
- except errors.SynapseError as e:
- res = e.code
+ e = self.get_failure(
+ self.handler.get_version_info(self.local_user, "1"), SynapseError
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_get_missing_backup(self):
- """Check that we get a 404 on querying missing backup
- """
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.get_room_keys(self.local_user, "bogus_version")
- )
- except errors.SynapseError as e:
- res = e.code
+ """Check that we get a 404 on querying missing backup"""
+ e = self.get_failure(
+ self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_get_missing_room_keys(self):
- """Check we get an empty response from an empty backup
- """
- version = yield defer.ensureDeferred(
+ """Check we get an empty response from an empty backup"""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -369,33 +315,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- res = yield defer.ensureDeferred(
- self.handler.get_room_keys(self.local_user, version)
- )
+ res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.assertDictEqual(res, {"rooms": {}})
# TODO: test the locking semantics when uploading room_keys,
# although this is probably best done in sytest
- @defer.inlineCallbacks
def test_upload_room_keys_no_versions(self):
- """Check that we get a 404 on uploading keys when no versions are defined
- """
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
- )
- except errors.SynapseError as e:
- res = e.code
+ """Check that we get a 404 on uploading keys when no versions are defined"""
+ e = self.get_failure(
+ self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
+ SynapseError,
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_upload_room_keys_bogus_version(self):
"""Check that we get a 404 on uploading keys when an nonexistent version
is specified
"""
- version = yield defer.ensureDeferred(
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -406,22 +345,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.upload_room_keys(
- self.local_user, "bogus_version", room_keys
- )
- )
- except errors.SynapseError as e:
- res = e.code
+ e = self.get_failure(
+ self.handler.upload_room_keys(self.local_user, "bogus_version", room_keys),
+ SynapseError,
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_upload_room_keys_wrong_version(self):
- """Check that we get a 403 on uploading keys for an old version
- """
- version = yield defer.ensureDeferred(
+ """Check that we get a 403 on uploading keys for an old version"""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -432,7 +365,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- version = yield defer.ensureDeferred(
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -443,20 +376,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "2")
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.upload_room_keys(self.local_user, "1", room_keys)
- )
- except errors.SynapseError as e:
- res = e.code
+ e = self.get_failure(
+ self.handler.upload_room_keys(self.local_user, "1", room_keys), SynapseError
+ )
+ res = e.value.code
self.assertEqual(res, 403)
- @defer.inlineCallbacks
def test_upload_room_keys_insert(self):
- """Check that we can insert and retrieve keys for a session
- """
- version = yield defer.ensureDeferred(
+ """Check that we can insert and retrieve keys for a session"""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -467,17 +395,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
- res = yield defer.ensureDeferred(
- self.handler.get_room_keys(self.local_user, version)
- )
+ res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given room
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org"
)
@@ -485,18 +411,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given session_id
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
)
self.assertDictEqual(res, room_keys)
- @defer.inlineCallbacks
def test_upload_room_keys_merge(self):
"""Check that we can upload a new room_key for an existing session and
have it correctly merged"""
- version = yield defer.ensureDeferred(
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -507,12 +432,12 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
# get the etag to compare to future versions
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
backup_etag = res["etag"]
self.assertEqual(res["count"], 1)
@@ -522,37 +447,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# test that increasing the message_index doesn't replace the existing session
new_room_key["first_message_index"] = 2
new_room_key["session_data"] = "new"
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
- res = yield defer.ensureDeferred(
- self.handler.get_room_keys(self.local_user, version)
- )
+ res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK",
)
# the etag should be the same since the session did not change
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# test that marking the session as verified however /does/ replace it
new_room_key["is_verified"] = True
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
- res = yield defer.ensureDeferred(
- self.handler.get_room_keys(self.local_user, version)
- )
+ res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should NOT be equal now, since the key changed
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
self.assertNotEqual(res["etag"], backup_etag)
backup_etag = res["etag"]
@@ -560,28 +481,24 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# with a lower forwarding count
new_room_key["forwarded_count"] = 2
new_room_key["session_data"] = "other"
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
- res = yield defer.ensureDeferred(
- self.handler.get_room_keys(self.local_user, version)
- )
+ res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should be the same since the session did not change
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# TODO: check edge cases as well as the common variations here
- @defer.inlineCallbacks
def test_delete_room_keys(self):
- """Check that we can insert and delete keys for a session
- """
- version = yield defer.ensureDeferred(
+ """Check that we can insert and delete keys for a session"""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -593,13 +510,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(version, "1")
# check for bulk-delete
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
- yield defer.ensureDeferred(
- self.handler.delete_room_keys(self.local_user, version)
- )
- res = yield defer.ensureDeferred(
+ self.get_success(self.handler.delete_room_keys(self.local_user, version))
+ res = self.get_success(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
@@ -607,15 +522,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per room
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.delete_room_keys(
self.local_user, version, room_id="!abc:matrix.org"
)
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
@@ -623,15 +538,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per session
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.delete_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index d0452e1490..3af361195b 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -16,7 +16,7 @@ import logging
from unittest import TestCase
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
@@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- with LoggingContext(request="send_rejected"):
+ with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- with LoggingContext(request="send_rejected"):
+ with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -191,6 +191,58 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invite_by_user_ratelimit(self):
+ """Tests that invites from federation to a particular user are
+ actually rate-limited.
+ """
+ other_server = "otherserver"
+ other_user = "@otheruser:" + other_server
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ def create_invite():
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+ return event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": other_user,
+ "state_key": "@user:test",
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ for i in range(3):
+ event = create_invite()
+ self.get_success(
+ self.handler.on_invite_request(
+ other_server,
+ event,
+ event.room_version,
+ )
+ )
+
+ event = create_invite()
+ self.get_failure(
+ self.handler.on_invite_request(
+ other_server,
+ event,
+ event.room_version,
+ ),
+ exc=LimitExceededError,
+ )
+
def _build_and_send_join_event(self, other_server, other_user, room_id):
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
@@ -198,7 +250,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
# the auth code requires that a signature exists, but doesn't check that
# signature... go figure.
join_event.signatures[other_server] = {"x": "y"}
- with LoggingContext(request="send_join"):
+ with LoggingContext("send_join"):
d = run_in_background(
self.handler.on_send_join_request, other_server, join_event
)
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index af42775815..a0d1ebdbe3 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -44,7 +44,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.info = self.get_success(
- self.hs.get_datastore().get_user_by_access_token(self.access_token,)
+ self.hs.get_datastore().get_user_by_access_token(
+ self.access_token,
+ )
)
self.token_id = self.info.token_id
@@ -169,8 +171,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
def test_allow_server_acl(self):
- """Test that sending an ACL that blocks everyone but ourselves works.
- """
+ """Test that sending an ACL that blocks everyone but ourselves works."""
self.helper.send_state(
self.room_id,
@@ -181,8 +182,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
)
def test_deny_server_acl_block_outselves(self):
- """Test that sending an ACL that blocks ourselves does not work.
- """
+ """Test that sending an ACL that blocks ourselves does not work."""
self.helper.send_state(
self.room_id,
EventTypes.ServerACL,
@@ -192,8 +192,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
)
def test_deny_redact_server_acl(self):
- """Test that attempting to redact an ACL is blocked.
- """
+ """Test that attempting to redact an ACL is blocked."""
body = self.helper.send_state(
self.room_id,
@@ -206,7 +205,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
# Redaction of event should fail.
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room_id, event_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", path, content={}, access_token=self.access_token
)
self.assertEqual(int(channel.result["code"]), 403)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a308c46da9..c7796fb837 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,32 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+import os
from urllib.parse import parse_qs, urlparse
-from mock import Mock, patch
+from mock import ANY, Mock, patch
-import attr
import pymacaroons
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
-
-from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
+from synapse.server import HomeServer
from synapse.types import UserID
+from synapse.util.macaroons import get_value_from_macaroon
+from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
+try:
+ import authlib # noqa: F401
-@attr.s
-class FakeResponse:
- code = attr.ib()
- body = attr.ib()
- phrase = attr.ib()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
+ HAS_OIDC = True
+except ImportError:
+ HAS_OIDC = False
# These are a few constants that are used as config parameters in the tests.
@@ -46,7 +41,7 @@ ISSUER = "https://issuer/"
CLIENT_ID = "test-client-id"
CLIENT_SECRET = "test-client-secret"
BASE_URL = "https://synapse/"
-CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
SCOPES = ["openid"]
AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
@@ -56,7 +51,18 @@ WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
JWKS_URI = ISSUER + ".well-known/jwks.json"
# config for common cases
-COMMON_CONFIG = {
+DEFAULT_CONFIG = {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "issuer": ISSUER,
+ "scopes": SCOPES,
+ "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
+}
+
+# extends the default config with explicit OAuth2 endpoints instead of using discovery
+EXPLICIT_ENDPOINT_CONFIG = {
+ **DEFAULT_CONFIG,
"discover": False,
"authorization_endpoint": AUTHORIZATION_ENDPOINT,
"token_endpoint": TOKEN_ENDPOINT,
@@ -64,17 +70,14 @@ COMMON_CONFIG = {
}
-# The cookie name and path don't really matter, just that it has to be coherent
-# between the callback & redirect handlers.
-COOKIE_NAME = b"oidc_session"
-COOKIE_PATH = "/_synapse/oidc"
-
-
-class TestMappingProvider(OidcMappingProvider):
+class TestMappingProvider:
@staticmethod
def parse_config(config):
return
+ def __init__(self, config):
+ pass
+
def get_remote_user_id(self, userinfo):
return userinfo["sub"]
@@ -97,16 +100,6 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-def simple_async_mock(return_value=None, raises=None):
- # AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args, **kwargs):
- if raises:
- raise raises
- return return_value
-
- return Mock(side_effect=cb)
-
-
async def get_json(url):
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
@@ -126,28 +119,42 @@ async def get_json(url):
return {"keys": []}
+def _key_file_path() -> str:
+ """path to a file containing the private half of a test key"""
+
+ # this key was generated with:
+ # openssl ecparam -name prime256v1 -genkey -noout |
+ # openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8
+ #
+ # we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because
+ # that's what Apple use, and we want to be sure that we work with Apple's keys.
+ #
+ # (For the record: both PKCS8 and SEC-1 specify (different) ways of representing
+ # keys using ASN.1. Both are then typically formatted using PEM, which says: use the
+ # base64-encoded DER encoding of ASN.1, with headers and footers. But we don't
+ # really need to care about any of that.)
+ return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8")
+
+
+def _public_key_file_path() -> str:
+ """path to a file containing the public half of a test key"""
+ # this was generated with:
+ # openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem
+ #
+ # See above about where oidc_test_key.p8 came from
+ return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem")
+
+
class OidcHandlerTestCase(HomeserverTestCase):
+ if not HAS_OIDC:
+ skip = "requires OIDC"
+
def default_config(self):
config = super().default_config()
config["public_baseurl"] = BASE_URL
- oidc_config = {
- "enabled": True,
- "client_id": CLIENT_ID,
- "client_secret": CLIENT_SECRET,
- "issuer": ISSUER,
- "scopes": SCOPES,
- "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
- }
-
- # Update this config with what's in the default config so that
- # override_config works as expected.
- oidc_config.update(config.get("oidc_config", {}))
- config["oidc_config"] = oidc_config
-
return config
def make_homeserver(self, reactor, clock):
-
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = "Synapse Test"
@@ -155,6 +162,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
self.handler = hs.get_oidc_handler()
+ self.provider = self.handler._providers["oidc"]
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
@@ -166,27 +174,38 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def metadata_edit(self, values):
- return patch.dict(self.handler._provider_metadata, values)
+ """Modify the result that will be returned by the well-known query"""
+
+ async def patched_get_json(uri):
+ res = await get_json(uri)
+ if uri == WELL_KNOWN:
+ res.update(values)
+ return res
+
+ return patch.object(self.http_client, "get_json", patched_get_json)
def assertRenderedError(self, error, error_description=None):
+ self.render_error.assert_called_once()
args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
if error_description is not None:
self.assertEqual(args[2], error_description)
# Reset the render_error mock
self.render_error.reset_mock()
+ return args
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
- self.assertEqual(self.handler._callback_url, CALLBACK_URL)
- self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
- self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+ self.assertEqual(self.provider._callback_url, CALLBACK_URL)
+ self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
+ self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
- @override_config({"oidc_config": {"discover": True}})
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
- metadata = self.get_success(self.handler.load_metadata())
+ metadata = self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.assertEqual(metadata.issuer, ISSUER)
@@ -198,92 +217,105 @@ class OidcHandlerTestCase(HomeserverTestCase):
# subsequent calls should be cached
self.http_client.reset_mock()
- self.get_success(self.handler.load_metadata())
+ self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
- @override_config({"oidc_config": COMMON_CONFIG})
+ @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
- self.get_success(self.handler.load_metadata())
+ self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
- @override_config({"oidc_config": COMMON_CONFIG})
+ @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
- jwks = self.get_success(self.handler.load_jwks())
+ jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
self.assertEqual(jwks, {"keys": []})
# subsequent calls should be cached…
self.http_client.reset_mock()
- self.get_success(self.handler.load_jwks())
+ self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_not_called()
# …unless forced
self.http_client.reset_mock()
- self.get_success(self.handler.load_jwks(force=True))
+ self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing
- with self.metadata_edit({"jwks_uri": None}):
- self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
+ original = self.provider.load_metadata
+
+ async def patched_load_metadata():
+ m = (await original()).copy()
+ m.update({"jwks_uri": None})
+ return m
+
+ with patch.object(self.provider, "load_metadata", patched_load_metadata):
+ self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
- self.handler._scopes = [] # not asking the openid scope
+ self.provider._scopes = [] # not asking the openid scope
self.http_client.get_json.reset_mock()
- jwks = self.get_success(self.handler.load_jwks(force=True))
+ jwks = self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
- @override_config({"oidc_config": COMMON_CONFIG})
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_validate_config(self):
"""Provider metadatas are extensively validated."""
- h = self.handler
+ h = self.provider
+
+ def force_load_metadata():
+ async def force_load():
+ return await h.load_metadata(force=True)
+
+ return get_awaitable_result(force_load())
# Default test config does not throw
- h._validate_metadata()
+ force_load_metadata()
with self.metadata_edit({"issuer": None}):
- self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
with self.metadata_edit({"issuer": "http://insecure/"}):
- self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
- self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
with self.metadata_edit({"authorization_endpoint": None}):
self.assertRaisesRegex(
- ValueError, "authorization_endpoint", h._validate_metadata
+ ValueError, "authorization_endpoint", force_load_metadata
)
with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
self.assertRaisesRegex(
- ValueError, "authorization_endpoint", h._validate_metadata
+ ValueError, "authorization_endpoint", force_load_metadata
)
with self.metadata_edit({"token_endpoint": None}):
- self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
- self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
with self.metadata_edit({"jwks_uri": None}):
- self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
- self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
with self.metadata_edit({"response_types_supported": ["id_token"]}):
self.assertRaisesRegex(
- ValueError, "response_types_supported", h._validate_metadata
+ ValueError, "response_types_supported", force_load_metadata
)
with self.metadata_edit(
{"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
):
# should not throw, as client_secret_basic is the default auth method
- h._validate_metadata()
+ force_load_metadata()
with self.metadata_edit(
{"token_endpoint_auth_methods_supported": ["client_secret_post"]}
@@ -291,7 +323,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRaisesRegex(
ValueError,
"token_endpoint_auth_methods_supported",
- h._validate_metadata,
+ force_load_metadata,
)
# Tests for configs that require the userinfo endpoint
@@ -300,30 +332,33 @@ class OidcHandlerTestCase(HomeserverTestCase):
h._user_profile_method = "userinfo_endpoint"
self.assertTrue(h._uses_userinfo)
- # Revert the profile method and do not request the "openid" scope.
+ # Revert the profile method and do not request the "openid" scope: this should
+ # mean that we check for a userinfo endpoint
h._user_profile_method = "auto"
h._scopes = []
self.assertTrue(h._uses_userinfo)
- self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
+ with self.metadata_edit({"userinfo_endpoint": None}):
+ self.assertRaisesRegex(ValueError, "userinfo_endpoint", force_load_metadata)
- with self.metadata_edit(
- {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
- ):
- # Shouldn't raise with a valid userinfo, even without
- h._validate_metadata()
+ with self.metadata_edit({"jwks_uri": None}):
+ # Shouldn't raise with a valid userinfo, even without jwks
+ force_load_metadata()
- @override_config({"oidc_config": {"skip_verification": True}})
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
def test_skip_verification(self):
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
- self.handler._validate_metadata()
+ get_awaitable_result(self.provider.load_metadata())
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
- req = Mock(spec=["addCookie"])
+ req = Mock(spec=["cookies"])
+ req.cookies = []
+
url = self.get_success(
- self.handler.handle_redirect_request(req, b"http://client/redirect")
+ self.provider.handle_redirect_request(req, b"http://client/redirect")
)
url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@@ -340,28 +375,27 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(len(params["state"]), 1)
self.assertEqual(len(params["nonce"]), 1)
- # Check what is in the cookie
- # note: python3.5 mock does not have the .called_once() method
- calls = req.addCookie.call_args_list
- self.assertEqual(len(calls), 1) # called once
- # For some reason, call.args does not work with python3.5
- args = calls[0][0]
- kwargs = calls[0][1]
- self.assertEqual(args[0], COOKIE_NAME)
- self.assertEqual(kwargs["path"], COOKIE_PATH)
- cookie = args[1]
+ # Check what is in the cookies
+ self.assertEqual(len(req.cookies), 2) # two cookies
+ cookie_header = req.cookies[0]
+
+ # The cookie name and path don't really matter, just that it has to be coherent
+ # between the callback & redirect handlers.
+ parts = [p.strip() for p in cookie_header.split(b";")]
+ self.assertIn(b"Path=/_synapse/client/oidc", parts)
+ name, cookie = parts[0].split(b"=")
+ self.assertEqual(name, b"oidc_session")
macaroon = pymacaroons.Macaroon.deserialize(cookie)
- state = self.handler._get_value_from_macaroon(macaroon, "state")
- nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
- redirect = self.handler._get_value_from_macaroon(
- macaroon, "client_redirect_url"
- )
+ state = get_value_from_macaroon(macaroon, "state")
+ nonce = get_value_from_macaroon(macaroon, "nonce")
+ redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
self.assertEqual(params["state"], [state])
self.assertEqual(params["nonce"], [nonce])
self.assertEqual(redirect, "http://client/redirect")
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
request = Mock(args={})
@@ -373,6 +407,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "some description")
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback(self):
"""Code callback works and display errors if something went wrong.
@@ -384,31 +419,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
- when the userinfo fetching fails
- when the code exchange fails
"""
+
+ # ensure that we are correctly testing the fallback when "get_extra_attributes"
+ # is not implemented.
+ mapping_provider = self.provider._user_mapping_provider
+ with self.assertRaises(AttributeError):
+ _ = mapping_provider.get_extra_attributes
+
token = {
"type": "bearer",
"id_token": "id_token",
"access_token": "access_token",
}
+ username = "bar"
userinfo = {
"sub": "foo",
- "preferred_username": "bar",
- }
- user_id = "@foo:domain.org"
- self.handler._exchange_code = simple_async_mock(return_value=token)
- self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
- self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
- self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(
- spec=[
- "args",
- "getCookie",
- "addCookie",
- "requestHeaders",
- "getClientIP",
- "get_user_agent",
- ]
- )
+ "username": username,
+ }
+ expected_user_id = "@%s:%s" % (username, self.hs.hostname)
+ self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
code = "code"
state = "state"
@@ -416,82 +449,70 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
- request.getCookie.return_value = self.handler._generate_oidc_session_token(
- state=state,
- nonce=nonce,
- client_redirect_url=client_redirect_url,
- ui_auth_session_id=None,
+ session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+ request = _build_callback_request(
+ code, state, session, user_agent=user_agent, ip_address=ip_address
)
- request.args = {}
- request.args[b"code"] = [code.encode("utf-8")]
- request.args[b"state"] = [state.encode("utf-8")]
-
- request.getClientIP.return_value = ip_address
- request.get_user_agent.return_value = user_agent
-
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {},
- )
- self.handler._exchange_code.assert_called_once_with(code)
- self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(
- userinfo, token, user_agent, ip_address
+ auth_handler.complete_sso_login.assert_called_once_with(
+ expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
)
- self.handler._fetch_userinfo.assert_not_called()
+ self.provider._exchange_code.assert_called_once_with(code)
+ self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
+ self.provider._fetch_userinfo.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
- self.handler._map_userinfo_to_user = simple_async_mock(
- raises=MappingException()
- )
- self.get_success(self.handler.handle_oidc_callback(request))
- self.assertRenderedError("mapping_error")
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ with patch.object(
+ self.provider,
+ "_remote_id_from_userinfo",
+ new=Mock(side_effect=MappingException()),
+ ):
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.handler._parse_id_token = simple_async_mock(raises=Exception())
+ self.provider._parse_id_token = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
- self.handler._auth_handler.complete_sso_login.reset_mock()
- self.handler._exchange_code.reset_mock()
- self.handler._parse_id_token.reset_mock()
- self.handler._map_userinfo_to_user.reset_mock()
- self.handler._fetch_userinfo.reset_mock()
+ auth_handler.complete_sso_login.reset_mock()
+ self.provider._exchange_code.reset_mock()
+ self.provider._parse_id_token.reset_mock()
+ self.provider._fetch_userinfo.reset_mock()
# With userinfo fetching
- self.handler._scopes = [] # do not ask the "openid" scope
+ self.provider._scopes = [] # do not ask the "openid" scope
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {},
- )
- self.handler._exchange_code.assert_called_once_with(code)
- self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(
- userinfo, token, user_agent, ip_address
+ auth_handler.complete_sso_login.assert_called_once_with(
+ expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
)
- self.handler._fetch_userinfo.assert_called_once_with(token)
+ self.provider._exchange_code.assert_called_once_with(code)
+ self.provider._parse_id_token.assert_not_called()
+ self.provider._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+ self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
- self.handler._exchange_code = simple_async_mock(
+ from synapse.handlers.oidc_handler import OidcError
+
+ self.provider._exchange_code = simple_async_mock(
raises=OidcError("invalid_request")
)
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
- request = Mock(spec=["args", "getCookie", "addCookie"])
+ request = Mock(spec=["args", "getCookie", "cookies"])
# Missing cookie
request.args = {}
@@ -513,11 +534,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_session")
# Mismatching session
- session = self.handler._generate_oidc_session_token(
+ session = self._generate_oidc_session_token(
state="state",
nonce="nonce",
client_redirect_url="http://client/redirect",
- ui_auth_session_id=None,
)
request.args = {}
request.args[b"state"] = [b"mismatching state"]
@@ -532,7 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
- @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
+ @override_config(
+ {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
+ )
def test_exchange_code(self):
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
@@ -541,7 +563,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
)
code = "code"
- ret = self.get_success(self.handler._exchange_code(code))
+ ret = self.get_success(self.provider._exchange_code(code))
kwargs = self.http_client.request.call_args[1]
self.assertEqual(ret, token)
@@ -563,17 +585,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "foo", "error_description": "bar"}',
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ from synapse.handlers.oidc_handler import OidcError
+
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "foo")
self.assertEqual(exc.value.error_description, "bar")
# Internal server error with no JSON body
self.http_client.request = simple_async_mock(
return_value=FakeResponse(
- code=500, phrase=b"Internal Server Error", body=b"Not JSON",
+ code=500,
+ phrase=b"Internal Server Error",
+ body=b"Not JSON",
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
@@ -585,31 +611,133 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
self.http_client.request = simple_async_mock(
- return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
+ return_value=FakeResponse(
+ code=400,
+ phrase=b"Bad request",
+ body=b"{}",
+ )
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
self.http_client.request = simple_async_mock(
return_value=FakeResponse(
- code=200, phrase=b"OK", body=b'{"error": "some_error"}',
+ code=200,
+ phrase=b"OK",
+ body=b'{"error": "some_error"}',
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "some_error")
@override_config(
{
"oidc_config": {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "issuer": ISSUER,
+ "client_auth_method": "client_secret_post",
+ "client_secret_jwt_key": {
+ "key_file": _key_file_path(),
+ "jwt_header": {"alg": "ES256", "kid": "ABC789"},
+ "jwt_payload": {"iss": "DEFGHI"},
+ },
+ }
+ }
+ )
+ def test_exchange_code_jwt_key(self):
+ """Test that code exchange works with a JWK client secret."""
+ from authlib.jose import jwt
+
+ token = {"type": "bearer"}
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
+ )
+ )
+ code = "code"
+
+ # advance the clock a bit before we start, so we aren't working with zero
+ # timestamps.
+ self.reactor.advance(1000)
+ start_time = self.reactor.seconds()
+ ret = self.get_success(self.provider._exchange_code(code))
+
+ self.assertEqual(ret, token)
+
+ # the request should have hit the token endpoint
+ kwargs = self.http_client.request.call_args[1]
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+ # the client secret provided to the should be a jwt which can be checked with
+ # the public key
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ secret = args["client_secret"][0]
+ with open(_public_key_file_path()) as f:
+ key = f.read()
+ claims = jwt.decode(secret, key)
+ self.assertEqual(claims.header["kid"], "ABC789")
+ self.assertEqual(claims["aud"], ISSUER)
+ self.assertEqual(claims["iss"], "DEFGHI")
+ self.assertEqual(claims["sub"], CLIENT_ID)
+ self.assertEqual(claims["iat"], start_time)
+ self.assertGreater(claims["exp"], start_time)
+
+ # check the rest of the POSTed data
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+ @override_config(
+ {
+ "oidc_config": {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "issuer": ISSUER,
+ "client_auth_method": "none",
+ }
+ }
+ )
+ def test_exchange_code_no_auth(self):
+ """Test that code exchange works with no client secret."""
+ token = {"type": "bearer"}
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
+ )
+ )
+ code = "code"
+ ret = self.get_success(self.provider._exchange_code(code))
+
+ self.assertEqual(ret, token)
+
+ # the request should have hit the token endpoint
+ kwargs = self.http_client.request.call_args[1]
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+ # check the POSTed data
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderExtra"
- }
+ },
}
}
)
@@ -624,72 +752,60 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
userinfo = {
"sub": "foo",
+ "username": "foo",
"phone": "1234567",
}
- user_id = "@foo:domain.org"
- self.handler._exchange_code = simple_async_mock(return_value=token)
- self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
- self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(
- spec=[
- "args",
- "getCookie",
- "addCookie",
- "requestHeaders",
- "getClientIP",
- "get_user_agent",
- ]
- )
+ self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
state = "state"
client_redirect_url = "http://client/redirect"
- request.getCookie.return_value = self.handler._generate_oidc_session_token(
+ session = self._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
- ui_auth_session_id=None,
)
-
- request.args = {}
- request.args[b"code"] = [b"code"]
- request.args[b"state"] = [state.encode("utf-8")]
-
- request.getClientIP.return_value = "10.0.0.1"
- request.get_user_agent.return_value = "Browser"
+ request = _build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {"phone": "1234567"},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@foo:test",
+ "oidc",
+ request,
+ client_redirect_url,
+ {"phone": "1234567"},
+ new_user=True,
)
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
userinfo = {
"sub": "test_user",
"username": "test_user",
}
- # The token doesn't matter with the default user mapping provider.
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", "oidc", ANY, ANY, None, new_user=True
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Some providers return an integer ID.
userinfo = {
"sub": 1234,
"username": "test_user_2",
}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
)
- self.assertEqual(mxid, "@test_user_2:test")
+ auth_handler.complete_sso_login.reset_mock()
# Test if the mxid is already taken
store = self.hs.get_datastore()
@@ -698,17 +814,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
- )
- self.assertEqual(
- str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error",
+ "Mapping provider does not support de-duplicating Matrix IDs",
)
- @override_config({"oidc_config": {"allow_existing_users": True}})
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self):
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastore()
@@ -717,26 +830,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user.to_string(), password_hash=None)
)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
# Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
@@ -747,13 +860,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
@@ -770,14 +881,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+ args = self.assertRenderedError("mapping_error")
self.assertTrue(
- str(e.value).startswith(
+ args[2].startswith(
"Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
)
)
@@ -788,40 +896,34 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
)
- self.assertEqual(mxid, "@TEST_USER_2:test")
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
- userinfo = {
- "sub": "test2",
- "username": "föö",
- }
- token = {}
-
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self.get_success(
+ _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
)
- self.assertEqual(str(e.value), "localpart is invalid: föö")
+ self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
{
"oidc_config": {
+ **DEFAULT_CONFIG,
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderFailures"
- }
+ },
}
}
)
def test_map_userinfo_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
@@ -830,14 +932,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
# test_user is already taken, so test_user1 gets registered instead.
- self.assertEqual(mxid, "@test_user1:test")
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
+ )
+ auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
@@ -853,12 +954,266 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
- self.assertEqual(
- str(e.value), "Unable to generate a Matrix ID from the SSO response"
+
+ @override_config({"oidc_config": DEFAULT_CONFIG})
+ def test_empty_localpart(self):
+ """Attempts to map onto an empty localpart should be rejected."""
+ userinfo = {
+ "sub": "tester",
+ "username": "",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "user_mapping_provider": {
+ "config": {"localpart_template": "{{ user.username }}"}
+ },
+ }
+ }
+ )
+ def test_null_localpart(self):
+ """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
+ userinfo = {
+ "sub": "tester",
+ "username": None,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+ }
+ }
+ )
+ def test_attribute_requirements(self):
+ """The required attributes must be met from the OIDC userinfo response."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # userinfo lacking "test": "foobar" attribute should fail.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": "foobar" attribute should succeed.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": "foobar",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@tester:test", "oidc", ANY, ANY, None, new_user=True
)
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+ }
+ }
+ )
+ def test_attribute_requirements_contains(self):
+ """Test that auth succeeds if userinfo attribute CONTAINS required value"""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": ["foobar", "foo", "bar"],
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@tester:test", "oidc", ANY, ANY, None, new_user=True
+ )
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+ }
+ }
+ )
+ def test_attribute_requirements_mismatch(self):
+ """
+ Test that auth fails if attributes exist but don't match,
+ or are non-string values.
+ """
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ # userinfo with "test": "not_foobar" attribute should fail
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": "not_foobar",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": ["foo", "bar"] attribute should fail
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": ["foo", "bar"],
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": False attribute should fail
+ # this is largely just to ensure we don't crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": False,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": None attribute should fail
+ # a value of None breaks the OIDC spec, but it's important to not crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": None,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": 1 attribute should fail
+ # this is largely just to ensure we don't crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": 1,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": 3.14 attribute should fail
+ # this is largely just to ensure we don't crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": 3.14,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ def _generate_oidc_session_token(
+ self,
+ state: str,
+ nonce: str,
+ client_redirect_url: str,
+ ui_auth_session_id: str = "",
+ ) -> str:
+ from synapse.handlers.oidc_handler import OidcSessionData
+
+ return self.handler._token_generator.generate_oidc_session_token(
+ state=state,
+ session_data=OidcSessionData(
+ idp_id="oidc",
+ nonce=nonce,
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=ui_auth_session_id,
+ ),
+ )
+
+
+async def _make_callback_with_userinfo(
+ hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+) -> None:
+ """Mock up an OIDC callback with the given userinfo dict
+
+ We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
+ and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
+
+ Args:
+ hs: the HomeServer impl to send the callback to.
+ userinfo: the OIDC userinfo dict
+ client_redirect_url: the URL to redirect to on success.
+ """
+ from synapse.handlers.oidc_handler import OidcSessionData
+
+ handler = hs.get_oidc_handler()
+ provider = handler._providers["oidc"]
+ provider._exchange_code = simple_async_mock(return_value={})
+ provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+
+ state = "state"
+ session = handler._token_generator.generate_oidc_session_token(
+ state=state,
+ session_data=OidcSessionData(
+ idp_id="oidc",
+ nonce="nonce",
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id="",
+ ),
+ )
+ request = _build_callback_request("code", state, session)
+
+ await handler.handle_oidc_callback(request)
+
+
+def _build_callback_request(
+ code: str,
+ state: str,
+ session: str,
+ user_agent: str = "Browser",
+ ip_address: str = "10.0.0.1",
+):
+ """Builds a fake SynapseRequest to mock the browser callback
+
+ Returns a Mock object which looks like the SynapseRequest we get from a browser
+ after SSO (before we return to the client)
+
+ Args:
+ code: the authorization code which would have been returned by the OIDC
+ provider
+ state: the "state" param which would have been passed around in the
+ query param. Should be the same as was embedded in the session in
+ _build_oidc_session.
+ session: the "session" which would have been passed around in the cookie.
+ user_agent: the user-agent to present
+ ip_address: the IP address to pretend the request came from
+ """
+ request = Mock(
+ spec=[
+ "args",
+ "getCookie",
+ "cookies",
+ "requestHeaders",
+ "getClientIP",
+ "getHeader",
+ ]
+ )
+
+ request.cookies = []
+ request.getCookie.return_value = session
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+ request.getClientIP.return_value = ip_address
+ return request
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ceaf0902d2..a98a65ae67 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -231,8 +231,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_no_local_user_fallback_login(self):
- """localdb_enabled can block login with the local password
- """
+ """localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass")
# check_password must return an awaitable
@@ -251,8 +250,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_no_local_user_fallback_ui_auth(self):
- """localdb_enabled can block ui auth with the local password
- """
+ """localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass")
# allow login via the auth provider
@@ -432,6 +430,29 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
+ **providers_config(CustomAuthProvider),
+ "password_config": {"enabled": False, "localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_password_disabled_localdb_enabled(self):
+ """Check the localdb_enabled == enabled == False
+
+ Regression test for https://github.com/matrix-org/synapse/issues/8914: check
+ that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
+ cause an exception.
+ """
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
@@ -528,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
def _get_login_flows(self) -> JsonDict:
- _, channel = self.make_request("GET", "/_matrix/client/r0/login")
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["flows"]
@@ -537,7 +558,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
def _send_login(self, type, user, **params) -> FakeChannel:
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
- _, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
+ channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel
def _start_delete_device_session(self, access_token, device_id) -> str:
@@ -571,10 +592,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)
def _delete_device(
- self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
+ self,
+ access_token: str,
+ device: str,
+ body: Union[JsonDict, bytes] = b"",
) -> FakeChannel:
"""Delete an individual device."""
- _, channel = self.make_request(
+ channel = self.make_request(
"DELETE", "devices/" + device, body, access_token=access_token
)
return channel
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 0794b32c9c..77330f59a9 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -310,6 +310,26 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.UNAVAILABLE)
+ def test_busy_no_idle(self):
+ """
+ Tests that a user setting their presence to busy but idling doesn't turn their
+ presence state into unavailable.
+ """
+ user_id = "@foo:bar"
+ now = 5000000
+
+ state = UserPresenceState.default(user_id)
+ state = state.copy_and_replace(
+ state=PresenceState.BUSY,
+ last_active_ts=now - IDLE_TIMER - 1,
+ last_user_sync_ts=now,
+ )
+
+ new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+
+ self.assertIsNotNone(new_state)
+ self.assertEquals(new_state.state, PresenceState.BUSY)
+
def test_sync_timeout(self):
user_id = "@foo:bar"
now = 5000000
@@ -521,7 +541,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_state.state, PresenceState.ONLINE)
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
- destinations=["server2"], states=[expected_state]
+ destinations=["server2"], states={expected_state}
)
#
@@ -533,7 +553,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.federation_sender.send_presence.assert_not_called()
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
- destinations=["server3"], states=[expected_state]
+ destinations=["server3"], states={expected_state}
)
def test_remote_gets_presence_when_local_user_joins(self):
@@ -584,13 +604,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.presence_handler.current_state_for_user("@test2:server")
)
self.assertEqual(expected_state.state, PresenceState.ONLINE)
- self.federation_sender.send_presence_to_destinations.assert_called_once_with(
- destinations={"server2", "server3"}, states=[expected_state]
+ self.assertEqual(
+ self.federation_sender.send_presence_to_destinations.call_count, 2
+ )
+ self.federation_sender.send_presence_to_destinations.assert_any_call(
+ destinations=["server3"], states={expected_state}
+ )
+ self.federation_sender.send_presence_to_destinations.assert_any_call(
+ destinations=["server2"], states={expected_state}
)
def _add_new_user(self, room_id, user_id):
- """Add new user to the room by creating an event and poking the federation API.
- """
+ """Add new user to the room by creating an event and poking the federation API."""
hostname = get_domain_from_id(user_id)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 300a625972..cbbe7280c7 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -13,25 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from mock import Mock
-from twisted.internet import defer
-
import synapse.types
from synapse.api.errors import AuthError, SynapseError
from synapse.types import UserID
from tests import unittest
from tests.test_utils import make_awaitable
-from tests.utils import setup_test_homeserver
-class ProfileTestCase(unittest.TestCase):
+class ProfileTestCase(unittest.HomeserverTestCase):
""" Tests profile management. """
- @defer.inlineCallbacks
- def setUp(self):
+ def make_homeserver(self, reactor, clock):
self.mock_federation = Mock()
self.mock_registry = Mock()
@@ -42,41 +37,35 @@ class ProfileTestCase(unittest.TestCase):
self.mock_registry.register_query_handler = register_query_handler
- hs = yield setup_test_homeserver(
- self.addCleanup,
- federation_http_client=None,
- resource_for_federation=Mock(),
+ hs = self.setup_test_homeserver(
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
)
+ return hs
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.frank = UserID.from_string("@1234ABCD:test")
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
+ self.get_success(self.store.create_profile(self.frank.localpart))
self.handler = hs.get_profile_handler()
- self.hs = hs
- @defer.inlineCallbacks
def test_get_my_name(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
)
- displayname = yield defer.ensureDeferred(
- self.handler.get_displayname(self.frank)
- )
+ displayname = self.get_success(self.handler.get_displayname(self.frank))
self.assertEquals("Frank", displayname)
- @defer.inlineCallbacks
def test_set_my_name(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
@@ -84,7 +73,7 @@ class ProfileTestCase(unittest.TestCase):
self.assertEquals(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.frank.localpart)
)
),
@@ -92,7 +81,7 @@ class ProfileTestCase(unittest.TestCase):
)
# Set displayname again
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank"
)
@@ -100,25 +89,35 @@ class ProfileTestCase(unittest.TestCase):
self.assertEquals(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank",
)
- @defer.inlineCallbacks
+ # Set displayname to an empty string
+ self.get_success(
+ self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.frank), ""
+ )
+ )
+
+ self.assertIsNone(
+ (self.get_success(self.store.get_profile_displayname(self.frank.localpart)))
+ )
+
def test_set_my_name_if_disabled(self):
self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
)
self.assertEquals(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.frank.localpart)
)
),
@@ -126,33 +125,27 @@ class ProfileTestCase(unittest.TestCase):
)
# Setting displayname a second time is forbidden
- d = defer.ensureDeferred(
+ self.get_failure(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
- )
+ ),
+ SynapseError,
)
- yield self.assertFailure(d, SynapseError)
-
- @defer.inlineCallbacks
def test_set_my_name_noauth(self):
- d = defer.ensureDeferred(
+ self.get_failure(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
- )
+ ),
+ AuthError,
)
- yield self.assertFailure(d, AuthError)
-
- @defer.inlineCallbacks
def test_get_other_name(self):
self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)
- displayname = yield defer.ensureDeferred(
- self.handler.get_displayname(self.alice)
- )
+ displayname = self.get_success(self.handler.get_displayname(self.alice))
self.assertEquals(displayname, "Alice")
self.mock_federation.make_query.assert_called_with(
@@ -162,35 +155,34 @@ class ProfileTestCase(unittest.TestCase):
ignore_backoff=True,
)
- @defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield defer.ensureDeferred(self.store.create_profile("caroline"))
- yield defer.ensureDeferred(
- self.store.set_profile_displayname("caroline", "Caroline", 1)
- )
+ self.get_success(self.store.create_profile("caroline"))
+ self.get_success(self.store.set_profile_displayname("caroline", "Caroline", 1))
- response = yield defer.ensureDeferred(
+ response = self.get_success(
self.query_handlers["profile"](
- {"user_id": "@caroline:test", "field": "displayname"}
+ {
+ "user_id": "@caroline:test",
+ "field": "displayname",
+ "origin": "servername.tld",
+ }
)
)
self.assertEquals({"displayname": "Caroline"}, response)
- @defer.inlineCallbacks
def test_get_my_avatar(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png", 1
)
)
- avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
+ avatar_url = self.get_success(self.handler.get_avatar_url(self.frank))
self.assertEquals("http://my.server/me.png", avatar_url)
- @defer.inlineCallbacks
def test_set_my_avatar(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
@@ -199,16 +191,12 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (
- yield defer.ensureDeferred(
- self.store.get_profile_avatar_url(self.frank.localpart)
- )
- ),
+ (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
"http://my.server/pic.gif",
)
# Set avatar again
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
@@ -217,41 +205,44 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (
- yield defer.ensureDeferred(
- self.store.get_profile_avatar_url(self.frank.localpart)
- )
- ),
+ (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
"http://my.server/me.png",
)
- @defer.inlineCallbacks
+ # Set avatar to an empty string
+ self.get_success(
+ self.handler.set_avatar_url(
+ self.frank,
+ synapse.types.create_requester(self.frank),
+ "",
+ )
+ )
+
+ self.assertIsNone(
+ (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+ )
+
def test_set_my_avatar_if_disabled(self):
self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png", 1
)
)
self.assertEquals(
- (
- yield defer.ensureDeferred(
- self.store.get_profile_avatar_url(self.frank.localpart)
- )
- ),
+ (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
"http://my.server/me.png",
)
# Set avatar a second time is forbidden
- d = defer.ensureDeferred(
+ self.get_failure(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
- )
+ ),
+ SynapseError,
)
-
- yield self.assertFailure(d, SynapseError)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index d979aadcd6..00a0bc5274 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,7 +18,6 @@ from mock import Mock
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
-from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha.register import (
_map_email_to_displayname,
register_servlets,
@@ -527,6 +526,37 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
+ def test_spam_checker_receives_sso_type(self):
+ """Test rejecting registration based on SSO type"""
+
+ class BanBadIdPUser:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info, auth_provider_id=None
+ ):
+ # Reject any user coming from CAS and whose username contains profanity
+ if auth_provider_id == "cas" and "flimflob" in username:
+ return RegistrationBehaviour.DENY
+ return RegistrationBehaviour.ALLOW
+
+ # Configure a spam checker that denies a certain user on a specific IdP
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanBadIdPUser()]
+
+ f = self.get_failure(
+ self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
+ SynapseError,
+ )
+ exception = f.value
+
+ # We return 429 from the spam checker for denied registrations
+ self.assertIsInstance(exception, SynapseError)
+ self.assertEqual(exception.code, 429)
+
+ # Check the same username can register using SAML
+ self.get_success(
+ self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
+ )
+
def test_email_to_displayname_mapping(self):
"""Test that custom emails are mapped to new user displaynames correctly"""
self._check_mapping(
@@ -580,7 +610,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Mock Synapse's http json post method to check for the internal bind call
post_json_get_json = Mock(return_value=make_awaitable(None))
- self.hs.get_simple_http_client().post_json_get_json = post_json_get_json
+ self.hs.get_identity_handler().http_client.post_json_get_json = (
+ post_json_get_json
+ )
# Retrieve a UIA session ID
channel = self.uia_register(
@@ -617,11 +649,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def uia_register(self, expected_response: int, body: dict) -> FakeChannel:
"""Make a register request."""
- request, channel = self.make_request(
- "POST", "register", body
- ) # type: SynapseRequest, FakeChannel
+ channel = self.make_request("POST", "register", body)
- self.assertEqual(request.code, expected_response)
+ self.assertEqual(channel.code, expected_response)
return channel
async def get_or_create_user(
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 45dc17aba5..30efd43b40 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,13 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
+from mock import Mock
+
import attr
from synapse.api.errors import RedirectException
-from synapse.handlers.sso import MappingException
+from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
+# Check if we have the dependencies to run the tests.
+try:
+ import saml2.config
+ from saml2.sigver import SigverError
+
+ has_saml2 = True
+
+ # pysaml2 can be installed and imported, but might not be able to find xmlsec1.
+ config = saml2.config.SPConfig()
+ try:
+ config.load({"metadata": {}})
+ has_xmlsec1 = True
+ except SigverError:
+ has_xmlsec1 = False
+except ImportError:
+ has_saml2 = False
+ has_xmlsec1 = False
+
# These are a few constants that are used as config parameters in the tests.
BASE_URL = "https://synapse/"
@@ -26,6 +48,8 @@ BASE_URL = "https://synapse/"
@attr.s
class FakeAuthnResponse:
ava = attr.ib(type=dict)
+ assertions = attr.ib(type=list, factory=list)
+ in_response_to = attr.ib(type=Optional[str], default=None)
class TestMappingProvider:
@@ -86,17 +110,29 @@ class SamlHandlerTestCase(HomeserverTestCase):
return hs
+ if not has_saml2:
+ skip = "Requires pysaml2"
+ elif not has_xmlsec1:
+ skip = "Requires xmlsec1"
+
def test_map_saml_response_to_user(self):
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
- # The redirect_url doesn't matter with the default user mapping provider.
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
- self.assertEqual(mxid, "@test_user:test")
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self):
@@ -106,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase):
store.register_user(user_id="@test_user:test", password_hash=None)
)
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
# Map a user via SSO.
saml_response = FakeAuthnResponse(
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
)
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", "saml", request, "", None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
# Subsequent calls should map to the same mxid.
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ auth_handler.complete_sso_login.reset_mock()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "")
+ )
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", "saml", request, "", None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
def test_map_saml_response_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # mock out the error renderer too
+ sso_handler = self.hs.get_sso_handler()
+ sso_handler.render_error = Mock(return_value=None)
+
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
- redirect_url = ""
- e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
+ )
+ sso_handler.render_error.assert_called_once_with(
+ request, "mapping_error", "localpart is invalid: föö"
)
- self.assertEqual(str(e.value), "localpart is invalid: föö")
+ auth_handler.complete_sso_login.assert_not_called()
def test_map_saml_response_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
+
+ # stub out the auth handler and error renderer
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ sso_handler = self.hs.get_sso_handler()
+ sso_handler.render_error = Mock(return_value=None)
+
+ # register a user to occupy the first-choice MXID
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)
+
+ # send the fake SAML response
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
)
+
# test_user is already taken, so test_user1 gets registered instead.
- self.assertEqual(mxid, "@test_user1:test")
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user1:test", "saml", request, "", None, new_user=True
+ )
+ auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular SAML username.
self.get_success(
@@ -165,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase):
# Now attempt to map to a username, this will fail since all potential usernames are taken.
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
- e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
)
- self.assertEqual(
- str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ sso_handler.render_error.assert_called_once_with(
+ request,
+ "mapping_error",
+ "Unable to generate a Matrix ID from the SSO response",
)
+ auth_handler.complete_sso_login.assert_not_called()
@override_config(
{
@@ -185,12 +249,71 @@ class SamlHandlerTestCase(HomeserverTestCase):
}
)
def test_map_saml_response_redirect(self):
+ """Test a mapping provider that raises a RedirectException"""
+
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- redirect_url = ""
+ request = _mock_request()
e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
+ self.handler._handle_authn_response(request, saml_response, ""),
RedirectException,
)
self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
+
+ @override_config(
+ {
+ "saml2_config": {
+ "attribute_requirements": [
+ {"attribute": "userGroup", "value": "staff"},
+ {"attribute": "department", "value": "sales"},
+ ],
+ },
+ }
+ )
+ def test_attribute_requirements(self):
+ """The required attributes must be met from the SAML response."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # The response doesn't have the proper userGroup or department.
+ saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+ )
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # The response doesn't have the proper department.
+ saml_response = FakeAuthnResponse(
+ {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]}
+ )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+ )
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # Add the proper attributes and it should succeed.
+ saml_response = FakeAuthnResponse(
+ {
+ "uid": "test_user",
+ "username": "test_user",
+ "userGroup": ["staff", "admin"],
+ "department": ["sales"],
+ }
+ )
+ request.reset_mock()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
+ )
+
+
+def _mock_request():
+ """Returns a mock which will stand in as a SynapseRequest"""
+ return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 36086ca836..24e7138196 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@
import json
+from typing import Dict
from mock import ANY, Mock, call
from twisted.internet import defer
+from twisted.web.resource import Resource
from synapse.api.errors import AuthError
+from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
-from tests.utils import register_federation_servlets
# Some local users to test with
U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- servlets = [register_federation_servlets]
-
def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
@@ -77,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
return hs
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TransportLayerServer(self.hs)
+ return d
+
def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
@@ -138,14 +143,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
self.datastore.get_to_device_stream_token = lambda: 0
- self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
- ([], 0)
+ self.datastore.get_new_device_msgs_for_remote = (
+ lambda *args, **kargs: make_awaitable(([], 0))
)
- self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
- None
+ self.datastore.delete_device_msgs_for_remote = (
+ lambda *args, **kargs: make_awaitable(None)
)
- self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
- None
+ self.datastore.set_received_txn_response = (
+ lambda *args, **kwargs: make_awaitable(None)
)
def test_started_typing_local(self):
@@ -215,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- (request, channel) = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/federation/v1/send/1000000",
_make_edu_transaction_json(
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index dcd02932b2..dbe68bb058 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -57,6 +57,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
)
)
+ regular_user_id = "@regular:test"
+ self.get_success(
+ self.store.register_user(user_id=regular_user_id, password_hash=None)
+ )
self.get_success(
self.handler.handle_local_profile_change(support_user_id, None)
@@ -66,13 +70,47 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
display_name = "display_name"
profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
- regular_user_id = "@regular:test"
self.get_success(
self.handler.handle_local_profile_change(regular_user_id, profile_info)
)
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
self.assertTrue(profile["display_name"] == display_name)
+ def test_handle_local_profile_change_with_deactivated_user(self):
+ # create user
+ r_user_id = "@regular:test"
+ self.get_success(
+ self.store.register_user(user_id=r_user_id, password_hash=None)
+ )
+
+ # update profile
+ display_name = "Regular User"
+ profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
+ self.get_success(
+ self.handler.handle_local_profile_change(r_user_id, profile_info)
+ )
+
+ # profile is in directory
+ profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ self.assertTrue(profile["display_name"] == display_name)
+
+ # deactivate user
+ self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
+ self.get_success(self.handler.handle_user_deactivated(r_user_id))
+
+ # profile is not in directory
+ profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ self.assertTrue(profile is None)
+
+ # update profile after deactivation
+ self.get_success(
+ self.handler.handle_local_profile_change(r_user_id, profile_info)
+ )
+
+ # profile is furthermore not in directory
+ profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ self.assertTrue(profile is None)
+
def test_handle_user_deactivated_support_user(self):
s_user_id = "@support:test"
self.get_success(
@@ -165,7 +203,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Check that the room has an encryption state event
event_content = self.helper.get_state(
- room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
)
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
@@ -174,7 +214,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Check that the room has an encryption state event
event_content = self.helper.get_state(
- room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
)
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
@@ -192,7 +234,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Check that the room has an encryption state event
event_content = self.helper.get_state(
- room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
)
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
@@ -273,7 +317,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
spam_checker = self.hs.get_spam_checker()
class AllowAll:
- def check_username_for_spam(self, user_profile):
+ async def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -286,7 +330,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that filters all users.
class BlockAll:
- def check_username_for_spam(self, user_profile):
+ async def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
@@ -575,7 +619,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
[self.assertIn(user, remote_users) for user in received_user_id_ordering[3:]]
def _add_user_to_room(
- self, room_id: str, room_version: RoomVersion, user_id: str,
+ self,
+ room_id: str,
+ room_version: RoomVersion,
+ user_id: str,
):
# Add a user to the room.
builder = self.event_builder_factory.for_room_version(
@@ -628,7 +675,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
self.helper.join(room, user=u2)
# Assert user directory is not empty
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
self.assertEquals(200, channel.code, channel.result)
@@ -636,7 +683,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
# Disable user directory and check search returns nothing
self.config.user_directory_search_enabled = False
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
self.assertEquals(200, channel.code, channel.result)
@@ -672,7 +719,7 @@ class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
user_one, user_two, user_three, user_three_token = self.setup_test_users()
# Request info about each user from user_three
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
path="/_matrix/client/unstable/users/info",
content={"user_ids": [user_one, user_two, user_three]},
@@ -703,7 +750,7 @@ class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
user_one, user_two, user_three, user_three_token = self.setup_test_users()
# Request information about our local users from the perspective of a remote server
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
path="/_matrix/federation/unstable/users/info",
content={"user_ids": [user_one, user_two, user_three]},
@@ -756,9 +803,7 @@ class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
"expiration_ts": 0,
"enable_renewal_emails": False,
}
- request, channel = self.make_request(
- "POST", url, request_data, access_token=admin_tok
- )
+ channel = self.make_request("POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
def deactivate(self, user_id, tok):
@@ -766,7 +811,7 @@ class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
"auth": {"type": "m.login.password", "user": user_id, "password": "pass"},
"erase": False,
}
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index aac8c7d623..3972abb038 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -36,6 +36,7 @@ from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.federation.srv_resolver import Server
from synapse.http.federation.well_known_resolver import (
+ WELL_KNOWN_MAX_SIZE,
WellKnownResolver,
_cache_period_from_headers,
)
@@ -517,8 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.successResultOf(test_d)
def test_get_well_known(self):
- """Test the behaviour when the .well-known delegates elsewhere
- """
+ """Test the behaviour when the .well-known delegates elsewhere"""
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -1094,7 +1094,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# Expire both caches and repeat the request
self.reactor.pump((10000.0,))
- # Repated the request, this time it should fail if the lookup fails.
+ # Repeat the request, this time it should fail if the lookup fails.
fetch_d = defer.ensureDeferred(
self.well_known_resolver.get_well_known(b"testserv")
)
@@ -1107,9 +1107,34 @@ class MatrixFederationAgentTests(unittest.TestCase):
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, None)
+ def test_well_known_too_large(self):
+ """A well-known query that returns a result which is too large should be rejected."""
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ response_headers={b"Cache-Control": b"max-age=1000"},
+ content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }',
+ )
+
+ # The result is successful, but disabled delegation.
+ r = self.successResultOf(fetch_d)
+ self.assertIsNone(r.delegated_server)
+
def test_srv_fallbacks(self):
- """Test that other SRV results are tried if the first one fails.
- """
+ """Test that other SRV results are tried if the first one fails."""
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[
Server(host=b"target.com", port=8443),
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 05e9c449be..453391a5a5 100644
--- a/tests/http/test_additional_resource.py
+++ b/tests/http/test_additional_resource.py
@@ -46,16 +46,16 @@ class AdditionalResourceTests(HomeserverTestCase):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler)
- request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+ channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler)
- request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+ channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
new file mode 100644
index 0000000000..0ce181a51e
--- /dev/null
+++ b/tests/http/test_client.py
@@ -0,0 +1,243 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from io import BytesIO
+
+from mock import Mock
+
+from netaddr import IPSet
+
+from twisted.internet.error import DNSLookupError
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import AccumulatingProtocol
+from twisted.web.client import Agent, ResponseDone
+from twisted.web.iweb import UNKNOWN_LENGTH
+
+from synapse.api.errors import SynapseError
+from synapse.http.client import (
+ BlacklistingAgentWrapper,
+ BlacklistingReactorWrapper,
+ BodyExceededMaxSize,
+ read_body_with_max_size,
+)
+
+from tests.server import FakeTransport, get_clock
+from tests.unittest import TestCase
+
+
+class ReadBodyWithMaxSizeTests(TestCase):
+ def _build_response(self, length=UNKNOWN_LENGTH):
+ """Start reading the body, returns the response, result and proto"""
+ response = Mock(length=length)
+ result = BytesIO()
+ deferred = read_body_with_max_size(response, result, 6)
+
+ # Fish the protocol out of the response.
+ protocol = response.deliverBody.call_args[0][0]
+ protocol.transport = Mock()
+
+ return result, deferred, protocol
+
+ def _assert_error(self, deferred, protocol):
+ """Ensure that the expected error is received."""
+ self.assertIsInstance(deferred.result, Failure)
+ self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
+ protocol.transport.abortConnection.assert_called_once()
+
+ def _cleanup_error(self, deferred):
+ """Ensure that the error in the Deferred is handled gracefully."""
+ called = [False]
+
+ def errback(f):
+ called[0] = True
+
+ deferred.addErrback(errback)
+ self.assertTrue(called[0])
+
+ def test_no_error(self):
+ """A response that is NOT too large."""
+ result, deferred, protocol = self._build_response()
+
+ # Start sending data.
+ protocol.dataReceived(b"12345")
+ # Close the connection.
+ protocol.connectionLost(Failure(ResponseDone()))
+
+ self.assertEqual(result.getvalue(), b"12345")
+ self.assertEqual(deferred.result, 5)
+
+ def test_too_large(self):
+ """A response which is too large raises an exception."""
+ result, deferred, protocol = self._build_response()
+
+ # Start sending data.
+ protocol.dataReceived(b"1234567890")
+
+ self.assertEqual(result.getvalue(), b"1234567890")
+ self._assert_error(deferred, protocol)
+ self._cleanup_error(deferred)
+
+ def test_multiple_packets(self):
+ """Data should be accumulated through mutliple packets."""
+ result, deferred, protocol = self._build_response()
+
+ # Start sending data.
+ protocol.dataReceived(b"12")
+ protocol.dataReceived(b"34")
+ # Close the connection.
+ protocol.connectionLost(Failure(ResponseDone()))
+
+ self.assertEqual(result.getvalue(), b"1234")
+ self.assertEqual(deferred.result, 4)
+
+ def test_additional_data(self):
+ """A connection can receive data after being closed."""
+ result, deferred, protocol = self._build_response()
+
+ # Start sending data.
+ protocol.dataReceived(b"1234567890")
+ self._assert_error(deferred, protocol)
+
+ # More data might have come in.
+ protocol.dataReceived(b"1234567890")
+
+ self.assertEqual(result.getvalue(), b"1234567890")
+ self._assert_error(deferred, protocol)
+ self._cleanup_error(deferred)
+
+ def test_content_length(self):
+ """The body shouldn't be read (at all) if the Content-Length header is too large."""
+ result, deferred, protocol = self._build_response(length=10)
+
+ # Deferred shouldn't be called yet.
+ self.assertFalse(deferred.called)
+
+ # Start sending data.
+ protocol.dataReceived(b"12345")
+ self._assert_error(deferred, protocol)
+ self._cleanup_error(deferred)
+
+ # The data is never consumed.
+ self.assertEqual(result.getvalue(), b"")
+
+
+class BlacklistingAgentTest(TestCase):
+ def setUp(self):
+ self.reactor, self.clock = get_clock()
+
+ self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
+ self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
+ self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
+
+ # Configure the reactor's DNS resolver.
+ for (domain, ip) in (
+ (self.safe_domain, self.safe_ip),
+ (self.unsafe_domain, self.unsafe_ip),
+ (self.allowed_domain, self.allowed_ip),
+ ):
+ self.reactor.lookups[domain.decode()] = ip.decode()
+ self.reactor.lookups[ip.decode()] = ip.decode()
+
+ self.ip_whitelist = IPSet([self.allowed_ip.decode()])
+ self.ip_blacklist = IPSet(["5.0.0.0/8"])
+
+ def test_reactor(self):
+ """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
+ agent = Agent(
+ BlacklistingReactorWrapper(
+ self.reactor,
+ ip_whitelist=self.ip_whitelist,
+ ip_blacklist=self.ip_blacklist,
+ ),
+ )
+
+ # The unsafe domains and IPs should be rejected.
+ for domain in (self.unsafe_domain, self.unsafe_ip):
+ self.failureResultOf(
+ agent.request(b"GET", b"http://" + domain), DNSLookupError
+ )
+
+ # The safe domains IPs should be accepted.
+ for domain in (
+ self.safe_domain,
+ self.allowed_domain,
+ self.safe_ip,
+ self.allowed_ip,
+ ):
+ d = agent.request(b"GET", b"http://" + domain)
+
+ # Grab the latest TCP connection.
+ (
+ host,
+ port,
+ client_factory,
+ _timeout,
+ _bindAddress,
+ ) = self.reactor.tcpClients[-1]
+
+ # Make the connection and pump data through it.
+ client = client_factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
+ )
+
+ response = self.successResultOf(d)
+ self.assertEqual(response.code, 200)
+
+ def test_agent(self):
+ """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
+ agent = BlacklistingAgentWrapper(
+ Agent(self.reactor),
+ ip_whitelist=self.ip_whitelist,
+ ip_blacklist=self.ip_blacklist,
+ )
+
+ # The unsafe IPs should be rejected.
+ self.failureResultOf(
+ agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
+ )
+
+ # The safe and unsafe domains and safe IPs should be accepted.
+ for domain in (
+ self.safe_domain,
+ self.unsafe_domain,
+ self.allowed_domain,
+ self.safe_ip,
+ self.allowed_ip,
+ ):
+ d = agent.request(b"GET", b"http://" + domain)
+
+ # Grab the latest TCP connection.
+ (
+ host,
+ port,
+ client_factory,
+ _timeout,
+ _bindAddress,
+ ) = self.reactor.tcpClients[-1]
+
+ # Make the connection and pump data through it.
+ client = client_factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
+ )
+
+ response = self.successResultOf(d)
+ self.assertEqual(response.code, 200)
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
index b2e9533b07..d06ea518ce 100644
--- a/tests/http/test_endpoint.py
+++ b/tests/http/test_endpoint.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name
+from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name
from tests import unittest
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 212484a7fe..9c52c8fdca 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -560,4 +560,4 @@ class FederationClientTests(HomeserverTestCase):
self.pump()
f = self.failureResultOf(test_d)
- self.assertIsInstance(f.value, ValueError)
+ self.assertIsInstance(f.value, RequestSendFailed)
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 25d07704e1..3ea8b5bec7 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -184,7 +184,9 @@ class MatrixFederationAgentTests(TestCase):
)
def test_https_request_via_no_proxy(self):
agent = ProxyAgent(
- self.reactor, contextFactory=get_test_https_policy(), use_proxy=True,
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ use_proxy=True,
)
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
@@ -196,7 +198,9 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
def test_https_request_via_no_proxy_star(self):
agent = ProxyAgent(
- self.reactor, contextFactory=get_test_https_policy(), use_proxy=True,
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ use_proxy=True,
)
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
@@ -252,10 +256,13 @@ class MatrixFederationAgentTests(TestCase):
self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
def _do_https_request_via_proxy(
- self, auth_credentials: Optional[str] = None,
+ self,
+ auth_credentials: Optional[str] = None,
):
agent = ProxyAgent(
- self.reactor, contextFactory=get_test_https_policy(), use_proxy=True,
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ use_proxy=True,
)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 73f469b802..48a74e2eee 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -18,30 +18,35 @@ import logging
from io import StringIO
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
+from synapse.logging.context import LoggingContext, LoggingContextFilter
from tests.logging import LoggerCleanupMixin
from tests.unittest import TestCase
class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
+ def setUp(self):
+ self.output = StringIO()
+
+ def get_log_line(self):
+ # One log message, with a single trailing newline.
+ data = self.output.getvalue()
+ logs = data.splitlines()
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(data.count("\n"), 1)
+ return json.loads(logs[0])
+
def test_terse_json_output(self):
"""
The Terse JSON formatter converts log messages to JSON.
"""
- output = StringIO()
-
- handler = logging.StreamHandler(output)
+ handler = logging.StreamHandler(self.output)
handler.setFormatter(TerseJsonFormatter())
logger = self.get_logger(handler)
logger.info("Hello there, %s!", "wally")
- # One log message, with a single trailing newline.
- data = output.getvalue()
- logs = data.splitlines()
- self.assertEqual(len(logs), 1)
- self.assertEqual(data.count("\n"), 1)
- log = json.loads(logs[0])
+ log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
@@ -57,9 +62,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"""
Additional information can be included in the structured logging.
"""
- output = StringIO()
-
- handler = logging.StreamHandler(output)
+ handler = logging.StreamHandler(self.output)
handler.setFormatter(TerseJsonFormatter())
logger = self.get_logger(handler)
@@ -67,12 +70,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
)
- # One log message, with a single trailing newline.
- data = output.getvalue()
- logs = data.splitlines()
- self.assertEqual(len(logs), 1)
- self.assertEqual(data.count("\n"), 1)
- log = json.loads(logs[0])
+ log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
@@ -96,26 +94,44 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"""
The Terse JSON formatter converts log messages to JSON.
"""
- output = StringIO()
-
- handler = logging.StreamHandler(output)
+ handler = logging.StreamHandler(self.output)
handler.setFormatter(JsonFormatter())
logger = self.get_logger(handler)
logger.info("Hello there, %s!", "wally")
- # One log message, with a single trailing newline.
- data = output.getvalue()
- logs = data.splitlines()
- self.assertEqual(len(logs), 1)
- self.assertEqual(data.count("\n"), 1)
- log = json.loads(logs[0])
+ log = self.get_log_line()
+
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
+
+ def test_with_context(self):
+ """
+ The logging context should be added to the JSON response.
+ """
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(JsonFormatter())
+ handler.addFilter(LoggingContextFilter())
+ logger = self.get_logger(handler)
+
+ with LoggingContext(request="test"):
+ logger.info("Hello there, %s!", "wally")
+
+ log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
"log",
"level",
"namespace",
+ "request",
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
+ self.assertEqual(log["request"], "test")
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 27206ca3db..edacd1b566 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -100,7 +100,10 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the event was sent
self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
- expected_requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
+ expected_requester,
+ event_dict,
+ ratelimit=False,
+ ignore_shadow_ban=True,
)
# Create and send a state event
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index bcdcafa5a9..941cf42429 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -21,6 +21,7 @@ import pkg_resources
from twisted.internet.defer import Deferred
import synapse.rest.admin
+from synapse.api.errors import Codes, SynapseError
from synapse.rest.client.v1 import login, room
from tests.unittest import HomeserverTestCase
@@ -100,12 +101,19 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token)
)
- token_id = user_tuple.token_id
+ self.token_id = user_tuple.token_id
+
+ # We need to add email to account before we can create a pusher.
+ self.get_success(
+ hs.get_datastore().user_add_threepid(
+ self.user_id, "email", "a@example.com", 0, 0
+ )
+ )
self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher(
user_id=self.user_id,
- access_token=token_id,
+ access_token=self.token_id,
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
@@ -116,6 +124,28 @@ class EmailPusherTests(HomeserverTestCase):
)
)
+ def test_need_validated_email(self):
+ """Test that we can only add an email pusher if the user has validated
+ their email.
+ """
+ with self.assertRaises(SynapseError) as cm:
+ self.get_success_or_raise(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=self.user_id,
+ access_token=self.token_id,
+ kind="email",
+ app_id="m.email",
+ app_display_name="Email Notifications",
+ device_display_name="b@example.com",
+ pushkey="b@example.com",
+ lang=None,
+ data={},
+ )
+ )
+
+ self.assertEqual(400, cm.exception.code)
+ self.assertEqual(Codes.THREEPID_NOT_FOUND, cm.exception.errcode)
+
def test_simple_sends_email(self):
# Create a simple room with two users
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
@@ -124,13 +154,18 @@ class EmailPusherTests(HomeserverTestCase):
)
self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
- # The other user sends some messages
+ # The other user sends a single message.
self.helper.send(room, body="Hi!", tok=self.others[0].token)
- self.helper.send(room, body="There!", tok=self.others[0].token)
# We should get emailed about that message
self._check_for_mail()
+ # The other user sends multiple messages.
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+ self.helper.send(room, body="There!", tok=self.others[0].token)
+
+ self._check_for_mail()
+
def test_invite_sends_email(self):
# Create a room and invite the user to it
room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
@@ -187,6 +222,75 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages
self._check_for_mail()
+ def test_multiple_rooms(self):
+ # We want to test multiple notifications from multiple rooms, so we pause
+ # processing of push while we send messages.
+ self.pusher._pause_processing()
+
+ # Create a simple room with multiple other users
+ rooms = [
+ self.helper.create_room_as(self.user_id, tok=self.access_token),
+ self.helper.create_room_as(self.user_id, tok=self.access_token),
+ ]
+
+ for r, other in zip(rooms, self.others):
+ self.helper.invite(
+ room=r, src=self.user_id, tok=self.access_token, targ=other.id
+ )
+ self.helper.join(room=r, user=other.id, tok=other.token)
+
+ # The other users send some messages
+ self.helper.send(rooms[0], body="Hi!", tok=self.others[0].token)
+ self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+ self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+
+ # Nothing should have happened yet, as we're paused.
+ assert not self.email_attempts
+
+ self.pusher._resume_processing()
+
+ # We should get emailed about those messages
+ self._check_for_mail()
+
+ def test_empty_room(self):
+ """All users leaving a room shouldn't cause the pusher to break."""
+ # Create a simple room with two users
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+ )
+ self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+ # The other user sends a single message.
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+
+ # Leave the room before the message is processed.
+ self.helper.leave(room, self.user_id, tok=self.access_token)
+ self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
+
+ # We should get emailed about that message
+ self._check_for_mail()
+
+ def test_empty_room_multiple_messages(self):
+ """All users leaving a room shouldn't cause the pusher to break."""
+ # Create a simple room with two users
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+ )
+ self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+ # The other user sends a single message.
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+ self.helper.send(room, body="There!", tok=self.others[0].token)
+
+ # Leave the room before the message is processed.
+ self.helper.leave(room, self.user_id, tok=self.access_token)
+ self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
+
+ # We should get emailed about that message
+ self._check_for_mail()
+
def test_encrypted_message(self):
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.helper.invite(
@@ -209,7 +313,7 @@ class EmailPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- last_stream_ordering = pushers[0]["last_stream_ordering"]
+ last_stream_ordering = pushers[0].last_stream_ordering
# Advance time a bit, so the pusher will register something has happened
self.pump(10)
@@ -220,7 +324,7 @@ class EmailPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+ self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
# One email was attempted to be sent
self.assertEqual(len(self.email_attempts), 1)
@@ -238,4 +342,7 @@ class EmailPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+ self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
+
+ # Reset the attempts.
+ self.email_attempts = []
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 870850d1ec..a3b304d316 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -18,6 +18,7 @@ from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
+from synapse.push import PusherConfigException
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import receipts
@@ -34,6 +35,11 @@ class HTTPPusherTests(HomeserverTestCase):
user_id = True
hijack_auth = False
+ def default_config(self):
+ config = super().default_config()
+ config["start_pushers"] = True
+ return config
+
def make_homeserver(self, reactor, clock):
self.push_attempts = []
@@ -46,14 +52,48 @@ class HTTPPusherTests(HomeserverTestCase):
m.post_json_get_json = post_json_get_json
- config = self.default_config()
- config["start_pushers"] = True
+ hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m)
+
+ return hs
- hs = self.setup_test_homeserver(
- config=config, proxied_blacklisted_http_client=m
+ def test_invalid_configuration(self):
+ """Invalid push configurations should be rejected."""
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
)
+ token_id = user_tuple.token_id
- return hs
+ def test_data(data):
+ self.get_failure(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data=data,
+ ),
+ PusherConfigException,
+ )
+
+ # Data must be provided with a URL.
+ test_data(None)
+ test_data({})
+ test_data({"url": 1})
+ # A bare domain name isn't accepted.
+ test_data({"url": "example.com"})
+ # A URL without a path isn't accepted.
+ test_data({"url": "http://example.com"})
+ # A url with an incorrect path isn't accepted.
+ test_data({"url": "http://example.com/foo"})
def test_sends_http(self):
"""
@@ -84,7 +124,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -104,7 +144,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- last_stream_ordering = pushers[0]["last_stream_ordering"]
+ last_stream_ordering = pushers[0].last_stream_ordering
# Advance time a bit, so the pusher will register something has happened
self.pump()
@@ -115,11 +155,13 @@ class HTTPPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+ self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
# One push was attempted to be sent -- it'll be the first message
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(
self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!"
)
@@ -134,12 +176,14 @@ class HTTPPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
- last_stream_ordering = pushers[0]["last_stream_ordering"]
+ self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
+ last_stream_ordering = pushers[0].last_stream_ordering
# Now it'll try and send the second push message, which will be the second one
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(
self.push_attempts[1][2]["notification"]["content"]["body"], "There!"
)
@@ -154,7 +198,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+ self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
def test_sends_high_priority_for_encrypted(self):
"""
@@ -196,7 +240,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -232,7 +276,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Add yet another person — we want to make this room not a 1:1
@@ -270,7 +316,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_one_to_one_only(self):
@@ -312,7 +360,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -328,7 +376,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority — this is a one-to-one room
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Yet another user joins
@@ -347,7 +397,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
# check that this is high-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
@@ -394,7 +446,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -410,7 +462,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Send another event, this time with no mention
@@ -419,7 +473,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
# check that this is high-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
@@ -467,7 +523,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -487,7 +543,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Send another event, this time as someone without the power of @room
@@ -498,7 +556,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
# check that this is high-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
@@ -572,7 +632,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -591,7 +651,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
# Check that the unread count for the room is 0
#
@@ -605,7 +667,7 @@ class HTTPPusherTests(HomeserverTestCase):
# This will actually trigger a new notification to be sent out so that
# even if the user does not receive another message, their unread
# count goes down
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
{},
diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py
new file mode 100644
index 0000000000..aff563919d
--- /dev/null
+++ b/tests/push/test_presentable_names.py
@@ -0,0 +1,229 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable, Optional, Tuple
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.events import FrozenEvent
+from synapse.push.presentable_names import calculate_room_name
+from synapse.types import StateKey, StateMap
+
+from tests import unittest
+
+
+class MockDataStore:
+ """
+ A fake data store which stores a mapping of state key to event content.
+ (I.e. the state key is used as the event ID.)
+ """
+
+ def __init__(self, events: Iterable[Tuple[StateKey, dict]]):
+ """
+ Args:
+ events: A state map to event contents.
+ """
+ self._events = {}
+
+ for i, (event_id, content) in enumerate(events):
+ self._events[event_id] = FrozenEvent(
+ {
+ "event_id": "$event_id",
+ "type": event_id[0],
+ "sender": "@user:test",
+ "state_key": event_id[1],
+ "room_id": "#room:test",
+ "content": content,
+ "origin_server_ts": i,
+ },
+ RoomVersions.V1,
+ )
+
+ async def get_event(
+ self, event_id: StateKey, allow_none: bool = False
+ ) -> Optional[FrozenEvent]:
+ assert allow_none, "Mock not configured for allow_none = False"
+
+ return self._events.get(event_id)
+
+ async def get_events(self, event_ids: Iterable[StateKey]):
+ # This is cheating since it just returns all events.
+ return self._events
+
+
+class PresentableNamesTestCase(unittest.HomeserverTestCase):
+ USER_ID = "@test:test"
+ OTHER_USER_ID = "@user:test"
+
+ def _calculate_room_name(
+ self,
+ events: StateMap[dict],
+ user_id: str = "",
+ fallback_to_members: bool = True,
+ fallback_to_single_member: bool = True,
+ ):
+ # This isn't 100% accurate, but works with MockDataStore.
+ room_state_ids = {k[0]: k[0] for k in events}
+
+ return self.get_success(
+ calculate_room_name(
+ MockDataStore(events),
+ room_state_ids,
+ user_id or self.USER_ID,
+ fallback_to_members,
+ fallback_to_single_member,
+ )
+ )
+
+ def test_name(self):
+ """A room name event should be used."""
+ events = [
+ ((EventTypes.Name, ""), {"name": "test-name"}),
+ ]
+ self.assertEqual("test-name", self._calculate_room_name(events))
+
+ # Check if the event content has garbage.
+ events = [((EventTypes.Name, ""), {"foo": 1})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ events = [((EventTypes.Name, ""), {"name": 1})]
+ self.assertEqual(1, self._calculate_room_name(events))
+
+ def test_canonical_alias(self):
+ """An canonical alias should be used."""
+ events = [
+ ((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}),
+ ]
+ self.assertEqual("#test-name:test", self._calculate_room_name(events))
+
+ # Check if the event content has garbage.
+ events = [((EventTypes.CanonicalAlias, ""), {"foo": 1})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ def test_invite(self):
+ """An invite has special behaviour."""
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}),
+ ]
+ self.assertEqual("Invite from Other User", self._calculate_room_name(events))
+ self.assertIsNone(
+ self._calculate_room_name(events, fallback_to_single_member=False)
+ )
+ # Ensure this logic is skipped if we don't fallback to members.
+ self.assertIsNone(self._calculate_room_name(events, fallback_to_members=False))
+
+ # Check if the event content has garbage.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+ ]
+ self.assertEqual("Invite from @user:test", self._calculate_room_name(events))
+
+ # No member event for sender.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ]
+ self.assertEqual("Room Invite", self._calculate_room_name(events))
+
+ def test_no_members(self):
+ """Behaviour of an empty room."""
+ events = []
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ # Note that events with invalid (or missing) membership are ignored.
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+ ((EventTypes.Member, "@foo:test"), {"membership": "foo"}),
+ ]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ def test_no_other_members(self):
+ """Behaviour of a room with no other members in it."""
+ events = [
+ (
+ (EventTypes.Member, self.USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Me"},
+ ),
+ ]
+ self.assertEqual("Me", self._calculate_room_name(events))
+
+ # Check if the event content has no displayname.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual("@test:test", self._calculate_room_name(events))
+
+ # 3pid invite, use the other user (who is set as the sender).
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual(
+ "nobody", self._calculate_room_name(events, user_id=self.OTHER_USER_ID)
+ )
+
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+ ((EventTypes.ThirdPartyInvite, self.OTHER_USER_ID), {}),
+ ]
+ self.assertEqual(
+ "Inviting email address",
+ self._calculate_room_name(events, user_id=self.OTHER_USER_ID),
+ )
+
+ def test_one_other_member(self):
+ """Behaviour of a room with a single other member."""
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Other User"},
+ ),
+ ]
+ self.assertEqual("Other User", self._calculate_room_name(events))
+ self.assertIsNone(
+ self._calculate_room_name(events, fallback_to_single_member=False)
+ )
+
+ # Check if the event content has no displayname and is an invite.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.INVITE},
+ ),
+ ]
+ self.assertEqual("@user:test", self._calculate_room_name(events))
+
+ def test_other_members(self):
+ """Behaviour of a room with multiple other members."""
+ # Two other members.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Other User"},
+ ),
+ ((EventTypes.Member, "@foo:test"), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual("Other User and @foo:test", self._calculate_room_name(events))
+
+ # Three or more other members.
+ events.append(
+ ((EventTypes.Member, "@fourth:test"), {"membership": Membership.INVITE})
+ )
+ self.assertEqual("Other User and 2 others", self._calculate_room_name(events))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 1f4b5ca2ac..4a841f5bb8 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -29,7 +29,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"type": "m.room.history_visibility",
"sender": "@user:test",
"state_key": "",
- "room_id": "@room:test",
+ "room_id": "#room:test",
"content": content,
},
RoomVersions.V1,
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 728de28277..1d4a592862 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, List, Optional, Tuple
-
-import attr
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
+from twisted.web.server import Request, Site
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@@ -28,10 +28,13 @@ from synapse.app.generic_worker import (
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.replication.tcp.resource import (
+ ReplicationStreamProtocolFactory,
+ ServerReplicationStreamProtocol,
+)
from synapse.server import HomeServer
from synapse.util import Clock
@@ -41,7 +44,7 @@ from tests.server import FakeTransport
try:
import hiredis
except ImportError:
- hiredis = None
+ hiredis = None # type: ignore
logger = logging.getLogger(__name__)
@@ -54,15 +57,13 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
- servlets = [
- streams.register_servlets,
- ]
-
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
- self.server = server_factory.buildProtocol(None)
+ self.server = server_factory.buildProtocol(
+ None
+ ) # type: ServerReplicationStreamProtocol
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -82,12 +83,21 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
- self.worker_hs, "client", "test", clock, repl_handler,
+ self.worker_hs,
+ "client",
+ "test",
+ clock,
+ repl_handler,
)
self._client_transport = None
self._server_transport = None
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+ return d
+
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
@@ -146,12 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
- request_factory = OneShotRequestFactory()
-
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor)
- channel.requestFactory = request_factory
- channel.site = self.site
+ channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -173,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
- return request_factory.request
+ return channel.request
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@@ -182,8 +188,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
fetching updates for given stream.
"""
+ path = request.path # type: bytes # type: ignore
self.assertRegex(
- request.path,
+ path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
% (stream_name.encode("ascii"),),
)
@@ -210,6 +217,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Fake in memory Redis server that servers can connect to.
self._redis_server = FakeRedisPubSubServer()
+ # We may have an attempt to connect to redis for the external cache already.
+ self.connect_any_redis_attempts()
+
store = self.hs.get_datastore()
self.database_pool = store.db_pool
@@ -223,7 +233,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
- "localhost", 6379, self.connect_any_redis_attempts,
+ b"localhost",
+ 6379,
+ self.connect_any_redis_attempts,
)
self.hs.get_tcp_replication().start_replication(self.hs)
@@ -241,8 +253,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
)
def create_test_resource(self):
- """Overrides `HomeserverTestCase.create_test_resource`.
- """
+ """Overrides `HomeserverTestCase.create_test_resource`."""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
# subclassses.
@@ -291,7 +302,10 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if instance_loc.host not in self.reactor.lookups:
raise Exception(
"Host does not have an IP for instance_map[%r].host = %r"
- % (instance_name, instance_loc.host,)
+ % (
+ instance_name,
+ instance_loc.host,
+ )
)
self.reactor.add_tcp_client_callback(
@@ -310,7 +324,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if not worker_hs.config.redis_enabled:
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
- worker_hs, "client", "test", self.clock, repl_handler,
+ worker_hs,
+ "client",
+ "test",
+ self.clock,
+ repl_handler,
)
server = self.server_factory.buildProtocol(None)
@@ -370,12 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
- request_factory = OneShotRequestFactory()
-
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor)
- channel.requestFactory = request_factory
- channel.site = self._hs_to_site[hs]
+ channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -399,25 +413,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
fake one.
"""
clients = self.reactor.tcpClients
- self.assertEqual(len(clients), 1)
- (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
- self.assertEqual(host, "localhost")
- self.assertEqual(port, 6379)
+ while clients:
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, b"localhost")
+ self.assertEqual(port, 6379)
- client_protocol = client_factory.buildProtocol(None)
- server_protocol = self._redis_server.buildProtocol(None)
+ client_protocol = client_factory.buildProtocol(None)
+ server_protocol = self._redis_server.buildProtocol(None)
- client_to_server_transport = FakeTransport(
- server_protocol, self.reactor, client_protocol
- )
- client_protocol.makeConnection(client_to_server_transport)
-
- server_to_client_transport = FakeTransport(
- client_protocol, self.reactor, server_protocol
- )
- server_protocol.makeConnection(server_to_client_transport)
+ client_to_server_transport = FakeTransport(
+ server_protocol, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
- return client_to_server_transport, server_to_client_transport
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, server_protocol
+ )
+ server_protocol.makeConnection(server_to_client_transport)
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -435,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
self.received_rdata_rows.append((stream_name, token, r))
-@attr.s()
-class OneShotRequestFactory:
- """A simple request factory that generates a single `SynapseRequest` and
- stores it for future use. Can only be used once.
- """
-
- request = attr.ib(default=None)
-
- def __call__(self, *args, **kwargs):
- assert self.request is None
-
- self.request = SynapseRequest(*args, **kwargs)
- return self.request
-
-
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
@@ -460,9 +457,13 @@ class _PushHTTPChannel(HTTPChannel):
makes it very hard to test.
"""
- def __init__(self, reactor: IReactorTime):
+ def __init__(
+ self, reactor: IReactorTime, request_factory: Type[Request], site: Site
+ ):
super().__init__()
self.reactor = reactor
+ self.requestFactory = request_factory
+ self.site = site
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
@@ -482,17 +483,20 @@ class _PushHTTPChannel(HTTPChannel):
self._pull_to_push_producer.stop()
def checkPersistence(self, request, version):
- """Check whether the connection can be re-used
- """
+ """Check whether the connection can be re-used"""
# We hijack this to always say no for ease of wiring stuff up in
# `handle_http_replication_attempt`.
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False
+ def requestDone(self, request):
+ # Store the request for inspection.
+ self.request = request
+ super().requestDone(request)
+
class _PullToPushProducer:
- """A push producer that wraps a pull producer.
- """
+ """A push producer that wraps a pull producer."""
def __init__(
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
@@ -509,39 +513,33 @@ class _PullToPushProducer:
self._start_loop()
def _start_loop(self):
- """Start the looping call to
- """
+ """Start the looping call to"""
if not self._looping_call:
# Start a looping call which runs every tick.
self._looping_call = self._clock.looping_call(self._run_once, 0)
def stop(self):
- """Stops calling resumeProducing.
- """
+ """Stops calling resumeProducing."""
if self._looping_call:
self._looping_call.stop()
self._looping_call = None
def pauseProducing(self):
- """Implements IPushProducer
- """
+ """Implements IPushProducer"""
self.stop()
def resumeProducing(self):
- """Implements IPushProducer
- """
+ """Implements IPushProducer"""
self._start_loop()
def stopProducing(self):
- """Implements IPushProducer
- """
+ """Implements IPushProducer"""
self.stop()
self._producer.stopProducing()
def _run_once(self):
- """Calls resumeProducing on producer once.
- """
+ """Calls resumeProducing on producer once."""
try:
self._producer.resumeProducing()
@@ -556,25 +554,21 @@ class _PullToPushProducer:
class FakeRedisPubSubServer:
- """A fake Redis server for pub/sub.
- """
+ """A fake Redis server for pub/sub."""
def __init__(self):
self._subscribers = set()
def add_subscriber(self, conn):
- """A connection has called SUBSCRIBE
- """
+ """A connection has called SUBSCRIBE"""
self._subscribers.add(conn)
def remove_subscriber(self, conn):
- """A connection has called UNSUBSCRIBE
- """
+ """A connection has called UNSUBSCRIBE"""
self._subscribers.discard(conn)
def publish(self, conn, channel, msg) -> int:
- """A connection want to publish a message to subscribers.
- """
+ """A connection want to publish a message to subscribers."""
for sub in self._subscribers:
sub.send(["message", channel, msg])
@@ -585,8 +579,9 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
- """A connection from a client talking to the fake Redis server.
- """
+ """A connection from a client talking to the fake Redis server."""
+
+ transport = None # type: Optional[FakeTransport]
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
@@ -610,8 +605,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.handle_command(msg[0], *msg[1:])
def handle_command(self, command, *args):
- """Received a Redis command from the client.
- """
+ """Received a Redis command from the client."""
# We currently only support pub/sub.
if command == b"PUBLISH":
@@ -622,12 +616,19 @@ class FakeRedisPubSubProtocol(Protocol):
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
+
+ # Since we use SET/GET to cache things we can safely no-op them.
+ elif command == b"SET":
+ self.send("OK")
+ elif command == b"GET":
+ self.send(None)
else:
raise Exception("Unknown command")
def send(self, msg):
- """Send a message back to the client.
- """
+ """Send a message back to the client."""
+ assert self.transport is not None
+
raw = self.encode(msg).encode("utf-8")
self.transport.write(raw)
@@ -643,6 +644,8 @@ class FakeRedisPubSubProtocol(Protocol):
# We assume bytes are just unicode strings.
obj = obj.decode("utf-8")
+ if obj is None:
+ return "$-1\r\n"
if isinstance(obj, str):
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
if isinstance(obj, int):
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index c0ee1cfbd6..0ceb0f935c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -66,7 +66,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.get_success(
self.master_store.store_room(
- ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1,
+ ROOM_ID,
+ USER_ID,
+ is_public=False,
+ room_version=RoomVersions.V1,
)
)
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 6a5116dd2a..153634d4ee 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -23,8 +23,7 @@ from tests.replication._base import BaseStreamTestCase
class AccountDataStreamTestCase(BaseStreamTestCase):
def test_update_function_room_account_data_limit(self):
- """Test replication with many room account data updates
- """
+ """Test replication with many room account data updates"""
store = self.hs.get_datastore()
# generate lots of account data updates
@@ -70,8 +69,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
def test_update_function_global_account_data_limit(self):
- """Test replication with many global account data updates
- """
+ """Test replication with many global account data updates"""
store = self.hs.get_datastore()
# generate lots of account data updates
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index bad0df08cf..77856fc304 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -129,7 +129,10 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
pls["users"][OTHER_USER] = 50
self.helper.send_state(
- self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
+ self.room_id,
+ EventTypes.PowerLevels,
+ pls,
+ tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
@@ -255,8 +258,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertIsNone(sr.event_id)
def test_update_function_state_row_limit(self):
- """Test replication with many state events over several stream ids.
- """
+ """Test replication with many state events over several stream ids."""
# we want to generate lots of state changes, but for this test, we want to
# spread out the state changes over a few stream IDs.
@@ -282,7 +284,10 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
pls["users"].update({u: 50 for u in user_ids})
self.helper.send_state(
- self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
+ self.room_id,
+ EventTypes.PowerLevels,
+ pls,
+ tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
index 2babea4e3e..aa4bf1c7e3 100644
--- a/tests/replication/tcp/streams/test_federation.py
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -24,7 +24,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
# enable federation sending on the worker
config = super()._get_worker_hs_config()
# TODO: make it so we don't need both of these
- config["send_federation"] = True
+ config["send_federation"] = False
config["worker_app"] = "synapse.app.federation_sender"
return config
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 5acfb3e53e..ca49d4dd3a 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -69,6 +69,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assert_request_is_get_repl_stream_updates(request, "typing")
# The from token should be the token from the last RDATA we got.
+ assert request.args is not None
self.assertEqual(int(request.args[b"from_token"][0]), token)
self.test_handler.on_rdata.assert_called_once()
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
index d1c15caeb0..1fe9d5b4d0 100644
--- a/tests/replication/tcp/test_remote_server_up.py
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -28,8 +28,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
self.factory = ReplicationStreamProtocolFactory(hs)
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
- """Create a new direct TCP replication connection
- """
+ """Create a new direct TCP replication connection"""
proto = self.factory.buildProtocol(("127.0.0.1", 0))
transport = StringTransport()
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
new file mode 100644
index 0000000000..f8fd8a843c
--- /dev/null
+++ b/tests/replication/test_auth.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, make_request
+from tests.unittest import override_config
+
+logger = logging.getLogger(__name__)
+
+
+class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
+ """Test the authentication of HTTP calls between workers."""
+
+ servlets = [register.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ # This isn't a real configuration option but is used to provide the main
+ # homeserver and worker homeserver different options.
+ main_replication_secret = config.pop("main_replication_secret", None)
+ if main_replication_secret:
+ config["worker_replication_secret"] = main_replication_secret
+ return self.setup_test_homeserver(config=config)
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.client_reader"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+
+ return config
+
+ def _test_register(self) -> FakeChannel:
+ """Run the actual test:
+
+ 1. Create a worker homeserver.
+ 2. Start registration by providing a user/password.
+ 3. Complete registration by providing dummy auth (this hits the main synapse).
+ 4. Return the final request.
+
+ """
+ worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ site = self._hs_to_site[worker_hs]
+
+ channel_1 = make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ )
+ self.assertEqual(channel_1.code, 401)
+
+ # Grab the session
+ session = channel_1.json_body["session"]
+
+ # also complete the dummy auth
+ return make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
+ )
+
+ def test_no_auth(self):
+ """With no authentication the request should finish."""
+ channel = self._test_register()
+ self.assertEqual(channel.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+ @override_config({"main_replication_secret": "my-secret"})
+ def test_missing_auth(self):
+ """If the main process expects a secret that is not provided, an error results."""
+ channel = self._test_register()
+ self.assertEqual(channel.code, 500)
+
+ @override_config(
+ {
+ "main_replication_secret": "my-secret",
+ "worker_replication_secret": "wrong-secret",
+ }
+ )
+ def test_unauthorized(self):
+ """If the main process receives the wrong secret, an error results."""
+ channel = self._test_register()
+ self.assertEqual(channel.code, 500)
+
+ @override_config({"worker_replication_secret": "my-secret"})
+ def test_authorized(self):
+ """The request should finish when the worker provides the authentication header."""
+ channel = self._test_register()
+ self.assertEqual(channel.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 96801db473..5da1d5dc4d 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -14,27 +14,19 @@
# limitations under the License.
import logging
-from synapse.api.constants import LoginType
-from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel, make_request
+from tests.server import make_request
logger = logging.getLogger(__name__)
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
- """Base class for tests of the replication streams"""
+ """Test using one or more client readers for registration."""
servlets = [register.register_servlets]
- def prepare(self, reactor, clock, hs):
- self.recaptcha_checker = DummyRecaptchaChecker(hs)
- auth_handler = hs.get_auth_handler()
- auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
-
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
@@ -43,65 +35,63 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
return config
def test_register_single_worker(self):
- """Test that registration works when using a single client reader worker.
- """
+ """Test that registration works when using a single client reader worker."""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]
- request_1, channel_1 = make_request(
+ channel_1 = make_request(
self.reactor,
site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request_1.code, 401)
+ )
+ self.assertEqual(channel_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
- request_2, channel_2 = make_request(
+ channel_2 = make_request(
self.reactor,
site,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request_2.code, 200)
+ )
+ self.assertEqual(channel_2.code, 200)
# We're given a registered user.
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
def test_register_multi_worker(self):
- """Test that registration works when using multiple client reader workers.
- """
+ """Test that registration works when using multiple client reader workers."""
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
site_1 = self._hs_to_site[worker_hs_1]
- request_1, channel_1 = make_request(
+ channel_1 = make_request(
self.reactor,
site_1,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request_1.code, 401)
+ )
+ self.assertEqual(channel_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
site_2 = self._hs_to_site[worker_hs_2]
- request_2, channel_2 = make_request(
+ channel_2 = make_request(
self.reactor,
site_2,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request_2.code, 200)
+ )
+ self.assertEqual(channel_2.code, 200)
# We're given a registered user.
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 1853667558..0d9e3bb11d 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -17,7 +17,7 @@ import mock
from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream
from tests.unittest import HomeserverTestCase
@@ -27,7 +27,7 @@ class FederationAckTestCase(HomeserverTestCase):
def default_config(self) -> dict:
config = super().default_config()
config["worker_app"] = "synapse.app.federation_sender"
- config["send_federation"] = True
+ config["send_federation"] = False
return config
def make_homeserver(self, reactor, clock):
@@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
"""
rch = self.hs.get_tcp_replication()
- # wire up the ReplicationCommandHandler to a mock connection
- mock_connection = mock.Mock(spec=AbstractConnection)
+ # wire up the ReplicationCommandHandler to a mock connection, which needs
+ # to implement IReplicationConnection. (Note that Mock doesn't understand
+ # interfaces, but casing an interface to a list gives the attributes.)
+ mock_connection = mock.Mock(spec=list(IReplicationConnection))
rch.new_connection(mock_connection)
# tell it it received an RDATA row
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index fffdb742c8..2f2d117858 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -49,7 +49,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.federation_sender",
- {"send_federation": True},
+ {"send_federation": False},
federation_http_client=mock_client,
)
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 48b574ccbe..b0800f9840 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -15,7 +15,7 @@
import logging
import os
from binascii import unhexlify
-from typing import Tuple
+from typing import Optional, Tuple
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
@@ -32,12 +32,11 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
logger = logging.getLogger(__name__)
-test_server_connection_factory = None
+test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory]
class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
- """Checks running multiple media repos work correctly.
- """
+ """Checks running multiple media repos work correctly."""
servlets = [
admin.register_servlets_for_client_rest_resource,
@@ -48,7 +47,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
- self.reactor.lookups["example.com"] = "127.0.0.2"
+ self.reactor.lookups["example.com"] = "1.2.3.4"
def default_config(self):
conf = super().default_config()
@@ -68,7 +67,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
the media which the caller should respond to.
"""
resource = hs.get_media_repository_resource().children[b"download"]
- _, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(resource),
"GET",
@@ -124,8 +123,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return channel, request
def test_basic(self):
- """Test basic fetching of remote media from a single worker.
- """
+ """Test basic fetching of remote media from a single worker."""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
@@ -223,16 +221,14 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(start_count + 3, self._count_remote_thumbnails())
def _count_remote_media(self) -> int:
- """Count the number of files in our remote media directory.
- """
+ """Count the number of files in our remote media directory."""
path = os.path.join(
self.hs.get_media_repository().primary_base_path, "remote_content"
)
return sum(len(files) for _, _, files in os.walk(path))
def _count_remote_thumbnails(self) -> int:
- """Count the number of files in our remote thumbnails directory.
- """
+ """Count the number of files in our remote thumbnails directory."""
path = os.path.join(
self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index f894bcd6e7..ab2988a6ba 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -27,8 +27,7 @@ logger = logging.getLogger(__name__)
class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
- """Checks pusher sharding works
- """
+ """Checks pusher sharding works"""
servlets = [
admin.register_servlets_for_client_rest_resource,
@@ -67,7 +66,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "https://push.example.com/push"},
+ data={"url": "https://push.example.com/_matrix/push/v1/notify"},
)
)
@@ -88,16 +87,15 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
return event_id
def test_send_push_single_worker(self):
- """Test that registration works when using a pusher worker.
- """
+ """Test that registration works when using a pusher worker."""
http_client_mock = Mock(spec_set=["post_json_get_json"])
- http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
- {}
+ http_client_mock.post_json_get_json.side_effect = (
+ lambda *_, **__: defer.succeed({})
)
self.make_worker_hs(
"synapse.app.pusher",
- {"start_pushers": True},
+ {"start_pushers": False},
proxied_blacklisted_http_client=http_client_mock,
)
@@ -109,7 +107,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -119,11 +117,10 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
def test_send_push_multiple_workers(self):
- """Test that registration works when using sharded pusher workers.
- """
+ """Test that registration works when using sharded pusher workers."""
http_client_mock1 = Mock(spec_set=["post_json_get_json"])
- http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
- {}
+ http_client_mock1.post_json_get_json.side_effect = (
+ lambda *_, **__: defer.succeed({})
)
self.make_worker_hs(
@@ -137,8 +134,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
http_client_mock2 = Mock(spec_set=["post_json_get_json"])
- http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
- {}
+ http_client_mock2.post_json_get_json.side_effect = (
+ lambda *_, **__: defer.succeed({})
)
self.make_worker_hs(
@@ -161,7 +158,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_not_called()
self.assertEqual(
http_client_mock1.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -183,7 +180,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock2.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 77fc3856d5..c9b773fbd2 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -29,8 +29,7 @@ logger = logging.getLogger(__name__)
class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
- """Checks event persisting sharding works
- """
+ """Checks event persisting sharding works"""
# Event persister sharding requires postgres (due to needing
# `MutliWriterIdGenerator`).
@@ -63,8 +62,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
return conf
def _create_room(self, room_id: str, user_id: str, tok: str):
- """Create a room with given room_id
- """
+ """Create a room with given room_id"""
# We control the room ID generation by patching out the
# `_generate_room_id` method
@@ -91,11 +89,13 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
"""
self.make_worker_hs(
- "synapse.app.generic_worker", {"worker_name": "worker1"},
+ "synapse.app.generic_worker",
+ {"worker_name": "worker1"},
)
self.make_worker_hs(
- "synapse.app.generic_worker", {"worker_name": "worker2"},
+ "synapse.app.generic_worker",
+ {"worker_name": "worker2"},
)
persisted_on_1 = False
@@ -139,15 +139,18 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
"""
self.make_worker_hs(
- "synapse.app.generic_worker", {"worker_name": "worker1"},
+ "synapse.app.generic_worker",
+ {"worker_name": "worker1"},
)
worker_hs2 = self.make_worker_hs(
- "synapse.app.generic_worker", {"worker_name": "worker2"},
+ "synapse.app.generic_worker",
+ {"worker_name": "worker2"},
)
sync_hs = self.make_worker_hs(
- "synapse.app.generic_worker", {"worker_name": "sync"},
+ "synapse.app.generic_worker",
+ {"worker_name": "sync"},
)
sync_hs_site = self._hs_to_site[sync_hs]
@@ -180,7 +183,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
# Do an initial sync so that we're up to date.
- request, channel = make_request(
+ channel = make_request(
self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
)
next_batch = channel.json_body["next_batch"]
@@ -206,7 +209,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Check that syncing still gets the new event, despite the gap in the
# stream IDs.
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -236,7 +239,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"]
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -261,7 +264,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -279,7 +282,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion.
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -292,7 +295,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken.
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -307,7 +310,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
# Paginating forwards should give the same results
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -318,12 +321,14 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.assertListEqual([], channel.json_body["chunk"])
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
- room_id2, vector_clock_token, prev_batch2,
+ room_id2,
+ vector_clock_token,
+ prev_batch2,
),
access_token=access_token,
)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 67d8878395..057e27372e 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -42,7 +42,7 @@ class VersionTestCase(unittest.HomeserverTestCase):
return resource
def test_version_string(self):
- request, channel = self.make_request("GET", self.url, shorthand=False)
+ channel = self.make_request("GET", self.url, shorthand=False)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -58,8 +58,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -68,7 +66,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
def test_delete_group(self):
# Create a new group
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/create_group".encode("ascii"),
access_token=self.admin_user_tok,
@@ -84,13 +82,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
# Invite/join another user
url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
url = "/groups/%s/self/accept_invite" % (group_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -101,7 +99,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
# Now delete the group
url = "/_synapse/admin/v1/delete_group/" + group_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
access_token=self.admin_user_tok,
@@ -123,7 +121,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
"""
url = "/groups/%s/profile" % (group_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
@@ -132,9 +130,8 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
)
def _get_groups_user_is_in(self, access_token):
- """Returns the list of groups the user is in (given their access token)
- """
- request, channel = self.make_request(
+ """Returns the list of groups the user is in (given their access token)"""
+ channel = self.make_request(
"GET", "/joined_groups".encode("ascii"), access_token=access_token
)
@@ -144,8 +141,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
class QuarantineMediaTestCase(unittest.HomeserverTestCase):
- """Test /quarantine_media admin API.
- """
+ """Test /quarantine_media admin API."""
servlets = [
synapse.rest.admin.register_servlets,
@@ -155,9 +151,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
- self.hs = hs
-
# Allow for uploading and downloading to/from the media repo
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
@@ -216,7 +209,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it."""
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
@@ -241,8 +234,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt quarantine media APIs as non-admin
url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
- request, channel = self.make_request(
- "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+ channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ access_token=non_admin_user_tok,
)
# Expect a forbidden error
@@ -254,8 +249,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# And the roomID/userID endpoint
url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
- request, channel = self.make_request(
- "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+ channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ access_token=non_admin_user_tok,
)
# Expect a forbidden error
@@ -282,7 +279,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_name, media_id = server_name_and_media_id.split("/")
# Attempt to access the media
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
@@ -299,7 +296,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
urllib.parse.quote(server_name),
urllib.parse.quote(media_id),
)
- request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=admin_user_tok,
+ )
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
@@ -351,7 +352,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
room_id
)
- request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=admin_user_tok,
+ )
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
self.assertEqual(
@@ -395,8 +400,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
non_admin_user
)
- request, channel = self.make_request(
- "POST", url.encode("ascii"), access_token=admin_user_tok,
+ channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ access_token=admin_user_tok,
)
self.pump(1.0)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -431,14 +438,20 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Mark the second item as safe from quarantine.
_, media_id_2 = server_and_media_id_2.split("/")
- self.get_success(self.store.mark_local_media_as_safe(media_id_2))
+ # Quarantine the media
+ url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
+ channel = self.make_request("POST", url, access_token=admin_user_tok)
+ self.pump(1.0)
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
non_admin_user
)
- request, channel = self.make_request(
- "POST", url.encode("ascii"), access_token=admin_user_tok,
+ channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ access_token=admin_user_tok,
)
self.pump(1.0)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -453,7 +466,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index cf3a007598..2a1bcf1760 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -50,17 +50,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
Try to get a device of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- request, channel = self.make_request("PUT", self.url, b"{}")
+ channel = self.make_request("PUT", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- request, channel = self.make_request("DELETE", self.url, b"{}")
+ channel = self.make_request("DELETE", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -69,22 +69,28 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error is returned.
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.other_user_token,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- request, channel = self.make_request(
- "PUT", self.url, access_token=self.other_user_token,
+ channel = self.make_request(
+ "PUT",
+ self.url,
+ access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- request, channel = self.make_request(
- "DELETE", self.url, access_token=self.other_user_token,
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -99,22 +105,28 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
% self.other_user_device_id
)
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- request, channel = self.make_request(
- "PUT", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(404, channel.code, msg=channel.json_body)
@@ -129,22 +141,28 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
% self.other_user_device_id
)
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
- request, channel = self.make_request(
- "PUT", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -158,21 +176,27 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- request, channel = self.make_request(
- "PUT", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
)
# Delete unknown device returns status 200
@@ -197,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
}
body = json.dumps(update)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
@@ -208,8 +232,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -227,15 +253,19 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
)
- request, channel = self.make_request(
- "PUT", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "PUT",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure the display name was not updated.
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -247,7 +277,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
# Set new display_name
body = json.dumps({"display_name": "new displayname"})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
@@ -257,8 +287,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
# Check new display_name
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -268,8 +300,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
Tests that a normal lookup for a device is successfully
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -291,8 +325,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(1, number_devices)
# Delete device
- request, channel = self.make_request(
- "DELETE", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -323,7 +359,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
Try to list devices of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -334,8 +370,10 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "GET", self.url, access_token=other_user_token,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -346,8 +384,10 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(404, channel.code, msg=channel.json_body)
@@ -359,8 +399,10 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -373,8 +415,10 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
# Get devices
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -391,8 +435,10 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.login("user", "pass")
# Get devices
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -431,7 +477,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
Try to delete devices of an user without authentication.
"""
- request, channel = self.make_request("POST", self.url, b"{}")
+ channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -442,8 +488,10 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "POST", self.url, access_token=other_user_token,
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -454,8 +502,10 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
- request, channel = self.make_request(
- "POST", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(404, channel.code, msg=channel.json_body)
@@ -467,8 +517,10 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
- request, channel = self.make_request(
- "POST", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -479,7 +531,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
Tests that a remove of a device that does not exist returns 200.
"""
body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
@@ -510,7 +562,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
# Delete devices
body = json.dumps({"devices": device_ids})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 11b72c10f7..e30ffe4fa0 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -32,8 +32,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -53,19 +51,23 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# Two rooms and two users. Every user sends and reports every room event
for i in range(5):
self._create_event_and_report(
- room_id=self.room_id1, user_tok=self.other_user_tok,
+ room_id=self.room_id1,
+ user_tok=self.other_user_tok,
)
for i in range(5):
self._create_event_and_report(
- room_id=self.room_id2, user_tok=self.other_user_tok,
+ room_id=self.room_id2,
+ user_tok=self.other_user_tok,
)
for i in range(5):
self._create_event_and_report(
- room_id=self.room_id1, user_tok=self.admin_user_tok,
+ room_id=self.room_id1,
+ user_tok=self.admin_user_tok,
)
for i in range(5):
self._create_event_and_report(
- room_id=self.room_id2, user_tok=self.admin_user_tok,
+ room_id=self.room_id2,
+ user_tok=self.admin_user_tok,
)
self.url = "/_synapse/admin/v1/event_reports"
@@ -74,7 +76,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
Try to get an event report without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -84,8 +86,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
If the user is not a server admin, an error 403 is returned.
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.other_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.other_user_tok,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -96,8 +100,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -111,8 +117,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with limit
"""
- request, channel = self.make_request(
- "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -126,8 +134,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a defined starting point (from)
"""
- request, channel = self.make_request(
- "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -141,8 +151,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a defined starting point and limit
"""
- request, channel = self.make_request(
- "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=5&limit=10",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -156,7 +168,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a filter of room
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?room_id=%s" % self.room_id1,
access_token=self.admin_user_tok,
@@ -176,7 +188,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a filter of user
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?user_id=%s" % self.other_user,
access_token=self.admin_user_tok,
@@ -196,7 +208,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a filter of user and room
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
access_token=self.admin_user_tok,
@@ -218,8 +230,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
# fetch the most recent first, largest timestamp
- request, channel = self.make_request(
- "GET", self.url + "?dir=b", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?dir=b",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -234,8 +248,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
report += 1
# fetch the oldest first, smallest timestamp
- request, channel = self.make_request(
- "GET", self.url + "?dir=f", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?dir=f",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -254,8 +270,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing that a invalid search order returns a 400
"""
- request, channel = self.make_request(
- "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?dir=bar",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -267,8 +285,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing that a negative limit parameter returns a 400
"""
- request, channel = self.make_request(
- "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -279,8 +299,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing that a negative from parameter returns a 400
"""
- request, channel = self.make_request(
- "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -293,8 +315,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of results is the number of entries
- request, channel = self.make_request(
- "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=20",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -304,8 +328,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of max results is larger than the number of entries
- request, channel = self.make_request(
- "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=21",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -315,8 +341,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# `next_token` does appear
# Number of max results is smaller than the number of entries
- request, channel = self.make_request(
- "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -327,8 +355,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# Check
# Set `from` to value of `next_token` for request remaining entries
# `next_token` does not appear
- request, channel = self.make_request(
- "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -337,12 +367,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body)
def _create_event_and_report(self, room_id, user_tok):
- """Create and report events
- """
+ """Create and report events"""
resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"rooms/%s/report/%s" % (room_id, event_id),
json.dumps({"score": -100, "reason": "this makes me sad"}),
@@ -351,8 +380,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def _check_fields(self, content):
- """Checks that all attributes are present in an event report
- """
+ """Checks that all attributes are present in an event report"""
for c in content:
self.assertIn("id", c)
self.assertIn("received_ts", c)
@@ -375,8 +403,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -389,7 +415,8 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
self._create_event_and_report(
- room_id=self.room_id1, user_tok=self.other_user_tok,
+ room_id=self.room_id1,
+ user_tok=self.other_user_tok,
)
# first created event report gets `id`=2
@@ -399,7 +426,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"""
Try to get event report without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -409,8 +436,10 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
If the user is not a server admin, an error 403 is returned.
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.other_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.other_user_tok,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -421,8 +450,10 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
Testing get a reported event
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -434,7 +465,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"""
# `report_id` is negative
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_synapse/admin/v1/event_reports/-123",
access_token=self.admin_user_tok,
@@ -448,7 +479,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
)
# `report_id` is a non-numerical string
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_synapse/admin/v1/event_reports/abcdef",
access_token=self.admin_user_tok,
@@ -462,7 +493,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
)
# `report_id` is undefined
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_synapse/admin/v1/event_reports/",
access_token=self.admin_user_tok,
@@ -480,7 +511,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
Testing that a not existing `report_id` returns a 404.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_synapse/admin/v1/event_reports/123",
access_token=self.admin_user_tok,
@@ -491,12 +522,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
self.assertEqual("Event report not found", channel.json_body["error"])
def _create_event_and_report(self, room_id, user_tok):
- """Create and report events
- """
+ """Create and report events"""
resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"rooms/%s/report/%s" % (room_id, event_id),
json.dumps({"score": -100, "reason": "this makes me sad"}),
@@ -505,8 +535,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def _check_fields(self, content):
- """Checks that all attributes are present in a event report
- """
+ """Checks that all attributes are present in a event report"""
self.assertIn("id", content)
self.assertIn("received_ts", content)
self.assertIn("room_id", content)
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index dadf9db660..31db472cd3 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -35,7 +35,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.handler = hs.get_device_handler()
self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname
@@ -50,7 +49,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
- request, channel = self.make_request("DELETE", url, b"{}")
+ channel = self.make_request("DELETE", url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -64,8 +63,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
- request, channel = self.make_request(
- "DELETE", url, access_token=self.other_user_token,
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -77,8 +78,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(404, channel.code, msg=channel.json_body)
@@ -90,8 +93,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -121,7 +126,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(server_name, self.server_name)
# Attempt to access media
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET",
@@ -146,18 +151,21 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id)
# Delete media
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
- media_id, channel.json_body["deleted_media"][0],
+ media_id,
+ channel.json_body["deleted_media"][0],
)
# Attempt to access media
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET",
@@ -189,7 +197,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.handler = hs.get_device_handler()
self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname
@@ -204,7 +211,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
Try to delete media without authentication.
"""
- request, channel = self.make_request("POST", self.url, b"{}")
+ channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -216,8 +223,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.other_user = self.register_user("user", "pass")
self.other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "POST", self.url, access_token=self.other_user_token,
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -229,8 +238,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
- request, channel = self.make_request(
- "POST", url + "?before_ts=1234", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "POST",
+ url + "?before_ts=1234",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -240,8 +251,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
"""
If the parameter `before_ts` is missing, an error is returned.
"""
- request, channel = self.make_request(
- "POST", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -254,8 +267,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
"""
If parameters are invalid, an error is returned.
"""
- request, channel = self.make_request(
- "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=-1234",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -265,7 +280,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=1234&size_gt=-1234",
access_token=self.admin_user_tok,
@@ -278,7 +293,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=1234&keep_profiles=not_bool",
access_token=self.admin_user_tok,
@@ -308,7 +323,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
# timestamp after upload/create
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
@@ -316,7 +331,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
- media_id, channel.json_body["deleted_media"][0],
+ media_id,
+ channel.json_body["deleted_media"][0],
)
self._access_media(server_and_media_id, False)
@@ -332,7 +348,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
@@ -344,7 +360,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
# timestamp after upload
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
@@ -352,7 +368,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
- server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ server_and_media_id.split("/")[1],
+ channel.json_body["deleted_media"][0],
)
self._access_media(server_and_media_id, False)
@@ -367,7 +384,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id)
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
access_token=self.admin_user_tok,
@@ -378,7 +395,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id)
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
access_token=self.admin_user_tok,
@@ -386,7 +403,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
- server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ server_and_media_id.split("/")[1],
+ channel.json_body["deleted_media"][0],
)
self._access_media(server_and_media_id, False)
@@ -401,7 +419,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id)
# set media as avatar
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.admin_user,),
content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}),
@@ -410,7 +428,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
@@ -421,7 +439,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id)
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
@@ -429,7 +447,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
- server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ server_and_media_id.split("/")[1],
+ channel.json_body["deleted_media"][0],
)
self._access_media(server_and_media_id, False)
@@ -445,7 +464,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
# set media as room avatar
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/state/m.room.avatar" % (room_id,),
content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}),
@@ -454,7 +473,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
@@ -465,7 +484,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id)
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
@@ -473,7 +492,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
- server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ server_and_media_id.split("/")[1],
+ channel.json_body["deleted_media"][0],
)
self._access_media(server_and_media_id, False)
@@ -512,7 +532,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
media_id = server_and_media_id.split("/")[1]
local_path = self.filepaths.local_media_filepath(media_id)
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET",
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 46933a0493..b55160b70a 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -20,6 +20,7 @@ from typing import List, Optional
from mock import Mock
import synapse.rest.admin
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes
from synapse.rest.client.v1 import directory, events, login, room
@@ -79,7 +80,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
# Test that the admin can still send shutdown
url = "/_synapse/admin/v1/shutdown_room/" + room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
@@ -103,7 +104,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
# Enable world readable
url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url.encode("ascii"),
json.dumps({"history_visibility": "world_readable"}),
@@ -113,7 +114,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
# Test that the admin can still send shutdown
url = "/_synapse/admin/v1/shutdown_room/" + room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
@@ -126,11 +127,10 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
self._assert_peek(room_id, expect_code=403)
def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
+ """Assert that the admin user can (or cannot) peek into the room."""
url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
self.assertEqual(
@@ -138,7 +138,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
)
url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
self.assertEqual(
@@ -184,8 +184,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
If the user is not a server admin, an error 403 is returned.
"""
- request, channel = self.make_request(
- "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
+ channel = self.make_request(
+ "POST",
+ self.url,
+ json.dumps({}),
+ access_token=self.other_user_tok,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -197,8 +200,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
- request, channel = self.make_request(
- "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "POST",
+ url,
+ json.dumps({}),
+ access_token=self.admin_user_tok,
)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
@@ -210,13 +216,17 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/rooms/invalidroom/delete"
- request, channel = self.make_request(
- "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "POST",
+ url,
+ json.dumps({}),
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
- "invalidroom is not a legal room ID", channel.json_body["error"],
+ "invalidroom is not a legal room ID",
+ channel.json_body["error"],
)
def test_new_room_user_does_not_exist(self):
@@ -225,7 +235,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"new_room_user_id": "@unknown:test"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -244,7 +254,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"new_room_user_id": "@not:exist.bla"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -253,7 +263,8 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
- "User must be our own: @not:exist.bla", channel.json_body["error"],
+ "User must be our own: @not:exist.bla",
+ channel.json_body["error"],
)
def test_block_is_not_bool(self):
@@ -262,7 +273,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"block": "NotBool"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -278,7 +289,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"purge": "NotBool"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -304,7 +315,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"block": True, "purge": True})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
@@ -337,7 +348,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"block": False, "purge": True})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
@@ -371,7 +382,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"block": False, "purge": False})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
@@ -418,7 +429,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Test that the admin can still send shutdown
url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
@@ -448,7 +459,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Enable world readable
url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url.encode("ascii"),
json.dumps({"history_visibility": "world_readable"}),
@@ -465,7 +476,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Test that the admin can still send shutdown
url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
@@ -490,8 +501,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._assert_peek(self.room_id, expect_code=403)
def _is_blocked(self, room_id, expect=True):
- """Assert that the room is blocked or not
- """
+ """Assert that the room is blocked or not"""
d = self.store.is_room_blocked(room_id)
if expect:
self.assertTrue(self.get_success(d))
@@ -499,20 +509,17 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.assertIsNone(self.get_success(d))
def _has_no_members(self, room_id):
- """Assert there is now no longer anyone in the room
- """
+ """Assert there is now no longer anyone in the room"""
users_in_room = self.get_success(self.store.get_users_in_room(room_id))
self.assertEqual([], users_in_room)
def _is_member(self, room_id, user_id):
- """Test that user is member of the room
- """
+ """Test that user is member of the room"""
users_in_room = self.get_success(self.store.get_users_in_room(room_id))
self.assertIn(user_id, users_in_room)
def _is_purged(self, room_id):
- """Test that the following tables have been purged of all rows related to the room.
- """
+ """Test that the following tables have been purged of all rows related to the room."""
for table in PURGE_TABLES:
count = self.get_success(
self.store.db_pool.simple_select_one_onecol(
@@ -526,11 +533,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
+ """Assert that the admin user can (or cannot) peek into the room."""
url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
self.assertEqual(
@@ -538,7 +544,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
)
url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
self.assertEqual(
@@ -547,8 +553,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
class PurgeRoomTestCase(unittest.HomeserverTestCase):
- """Test /purge_room admin API.
- """
+ """Test /purge_room admin API."""
servlets = [
synapse.rest.admin.register_servlets,
@@ -569,7 +574,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
url = "/_synapse/admin/v1/purge_room"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
{"room_id": room_id},
@@ -593,8 +598,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
class RoomTestCase(unittest.HomeserverTestCase):
- """Test /room admin API.
- """
+ """Test /room admin API."""
servlets = [
synapse.rest.admin.register_servlets,
@@ -604,8 +608,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
# Create user
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -623,8 +625,10 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Request the list of rooms
url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
# Check request completed successfully
@@ -686,7 +690,10 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Set the name of the rooms so we get a consistent returned ordering
for idx, room_id in enumerate(room_ids):
self.helper.send_state(
- room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+ room_id,
+ "m.room.name",
+ {"name": str(idx)},
+ tok=self.admin_user_tok,
)
# Request the list of rooms
@@ -704,8 +711,10 @@ class RoomTestCase(unittest.HomeserverTestCase):
limit,
"name",
)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(
200, int(channel.result["code"]), msg=channel.result["body"]
@@ -744,8 +753,10 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_ids, returned_room_ids)
url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -764,7 +775,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Create a new alias to this room
url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url.encode("ascii"),
{"room_id": room_id},
@@ -789,13 +800,18 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Set a name for the room
self.helper.send_state(
- room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+ room_id,
+ "m.room.name",
+ {"name": test_room_name},
+ tok=self.admin_user_tok,
)
# Request the list of rooms
url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -835,7 +851,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/r0/directory/room/%s" % (
urllib.parse.quote(test_alias),
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url.encode("ascii"),
{"room_id": room_id},
@@ -861,7 +877,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
def _order_test(
- order_type: str, expected_room_list: List[str], reverse: bool = False,
+ order_type: str,
+ expected_room_list: List[str],
+ reverse: bool = False,
):
"""Request the list of rooms in a certain order. Assert that order is what
we expect
@@ -875,8 +893,10 @@ class RoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
if reverse:
url += "&dir=b"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -908,13 +928,22 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
self.helper.send_state(
- room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+ room_id_1,
+ "m.room.name",
+ {"name": "A"},
+ tok=self.admin_user_tok,
)
self.helper.send_state(
- room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+ room_id_2,
+ "m.room.name",
+ {"name": "B"},
+ tok=self.admin_user_tok,
)
self.helper.send_state(
- room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+ room_id_3,
+ "m.room.name",
+ {"name": "C"},
+ tok=self.admin_user_tok,
)
# Set room canonical room aliases
@@ -991,10 +1020,16 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Set the name for each room
self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ room_id_1,
+ "m.room.name",
+ {"name": room_name_1},
+ tok=self.admin_user_tok,
)
self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ room_id_2,
+ "m.room.name",
+ {"name": room_name_2},
+ tok=self.admin_user_tok,
)
def _search_test(
@@ -1011,8 +1046,10 @@ class RoomTestCase(unittest.HomeserverTestCase):
expected_http_code: The expected http code for the request
"""
url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
@@ -1050,6 +1087,13 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(room_id_2, "else")
_search_test(room_id_2, "se")
+ # Test case insensitive
+ _search_test(room_id_1, "SOMETHING")
+ _search_test(room_id_1, "THING")
+
+ _search_test(room_id_2, "ELSE")
+ _search_test(room_id_2, "SE")
+
_search_test(None, "foo")
_search_test(None, "bar")
_search_test(None, "", expected_http_code=400)
@@ -1065,15 +1109,23 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Set the name for each room
self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ room_id_1,
+ "m.room.name",
+ {"name": room_name_1},
+ tok=self.admin_user_tok,
)
self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ room_id_2,
+ "m.room.name",
+ {"name": room_name_2},
+ tok=self.admin_user_tok,
)
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1084,6 +1136,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("canonical_alias", channel.json_body)
self.assertIn("joined_members", channel.json_body)
self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("joined_local_devices", channel.json_body)
self.assertIn("version", channel.json_body)
self.assertIn("creator", channel.json_body)
self.assertIn("encryption", channel.json_body)
@@ -1096,6 +1149,45 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id_1, channel.json_body["room_id"])
+ def test_single_room_devices(self):
+ """Test that `joined_local_devices` can be requested correctly"""
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["joined_local_devices"])
+
+ # Have another user join the room
+ user_1 = self.register_user("foo", "pass")
+ user_tok_1 = self.login("foo", "pass")
+ self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(2, channel.json_body["joined_local_devices"])
+
+ # leave room
+ self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok)
+ self.helper.leave(room_id_1, user_1, tok=user_tok_1)
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["joined_local_devices"])
+
def test_room_members(self):
"""Test that room members can be requested correctly"""
# Create two test rooms
@@ -1119,8 +1211,10 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.helper.join(room_id_2, user_3, tok=user_tok_3)
url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1130,8 +1224,10 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["total"], 3)
url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1140,6 +1236,23 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.json_body["total"], 3)
+ def test_room_state(self):
+ """Test that room state can be requested correctly"""
+ # Create two test rooms
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("state", channel.json_body)
+ # testing that the state events match is painful and not done here. We assume that
+ # the create_room already does the right thing, so no need to verify that we got
+ # the state events it created.
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
@@ -1170,7 +1283,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -1186,7 +1299,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"unknown_parameter": "@unknown:test"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -1202,7 +1315,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"user_id": "@unknown:test"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -1218,7 +1331,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"user_id": "@not:exist.bla"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -1238,7 +1351,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"user_id": self.second_user_id})
url = "/_synapse/admin/v1/join/!unknown:test"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
@@ -1255,7 +1368,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"user_id": self.second_user_id})
url = "/_synapse/admin/v1/join/invalidroom"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
@@ -1274,7 +1387,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -1286,8 +1399,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if user is a member of the room
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/joined_rooms",
+ access_token=self.second_tok,
)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
@@ -1303,7 +1418,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/join/{}".format(private_room_id)
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
@@ -1333,8 +1448,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if server admin is a member of the room
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/joined_rooms",
+ access_token=self.admin_user_tok,
)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
@@ -1344,7 +1461,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/join/{}".format(private_room_id)
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
@@ -1355,8 +1472,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if user is a member of the room
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/joined_rooms",
+ access_token=self.second_tok,
)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
@@ -1372,7 +1491,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/join/{}".format(private_room_id)
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
@@ -1384,12 +1503,233 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if user is a member of the room
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/joined_rooms",
+ access_token=self.second_tok,
)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+ def test_context_as_non_admin(self):
+ """
+ Test that, without being admin, one cannot use the context admin API
+ """
+ # Create a room.
+ user_id = self.register_user("test", "test")
+ user_tok = self.login("test", "test")
+
+ self.register_user("test_2", "test")
+ user_tok_2 = self.login("test_2", "test")
+
+ room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+ # Populate the room with events.
+ events = []
+ for i in range(30):
+ events.append(
+ self.helper.send_event(
+ room_id, "com.example.test", content={"index": i}, tok=user_tok
+ )
+ )
+
+ # Now attempt to find the context using the admin API without being admin.
+ midway = (len(events) - 1) // 2
+ for tok in [user_tok, user_tok_2]:
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/context/%s"
+ % (room_id, events[midway]["event_id"]),
+ access_token=tok,
+ )
+ self.assertEquals(
+ 403, int(channel.result["code"]), msg=channel.result["body"]
+ )
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_context_as_admin(self):
+ """
+ Test that, as admin, we can find the context of an event without having joined the room.
+ """
+
+ # Create a room. We're not part of it.
+ user_id = self.register_user("test", "test")
+ user_tok = self.login("test", "test")
+ room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+ # Populate the room with events.
+ events = []
+ for i in range(30):
+ events.append(
+ self.helper.send_event(
+ room_id, "com.example.test", content={"index": i}, tok=user_tok
+ )
+ )
+
+ # Now let's fetch the context for this room.
+ midway = (len(events) - 1) // 2
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/context/%s"
+ % (room_id, events[midway]["event_id"]),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(
+ channel.json_body["event"]["event_id"], events[midway]["event_id"]
+ )
+
+ for i, found_event in enumerate(channel.json_body["events_before"]):
+ for j, posted_event in enumerate(events):
+ if found_event["event_id"] == posted_event["event_id"]:
+ self.assertTrue(j < midway)
+ break
+ else:
+ self.fail("Event %s from events_before not found" % j)
+
+ for i, found_event in enumerate(channel.json_body["events_after"]):
+ for j, posted_event in enumerate(events):
+ if found_event["event_id"] == posted_event["event_id"]:
+ self.assertTrue(j > midway)
+ break
+ else:
+ self.fail("Event %s from events_after not found" % j)
+
+
+class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.public_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+ self.url = "/_synapse/admin/v1/rooms/{}/make_room_admin".format(
+ self.public_room_id
+ )
+
+ def test_public_room(self):
+ """Test that getting admin in a public room works."""
+ room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Now we test that we can join the room and ban a user.
+ self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
+ self.helper.change_membership(
+ room_id,
+ self.admin_user,
+ "@test:test",
+ Membership.BAN,
+ tok=self.admin_user_tok,
+ )
+
+ def test_private_room(self):
+ """Test that getting admin in a private room works and we get invited."""
+ room_id = self.helper.create_room_as(
+ self.creator,
+ tok=self.creator_tok,
+ is_public=False,
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Now we test that we can join the room (we should have received an
+ # invite) and can ban a user.
+ self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
+ self.helper.change_membership(
+ room_id,
+ self.admin_user,
+ "@test:test",
+ Membership.BAN,
+ tok=self.admin_user_tok,
+ )
+
+ def test_other_user(self):
+ """Test that giving admin in a public room works to a non-admin user works."""
+ room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ content={"user_id": self.second_user_id},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Now we test that we can join the room and ban a user.
+ self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
+ self.helper.change_membership(
+ room_id,
+ self.second_user_id,
+ "@test:test",
+ Membership.BAN,
+ tok=self.second_tok,
+ )
+
+ def test_not_enough_power(self):
+ """Test that we get a sensible error if there are no local room admins."""
+ room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+
+ # The creator drops admin rights in the room.
+ pl = self.helper.get_state(
+ room_id, EventTypes.PowerLevels, tok=self.creator_tok
+ )
+ pl["users"][self.creator] = 0
+ self.helper.send_state(
+ room_id, EventTypes.PowerLevels, body=pl, tok=self.creator_tok
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ # We expect this to fail with a 400 as there are no room admins.
+ #
+ # (Note we assert the error message to ensure that it's not denied for
+ # some other reason)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ channel.json_body["error"],
+ "No local admin user in room with power to update power levels.",
+ )
+
PURGE_TABLES = [
"current_state_events",
@@ -1419,7 +1759,6 @@ PURGE_TABLES = [
"event_push_summary",
"pusher_throttle",
"group_summary_rooms",
- "local_invites",
"room_account_data",
"room_tags",
# "state_groups", # Current impl leaves orphaned state groups around.
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 907b49f889..1f1d11f527 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -31,7 +31,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
self.media_repo = hs.get_media_repository_resource()
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -46,7 +45,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
Try to list users without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -55,8 +54,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error 403 is returned.
"""
- request, channel = self.make_request(
- "GET", self.url, json.dumps({}), access_token=self.other_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ json.dumps({}),
+ access_token=self.other_user_tok,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -67,47 +69,57 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
If parameters are invalid, an error is returned.
"""
# unkown order_by
- request, channel = self.make_request(
- "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?order_by=bar",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
- request, channel = self.make_request(
- "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
- request, channel = self.make_request(
- "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from_ts
- request, channel = self.make_request(
- "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from_ts=-1234",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative until_ts
- request, channel = self.make_request(
- "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?until_ts=-1234",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# until_ts smaller from_ts
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?from_ts=10&until_ts=5",
access_token=self.admin_user_tok,
@@ -117,16 +129,20 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# empty search term
- request, channel = self.make_request(
- "GET", self.url + "?search_term=", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?search_term=",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
- request, channel = self.make_request(
- "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?dir=bar",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -138,8 +154,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
self._create_users_with_media(10, 2)
- request, channel = self.make_request(
- "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -154,8 +172,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
self._create_users_with_media(20, 2)
- request, channel = self.make_request(
- "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -170,8 +190,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
self._create_users_with_media(20, 2)
- request, channel = self.make_request(
- "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=5&limit=10",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -190,8 +212,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of results is the number of entries
- request, channel = self.make_request(
- "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=20",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -201,8 +225,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of max results is larger than the number of entries
- request, channel = self.make_request(
- "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=21",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -212,8 +238,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# `next_token` does appear
# Number of max results is smaller than the number of entries
- request, channel = self.make_request(
- "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -223,8 +251,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# Set `from` to value of `next_token` for request remaining entries
# Check `next_token` does not appear
- request, channel = self.make_request(
- "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -238,8 +268,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
if users have no media created
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -267,10 +299,14 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# order by user_id
self._order_test("user_id", ["@user_a:test", "@user_b:test", "@user_c:test"])
self._order_test(
- "user_id", ["@user_a:test", "@user_b:test", "@user_c:test"], "f",
+ "user_id",
+ ["@user_a:test", "@user_b:test", "@user_c:test"],
+ "f",
)
self._order_test(
- "user_id", ["@user_c:test", "@user_b:test", "@user_a:test"], "b",
+ "user_id",
+ ["@user_c:test", "@user_b:test", "@user_a:test"],
+ "b",
)
# order by displayname
@@ -278,32 +314,46 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"displayname", ["@user_c:test", "@user_b:test", "@user_a:test"]
)
self._order_test(
- "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"], "f",
+ "displayname",
+ ["@user_c:test", "@user_b:test", "@user_a:test"],
+ "f",
)
self._order_test(
- "displayname", ["@user_a:test", "@user_b:test", "@user_c:test"], "b",
+ "displayname",
+ ["@user_a:test", "@user_b:test", "@user_c:test"],
+ "b",
)
# order by media_length
self._order_test(
- "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"],
+ "media_length",
+ ["@user_a:test", "@user_c:test", "@user_b:test"],
)
self._order_test(
- "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"], "f",
+ "media_length",
+ ["@user_a:test", "@user_c:test", "@user_b:test"],
+ "f",
)
self._order_test(
- "media_length", ["@user_b:test", "@user_c:test", "@user_a:test"], "b",
+ "media_length",
+ ["@user_b:test", "@user_c:test", "@user_a:test"],
+ "b",
)
# order by media_count
self._order_test(
- "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"],
+ "media_count",
+ ["@user_a:test", "@user_c:test", "@user_b:test"],
)
self._order_test(
- "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"], "f",
+ "media_count",
+ ["@user_a:test", "@user_c:test", "@user_b:test"],
+ "f",
)
self._order_test(
- "media_count", ["@user_b:test", "@user_c:test", "@user_a:test"], "b",
+ "media_count",
+ ["@user_b:test", "@user_c:test", "@user_a:test"],
+ "b",
)
def test_from_until_ts(self):
@@ -316,16 +366,20 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
ts1 = self.clock.time_msec()
# list all media when filter is not set
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media starting at `ts1` after creating first media
# result is 0
- request, channel = self.make_request(
- "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from_ts=%s" % (ts1,),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 0)
@@ -337,7 +391,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self._create_media(self.other_user_tok, 3)
# filter media between `ts1` and `ts2`
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
access_token=self.admin_user_tok,
@@ -346,8 +400,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media until `ts2` and earlier
- request, channel = self.make_request(
- "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?until_ts=%s" % (ts2,),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
@@ -356,14 +412,16 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self._create_users_with_media(20, 1)
# check without filter get all users
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
# filter user 1 and 10-19 by `user_id`
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?search_term=foo_user_1",
access_token=self.admin_user_tok,
@@ -372,7 +430,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["total"], 11)
# filter on this user in `displayname`
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?search_term=bar_user_10",
access_token=self.admin_user_tok,
@@ -382,8 +440,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["total"], 1)
# filter and get empty result
- request, channel = self.make_request(
- "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?search_term=foobar",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 0)
@@ -447,8 +507,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
url = self.url + "?order_by=%s" % (order_type,)
if dir is not None and dir in ("b", "f"):
url += "&dir=%s" % (dir,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 54d46f4bd3..79a05b519b 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -18,16 +18,20 @@ import hmac
import json
import urllib.parse
from binascii import unhexlify
+from typing import List, Optional
from mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
+from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import devices, sync
+from synapse.types import JsonDict
from tests import unittest
+from tests.server import FakeSite, make_request
from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -70,7 +74,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
self.hs.config.registration_shared_secret = None
- request, channel = self.make_request("POST", self.url, b"{}")
+ channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -87,7 +91,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.hs.get_secrets = Mock(return_value=secrets)
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
self.assertEqual(channel.json_body, {"nonce": "abcd"})
@@ -96,14 +100,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s).
"""
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
# 59 seconds
self.reactor.advance(59)
body = json.dumps({"nonce": nonce})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("username must be specified", channel.json_body["error"])
@@ -111,7 +115,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# 61 seconds
self.reactor.advance(2)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -120,7 +124,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
Only the provided nonce can be used, as it's checked in the MAC.
"""
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -136,7 +140,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"])
@@ -146,7 +150,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
When the correct nonce is provided, and the right key is provided, the
user is registered.
"""
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -165,7 +169,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -174,7 +178,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
A valid unrecognised nonce.
"""
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -190,13 +194,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -209,7 +213,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
def nonce():
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
return channel.json_body["nonce"]
#
@@ -218,7 +222,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
body = json.dumps({})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("nonce must be specified", channel.json_body["error"])
@@ -229,28 +233,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
body = json.dumps({"nonce": nonce()})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid username", channel.json_body["error"])
@@ -261,28 +265,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
body = json.dumps({"nonce": nonce(), "username": "a"})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid password", channel.json_body["error"])
# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid password", channel.json_body["error"])
@@ -300,7 +304,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"user_type": "invalid",
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid user type", channel.json_body["error"])
@@ -311,7 +315,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
# set no displayname
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -321,17 +325,17 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = json.dumps(
{"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob1:test", channel.json_body["user_id"])
- request, channel = self.make_request("GET", "/profile/@bob1:test/displayname")
+ channel = self.make_request("GET", "/profile/@bob1:test/displayname")
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("bob1", channel.json_body["displayname"])
# displayname is None
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -347,17 +351,17 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob2:test", channel.json_body["user_id"])
- request, channel = self.make_request("GET", "/profile/@bob2:test/displayname")
+ channel = self.make_request("GET", "/profile/@bob2:test/displayname")
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("bob2", channel.json_body["displayname"])
# displayname is empty
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -373,16 +377,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob3:test", channel.json_body["user_id"])
- request, channel = self.make_request("GET", "/profile/@bob3:test/displayname")
+ channel = self.make_request("GET", "/profile/@bob3:test/displayname")
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
# set displayname
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -398,12 +402,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob4:test", channel.json_body["user_id"])
- request, channel = self.make_request("GET", "/profile/@bob4:test/displayname")
+ channel = self.make_request("GET", "/profile/@bob4:test/displayname")
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Bob's Name", channel.json_body["displayname"])
@@ -429,7 +433,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
# Register new user with admin API
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -448,7 +452,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -466,23 +470,34 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.register_user("user1", "pass1", admin=False)
- self.register_user("user2", "pass2", admin=False)
-
def test_no_auth(self):
"""
Try to list users without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ self._create_users(1)
+ other_user_token = self.login("user1", "pass1")
+
+ channel = self.make_request("GET", self.url, access_token=other_user_token)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_all_users(self):
"""
List all users, including deactivated users.
"""
- request, channel = self.make_request(
+ self._create_users(2)
+
+ channel = self.make_request(
"GET",
self.url + "?deactivated=true",
b"{}",
@@ -493,13 +508,292 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(3, len(channel.json_body["users"]))
self.assertEqual(3, channel.json_body["total"])
+ # Check that all fields are available
+ self._check_fields(channel.json_body["users"])
+
+ def test_search_term(self):
+ """Test that searching for a users works correctly"""
+
+ def _search_test(
+ expected_user_id: Optional[str],
+ search_term: str,
+ search_field: Optional[str] = "name",
+ expected_http_code: Optional[int] = 200,
+ ):
+ """Search for a user and check that the returned user's id is a match
+
+ Args:
+ expected_user_id: The user_id expected to be returned by the API. Set
+ to None to expect zero results for the search
+ search_term: The term to search for user names with
+ search_field: Field which is to request: `name` or `user_id`
+ expected_http_code: The expected http code for the request
+ """
+ url = self.url + "?%s=%s" % (
+ search_field,
+ search_term,
+ )
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
-class UserRestTestCase(unittest.HomeserverTestCase):
+ if expected_http_code != 200:
+ return
+
+ # Check that users were returned
+ self.assertTrue("users" in channel.json_body)
+ self._check_fields(channel.json_body["users"])
+ users = channel.json_body["users"]
+
+ # Check that the expected number of users were returned
+ expected_user_count = 1 if expected_user_id else 0
+ self.assertEqual(len(users), expected_user_count)
+ self.assertEqual(channel.json_body["total"], expected_user_count)
+
+ if expected_user_id:
+ # Check that the first returned user id is correct
+ u = users[0]
+ self.assertEqual(expected_user_id, u["name"])
+
+ self._create_users(2)
+
+ user1 = "@user1:test"
+ user2 = "@user2:test"
+
+ # Perform search tests
+ _search_test(user1, "er1")
+ _search_test(user1, "me 1")
+
+ _search_test(user2, "er2")
+ _search_test(user2, "me 2")
+
+ _search_test(user1, "er1", "user_id")
+ _search_test(user2, "er2", "user_id")
+
+ # Test case insensitive
+ _search_test(user1, "ER1")
+ _search_test(user1, "NAME 1")
+
+ _search_test(user2, "ER2")
+ _search_test(user2, "NAME 2")
+
+ _search_test(user1, "ER1", "user_id")
+ _search_test(user2, "ER2", "user_id")
+
+ _search_test(None, "foo")
+ _search_test(None, "bar")
+
+ _search_test(None, "foo", "user_id")
+ _search_test(None, "bar", "user_id")
+
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+
+ # negative limit
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=-5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # negative from
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=-5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # invalid guests
+ channel = self.make_request(
+ "GET",
+ self.url + "?guests=not_bool",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # invalid deactivated
+ channel = self.make_request(
+ "GET",
+ self.url + "?deactivated=not_bool",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ def test_limit(self):
+ """
+ Testing list of users with limit
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 5)
+ self.assertEqual(channel.json_body["next_token"], "5")
+ self._check_fields(channel.json_body["users"])
+
+ def test_from(self):
+ """
+ Testing list of users with a defined starting point (from)
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["users"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of users with a defined starting point and limit
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=5&limit=10",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(channel.json_body["next_token"], "15")
+ self.assertEqual(len(channel.json_body["users"]), 10)
+ self._check_fields(channel.json_body["users"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=20",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), number_users)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=21",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), number_users)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=19",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 19)
+ self.assertEqual(channel.json_body["next_token"], "19")
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=19",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def _check_fields(self, content: JsonDict):
+ """Checks that the expected user attributes are present in content
+ Args:
+ content: List that is checked for content
+ """
+ for u in content:
+ self.assertIn("name", u)
+ self.assertIn("is_guest", u)
+ self.assertIn("admin", u)
+ self.assertIn("user_type", u)
+ self.assertIn("deactivated", u)
+ self.assertIn("shadow_banned", u)
+ self.assertIn("displayname", u)
+ self.assertIn("avatar_url", u)
+
+ def _create_users(self, number_users: int):
+ """
+ Create a number of users
+ Args:
+ number_users: Number of users to be created
+ """
+ for i in range(1, number_users + 1):
+ self.register_user(
+ "user%d" % i,
+ "pass%d" % i,
+ admin=False,
+ displayname="Name %d" % i,
+ )
+
+
+class DeactivateAccountTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
- sync.register_servlets,
]
def prepare(self, reactor, clock, hs):
@@ -508,11 +802,232 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.other_user = self.register_user("user", "pass")
+ self.other_user = self.register_user("user", "pass", displayname="User1")
self.other_user_token = self.login("user", "pass")
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
self.other_user
)
+ self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote(
+ self.other_user
+ )
+
+ # set attributes for user
+ self.get_success(
+ self.store.set_profile_avatar_url("user", "mxc://servername/mediaid", 1)
+ )
+ self.get_success(
+ self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
+ )
+
+ def test_no_auth(self):
+ """
+ Try to deactivate users without authentication.
+ """
+ channel = self.make_request("POST", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_not_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ url = "/_synapse/admin/v1/deactivate/@bob:test"
+
+ channel = self.make_request("POST", url, access_token=self.other_user_token)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=self.other_user_token,
+ content=b"{}",
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that deactivation for a user that does not exist returns a 404
+ """
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/deactivate/@unknown_person:test",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_erase_is_not_bool(self):
+ """
+ If parameter `erase` is not boolean, return an error
+ """
+ body = json.dumps({"erase": "False"})
+
+ channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that deactivation for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
+
+ channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only deactivate local users", channel.json_body["error"])
+
+ def test_deactivate_user_erase_true(self):
+ """
+ Test deactivating an user and set `erase` to `true`
+ """
+
+ # Get user
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+ self.assertEqual("User1", channel.json_body["displayname"])
+
+ # Deactivate user
+ body = json.dumps({"erase": True})
+
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+ self.assertIsNone(channel.json_body["avatar_url"])
+ self.assertIsNone(channel.json_body["displayname"])
+
+ self._is_erased("@user:test", True)
+
+ def test_deactivate_user_erase_false(self):
+ """
+ Test deactivating an user and set `erase` to `false`
+ """
+
+ # Get user
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+ self.assertEqual("User1", channel.json_body["displayname"])
+
+ # Deactivate user
+ body = json.dumps({"erase": False})
+
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+
+ # On DINUM's deployment we clear the profile information during a deactivation regardless,
+ # whereas on mainline we decided to only do this if the deactivation was performed with erase: True.
+ # The discrepancy is due to profile replication.
+ # See synapse.storage.databases.main.profile.ProfileWorkerStore.set_profiles_active
+ self.assertIsNone(channel.json_body["avatar_url"])
+ self.assertIsNone(channel.json_body["displayname"])
+
+ self._is_erased("@user:test", False)
+
+ def _is_erased(self, user_id: str, expect: bool) -> None:
+ """Assert that the user is erased or not"""
+ d = self.store.is_user_erased(user_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertFalse(self.get_success(d))
+
+
+class UserRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.auth_handler = hs.get_auth_handler()
+
+ # create users and get access tokens
+ # regardless of whether password login or SSO is allowed
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ self.admin_user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ self.other_user = self.register_user("user", "pass", displayname="User")
+ self.other_user_token = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ self.other_user, device_id=None, valid_until_ms=None
+ )
+ )
+ self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
+ self.other_user
+ )
def test_requester_is_no_admin(self):
"""
@@ -520,15 +1035,20 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v2/users/@bob:test"
- request, channel = self.make_request(
- "GET", url, access_token=self.other_user_token,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("You are not a server admin", channel.json_body["error"])
- request, channel = self.make_request(
- "PUT", url, access_token=self.other_user_token, content=b"{}",
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.other_user_token,
+ content=b"{}",
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -539,7 +1059,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_synapse/admin/v2/users/@unknown_person:test",
access_token=self.admin_user_tok,
@@ -561,11 +1081,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": True,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
- "avatar_url": None,
+ "avatar_url": "mxc://fibble/wibble",
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -577,11 +1097,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(True, channel.json_body["admin"])
+ self.assertTrue(channel.json_body["admin"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -589,9 +1112,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(True, channel.json_body["admin"])
- self.assertEqual(False, channel.json_body["is_guest"])
- self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["admin"])
+ self.assertFalse(channel.json_body["is_guest"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
def test_create_user(self):
"""
@@ -606,10 +1130,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": False,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ "avatar_url": "mxc://fibble/wibble",
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -621,11 +1146,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(False, channel.json_body["admin"])
+ self.assertFalse(channel.json_body["admin"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -633,9 +1161,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(False, channel.json_body["admin"])
- self.assertEqual(False, channel.json_body["is_guest"])
- self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertFalse(channel.json_body["admin"])
+ self.assertFalse(channel.json_body["is_guest"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertFalse(channel.json_body["shadow_banned"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -651,9 +1181,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Sync to set admin user to active
# before limit of monthly active users is reached
- request, channel = self.make_request(
- "GET", "/sync", access_token=self.admin_user_tok
- )
+ channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok)
if channel.code != 200:
raise HttpResponseException(
@@ -676,7 +1204,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Create user
body = json.dumps({"password": "abc123", "admin": False})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -685,7 +1213,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["admin"])
+ self.assertFalse(channel.json_body["admin"])
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -715,7 +1243,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Create user
body = json.dumps({"password": "abc123", "admin": False})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -725,7 +1253,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Admin user is not blocked by mau anymore
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["admin"])
+ self.assertFalse(channel.json_body["admin"])
@override_config(
{
@@ -752,7 +1280,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -769,7 +1297,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertEqual("@bob:test", pushers[0]["user_name"])
+ self.assertEqual("@bob:test", pushers[0].user_name)
@override_config(
{
@@ -796,7 +1324,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -822,7 +1350,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Change password
body = json.dumps({"password": "hahaha"})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
@@ -839,7 +1367,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Modify user
body = json.dumps({"displayname": "foobar"})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
@@ -851,8 +1379,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("foobar", channel.json_body["displayname"])
# Get user
- request, channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -869,7 +1399,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
{"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
@@ -882,8 +1412,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
# Get user
- request, channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -896,29 +1428,113 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Test deactivating another user.
"""
- # Deactivate user
- body = json.dumps({"deactivated": True})
+ # set attributes for user
+ self.get_success(
+ self.store.set_profile_avatar_url("user", "mxc://servername/mediaid", 1)
+ )
+ self.get_success(
+ self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
+ )
+
+ # Get user
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ )
- request, channel = self.make_request(
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+
+ # On DINUM's deployment we clear the profile information during a deactivation regardless,
+ # whereas on mainline we decided to only do this if the deactivation was performed with erase: True.
+ # The discrepancy is due to profile replication.
+ # See synapse.storage.databases.main.profile.ProfileWorkerStore.set_profiles_active
+ self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+ self.assertEqual("User", channel.json_body["displayname"])
+
+ # Deactivate user
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"deactivated": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+
+ # On DINUM's deployment we clear the profile information during a deactivation regardless,
+ # whereas on mainline we decided to only do this if the deactivation was performed with erase: True.
+ # The discrepancy is due to profile replication.
+ # See synapse.storage.databases.main.profile.ProfileWorkerStore.set_profiles_active
+ self.assertIsNone(channel.json_body["avatar_url"])
+ self.assertIsNone(channel.json_body["displayname"])
# the user is deactivated, the threepid will be deleted
# Get user
- request, channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+ self.assertIsNone(channel.json_body["avatar_url"])
+ self.assertIsNone(channel.json_body["displayname"])
+
+ @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
+ def test_change_name_deactivate_user_user_directory(self):
+ """
+ Test change profile information of a deactivated user and
+ check that it does not appear in user directory
+ """
+
+ # is in user directory
+ profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+ self.assertTrue(profile["display_name"] == "User")
+
+ # Deactivate user
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"deactivated": True},
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertTrue(channel.json_body["deactivated"])
+
+ # is not in user directory
+ profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+ self.assertIsNone(profile)
+
+ # Set new displayname user
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"displayname": "Foobar"},
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertEqual("Foobar", channel.json_body["displayname"])
+
+ # is not in user directory
+ profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+ self.assertIsNone(profile)
def test_reactivate_user(self):
"""
@@ -926,46 +1542,92 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
# Deactivate the user.
- request, channel = self.make_request(
+ self._deactivate_user("@user:test")
+
+ # Attempt to reactivate the user (without a password).
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"deactivated": False},
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Reactivate the user.
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
+ content={"deactivated": False, "password": "foo"},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertIsNotNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
- d = self.store.mark_user_erased("@user:test")
- self.assertIsNone(self.get_success(d))
- self._is_erased("@user:test", True)
- # Attempt to reactivate the user (without a password).
- request, channel = self.make_request(
+ @override_config({"password_config": {"localdb_enabled": False}})
+ def test_reactivate_user_localdb_disabled(self):
+ """
+ Test reactivating another user when using SSO.
+ """
+
+ # Deactivate the user.
+ self._deactivate_user("@user:test")
+
+ # Reactivate the user with a password
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
+ content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- # Reactivate the user.
- request, channel = self.make_request(
+ # Reactivate the user without a password.
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=json.dumps({"deactivated": False, "password": "foo"}).encode(
- encoding="utf_8"
- ),
+ content={"deactivated": False},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
+ self._is_erased("@user:test", False)
- # Get user
- request, channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ @override_config({"password_config": {"enabled": False}})
+ def test_reactivate_user_password_disabled(self):
+ """
+ Test reactivating another user when using SSO.
+ """
+
+ # Deactivate the user.
+ self._deactivate_user("@user:test")
+
+ # Reactivate the user with a password
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"deactivated": False, "password": "foo"},
)
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+ # Reactivate the user without a password.
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"deactivated": False},
+ )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
def test_set_user_as_admin(self):
@@ -974,27 +1636,27 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
# Set a user as an admin
- body = json.dumps({"admin": True})
-
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"admin": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["admin"])
+ self.assertTrue(channel.json_body["admin"])
# Get user
- request, channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["admin"])
+ self.assertTrue(channel.json_body["admin"])
def test_accidental_deactivation_prevention(self):
"""
@@ -1004,13 +1666,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
- body = json.dumps({"password": "abc123"})
-
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"password": "abc123"},
)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
@@ -1018,8 +1678,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("bob", channel.json_body["displayname"])
# Get user
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1028,20 +1690,20 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["deactivated"])
# Change password (and use a str for deactivate instead of a bool)
- body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops!
-
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"password": "abc123", "deactivated": "false"},
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
# Check user is not deactivated
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1051,15 +1713,32 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"])
- def _is_erased(self, user_id, expect):
- """Assert that the user is erased or not
- """
+ def _is_erased(self, user_id: str, expect: bool) -> None:
+ """Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id)
if expect:
self.assertTrue(self.get_success(d))
else:
self.assertFalse(self.get_success(d))
+ def _deactivate_user(self, user_id: str) -> None:
+ """Deactivate user and set as erased"""
+
+ # Deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
+ access_token=self.admin_user_tok,
+ content={"deactivated": True},
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
+ self._is_erased(user_id, False)
+ d = self.store.mark_user_erased(user_id)
+ self.assertIsNone(self.get_success(d))
+ self._is_erased(user_id, True)
+
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
@@ -1070,8 +1749,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -1084,7 +1761,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
Try to list rooms of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1095,8 +1772,10 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "GET", self.url, access_token=other_user_token,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -1104,28 +1783,34 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
def test_user_does_not_exist(self):
"""
- Tests that a lookup for a user that does not exist returns a 404
+ Tests that a lookup for a user that does not exist returns an empty list
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+ self.assertEqual(0, len(channel.json_body["joined_rooms"]))
def test_user_is_not_local(self):
"""
- Tests that a lookup for a user that is not a local returns a 400
+ Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only lookup local users", channel.json_body["error"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+ self.assertEqual(0, len(channel.json_body["joined_rooms"]))
def test_no_memberships(self):
"""
@@ -1133,8 +1818,10 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
if user has no memberships
"""
# Get rooms
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1152,14 +1839,63 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.helper.create_room_as(self.other_user, tok=other_user_tok)
# Get rooms
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"])
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
+ def test_get_rooms_with_nonlocal_user(self):
+ """
+ Tests that a normal lookup for rooms is successful with a non-local user
+ """
+
+ other_user_tok = self.login("user", "pass")
+ event_builder_factory = self.hs.get_event_builder_factory()
+ event_creation_handler = self.hs.get_event_creation_handler()
+ storage = self.hs.get_storage()
+
+ # Create two rooms, one with a local user only and one with both a local
+ # and remote user.
+ self.helper.create_room_as(self.other_user, tok=other_user_tok)
+ local_and_remote_room_id = self.helper.create_room_as(
+ self.other_user, tok=other_user_tok
+ )
+
+ # Add a remote user to the room.
+ builder = event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": "m.room.member",
+ "sender": "@joiner:remote_hs",
+ "state_key": "@joiner:remote_hs",
+ "room_id": local_and_remote_room_id,
+ "content": {"membership": "join"},
+ },
+ )
+
+ event, context = self.get_success(
+ event_creation_handler.create_new_client_event(builder)
+ )
+
+ self.get_success(storage.persistence.persist_event(event, context))
+
+ # Now get rooms
+ url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
+
class PushersRestTestCase(unittest.HomeserverTestCase):
@@ -1183,7 +1919,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
Try to list pushers of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1194,8 +1930,10 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "GET", self.url, access_token=other_user_token,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -1206,8 +1944,10 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(404, channel.code, msg=channel.json_body)
@@ -1219,8 +1959,10 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -1232,8 +1974,10 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
# Get pushers
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1256,13 +2000,15 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "https://example.com/_matrix/push/v1/notify"},
)
)
# Get pushers
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1302,7 +2048,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""
Try to list media of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1313,8 +2059,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "GET", self.url, access_token=other_user_token,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -1325,8 +2073,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/media"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(404, channel.code, msg=channel.json_body)
@@ -1338,8 +2088,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -1352,10 +2104,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
number_media = 20
other_user_tok = self.login("user", "pass")
- self._create_media(other_user_tok, number_media)
+ self._create_media_for_user(other_user_tok, number_media)
- request, channel = self.make_request(
- "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1371,10 +2125,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
number_media = 20
other_user_tok = self.login("user", "pass")
- self._create_media(other_user_tok, number_media)
+ self._create_media_for_user(other_user_tok, number_media)
- request, channel = self.make_request(
- "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1390,10 +2146,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
number_media = 20
other_user_tok = self.login("user", "pass")
- self._create_media(other_user_tok, number_media)
+ self._create_media_for_user(other_user_tok, number_media)
- request, channel = self.make_request(
- "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=5&limit=10",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1402,25 +2160,45 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["media"]), 10)
self._check_fields(channel.json_body["media"])
- def test_limit_is_negative(self):
+ def test_invalid_parameter(self):
"""
- Testing that a negative limit parameter returns a 400
+ If parameters are invalid, an error is returned.
"""
+ # unkown order_by
+ channel = self.make_request(
+ "GET",
+ self.url + "?order_by=bar",
+ access_token=self.admin_user_tok,
+ )
- request, channel = self.make_request(
- "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # invalid search order
+ channel = self.make_request(
+ "GET",
+ self.url + "?dir=bar",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
- def test_from_is_negative(self):
- """
- Testing that a negative from parameter returns a 400
- """
+ # negative limit
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=-5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- request, channel = self.make_request(
- "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ # negative from
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -1433,12 +2211,14 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
number_media = 20
other_user_tok = self.login("user", "pass")
- self._create_media(other_user_tok, number_media)
+ self._create_media_for_user(other_user_tok, number_media)
# `next_token` does not appear
# Number of results is the number of entries
- request, channel = self.make_request(
- "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=20",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1448,8 +2228,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of max results is larger than the number of entries
- request, channel = self.make_request(
- "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=21",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1459,8 +2241,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# `next_token` does appear
# Number of max results is smaller than the number of entries
- request, channel = self.make_request(
- "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1471,8 +2255,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Check
# Set `from` to value of `next_token` for request remaining entries
# `next_token` does not appear
- request, channel = self.make_request(
- "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1486,8 +2272,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
if user has no media created
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1501,10 +2289,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
number_media = 5
other_user_tok = self.login("user", "pass")
- self._create_media(other_user_tok, number_media)
+ self._create_media_for_user(other_user_tok, number_media)
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1513,11 +2303,118 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["media"])
- def _create_media(self, user_token, number_media):
+ def test_order_by(self):
+ """
+ Testing order list with parameter `order_by`
+ """
+
+ other_user_tok = self.login("user", "pass")
+
+ # Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
+ image_data1 = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+ # Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B
+ image_data2 = unhexlify(
+ b"47494638376101000100800100000000"
+ b"ffffff2c00000000010001000002024c"
+ b"01003b"
+ )
+ # Resolution: 1×1, MIME type: image/bmp, Extension: bmp, Size: 54 B
+ image_data3 = unhexlify(
+ b"424d3a0000000000000036000000280000000100000001000000"
+ b"0100180000000000040000000000000000000000000000000000"
+ b"0000"
+ )
+
+ # create media and make sure they do not have the same timestamp
+ media1 = self._create_media_and_access(other_user_tok, image_data1, "image.png")
+ self.pump(1.0)
+ media2 = self._create_media_and_access(other_user_tok, image_data2, "image.gif")
+ self.pump(1.0)
+ media3 = self._create_media_and_access(other_user_tok, image_data3, "image.bmp")
+ self.pump(1.0)
+
+ # Mark one media as safe from quarantine.
+ self.get_success(self.store.mark_local_media_as_safe(media2))
+ # Quarantine one media
+ self.get_success(
+ self.store.quarantine_media_by_id("test", media3, self.admin_user)
+ )
+
+ # order by default ("created_ts")
+ # default is backwards
+ self._order_test([media3, media2, media1], None)
+ self._order_test([media1, media2, media3], None, "f")
+ self._order_test([media3, media2, media1], None, "b")
+
+ # sort by media_id
+ sorted_media = sorted([media1, media2, media3], reverse=False)
+ sorted_media_reverse = sorted(sorted_media, reverse=True)
+
+ # order by media_id
+ self._order_test(sorted_media, "media_id")
+ self._order_test(sorted_media, "media_id", "f")
+ self._order_test(sorted_media_reverse, "media_id", "b")
+
+ # order by upload_name
+ self._order_test([media3, media2, media1], "upload_name")
+ self._order_test([media3, media2, media1], "upload_name", "f")
+ self._order_test([media1, media2, media3], "upload_name", "b")
+
+ # order by media_type
+ # result is ordered by media_id
+ # because of uploaded media_type is always 'application/json'
+ self._order_test(sorted_media, "media_type")
+ self._order_test(sorted_media, "media_type", "f")
+ self._order_test(sorted_media, "media_type", "b")
+
+ # order by media_length
+ self._order_test([media2, media3, media1], "media_length")
+ self._order_test([media2, media3, media1], "media_length", "f")
+ self._order_test([media1, media3, media2], "media_length", "b")
+
+ # order by created_ts
+ self._order_test([media1, media2, media3], "created_ts")
+ self._order_test([media1, media2, media3], "created_ts", "f")
+ self._order_test([media3, media2, media1], "created_ts", "b")
+
+ # order by last_access_ts
+ self._order_test([media1, media2, media3], "last_access_ts")
+ self._order_test([media1, media2, media3], "last_access_ts", "f")
+ self._order_test([media3, media2, media1], "last_access_ts", "b")
+
+ # order by quarantined_by
+ # one media is in quarantine, others are ordered by media_ids
+
+ # Different sort order of SQlite and PostreSQL
+ # If a media is not in quarantine `quarantined_by` is NULL
+ # SQLite considers NULL to be smaller than any other value.
+ # PostreSQL considers NULL to be larger than any other value.
+
+ # self._order_test(sorted([media1, media2]) + [media3], "quarantined_by")
+ # self._order_test(sorted([media1, media2]) + [media3], "quarantined_by", "f")
+ # self._order_test([media3] + sorted([media1, media2]), "quarantined_by", "b")
+
+ # order by safe_from_quarantine
+ # one media is safe from quarantine, others are ordered by media_ids
+ self._order_test(sorted([media1, media3]) + [media2], "safe_from_quarantine")
+ self._order_test(
+ sorted([media1, media3]) + [media2], "safe_from_quarantine", "f"
+ )
+ self._order_test(
+ [media2] + sorted([media1, media3]), "safe_from_quarantine", "b"
+ )
+
+ def _create_media_for_user(self, user_token: str, number_media: int):
"""
Create a number of media for a specific user
+ Args:
+ user_token: Access token of the user
+ number_media: Number of media to be created for the user
"""
- upload_resource = self.media_repo.children[b"upload"]
for i in range(number_media):
# file size is 67 Byte
image_data = unhexlify(
@@ -1526,13 +2423,59 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
b"0a2db40000000049454e44ae426082"
)
- # Upload some media into the room
- self.helper.upload_media(
- upload_resource, image_data, tok=user_token, expect_code=200
- )
+ self._create_media_and_access(user_token, image_data)
+
+ def _create_media_and_access(
+ self,
+ user_token: str,
+ image_data: bytes,
+ filename: str = "image1.png",
+ ) -> str:
+ """
+ Create one media for a specific user, access and returns `media_id`
+ Args:
+ user_token: Access token of the user
+ image_data: binary data of image
+ filename: The filename of the media to be uploaded
+ Returns:
+ The ID of the newly created media.
+ """
+ upload_resource = self.media_repo.children[b"upload"]
+ download_resource = self.media_repo.children[b"download"]
+
+ # Upload some media into the room
+ response = self.helper.upload_media(
+ upload_resource, image_data, user_token, filename, expect_code=200
+ )
+
+ # Extract media ID from the response
+ server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
+ media_id = server_and_media_id.split("/")[1]
- def _check_fields(self, content):
- """Checks that all attributes are present in content
+ # Try to access a media and to create `last_access_ts`
+ channel = make_request(
+ self.reactor,
+ FakeSite(download_resource),
+ "GET",
+ server_and_media_id,
+ shorthand=False,
+ access_token=user_token,
+ )
+
+ self.assertEqual(
+ 200,
+ channel.code,
+ msg=(
+ "Expected to receive a 200 on accessing media: %s" % server_and_media_id
+ ),
+ )
+
+ return media_id
+
+ def _check_fields(self, content: JsonDict):
+ """Checks that the expected user attributes are present in content
+ Args:
+ content: List that is checked for content
"""
for m in content:
self.assertIn("media_id", m)
@@ -1544,10 +2487,41 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertIn("quarantined_by", m)
self.assertIn("safe_from_quarantine", m)
+ def _order_test(
+ self,
+ expected_media_list: List[str],
+ order_by: Optional[str],
+ dir: Optional[str] = None,
+ ):
+ """Request the list of media in a certain order. Assert that order is what
+ we expect
+ Args:
+ expected_media_list: The list of media_ids in the order we expect to get
+ back from the server
+ order_by: The type of ordering to give the server
+ dir: The direction of ordering to give the server
+ """
+
+ url = self.url + "?"
+ if order_by is not None:
+ url += "order_by=%s&" % (order_by,)
+ if dir is not None and dir in ("b", "f"):
+ url += "dir=%s" % (dir,)
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], len(expected_media_list))
+
+ returned_order = [row["media_id"] for row in channel.json_body["media"]]
+ self.assertEqual(expected_media_list, returned_order)
+ self._check_fields(channel.json_body["media"])
+
class UserTokenRestTestCase(unittest.HomeserverTestCase):
- """Test for /_synapse/admin/v1/users/<user>/login
- """
+ """Test for /_synapse/admin/v1/users/<user>/login"""
servlets = [
synapse.rest.admin.register_servlets,
@@ -1571,32 +2545,29 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
)
def _get_token(self) -> str:
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", self.url, b"{}", access_token=self.admin_user_tok
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
return channel.json_body["access_token"]
def test_no_auth(self):
- """Try to login as a user without authentication.
- """
- request, channel = self.make_request("POST", self.url, b"{}")
+ """Try to login as a user without authentication."""
+ channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_not_admin(self):
- """Try to login as a user as a non-admin user.
- """
- request, channel = self.make_request(
+ """Try to login as a user as a non-admin user."""
+ channel = self.make_request(
"POST", self.url, b"{}", access_token=self.other_user_tok
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
def test_send_event(self):
- """Test that sending event as a user works.
- """
+ """Test that sending event as a user works."""
# Create a room.
room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
@@ -1610,13 +2581,12 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(event.sender, self.other_user)
def test_devices(self):
- """Tests that logging in as a user doesn't create a new device for them.
- """
+ """Tests that logging in as a user doesn't create a new device for them."""
# Login in as the user
self._get_token()
# Check that we don't see a new device in our devices list
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1625,31 +2595,24 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["devices"]), 1)
def test_logout(self):
- """Test that calling `/logout` with the token works.
- """
+ """Test that calling `/logout` with the token works."""
# Login in as the user
puppet_token = self._get_token()
# Test that we can successfully make a request
- request, channel = self.make_request(
- "GET", "devices", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Logout with the puppet token
- request, channel = self.make_request(
- "POST", "logout", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# The puppet token should no longer work
- request, channel = self.make_request(
- "GET", "devices", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
# .. but the real user's tokens should still work
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1662,25 +2625,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
puppet_token = self._get_token()
# Test that we can successfully make a request
- request, channel = self.make_request(
- "GET", "devices", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Logout all with the real user token
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.other_user_tok
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# The puppet token should still work
- request, channel = self.make_request(
- "GET", "devices", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# .. but the real user's tokens shouldn't
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
@@ -1693,25 +2652,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
puppet_token = self._get_token()
# Test that we can successfully make a request
- request, channel = self.make_request(
- "GET", "devices", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Logout all with the admin user token
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.admin_user_tok
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# The puppet token should no longer work
- request, channel = self.make_request(
- "GET", "devices", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
# .. but the real user's tokens should still work
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1730,8 +2685,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
}
)
def test_consent(self):
- """Test that sending a message is not subject to the privacy policies.
- """
+ """Test that sending a message is not subject to the privacy policies."""
# Have the admin user accept the terms.
self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0"))
@@ -1778,8 +2732,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -1793,11 +2745,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
"""
Try to get information of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url1, b"{}")
+ channel = self.make_request("GET", self.url1, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- request, channel = self.make_request("GET", self.url2, b"{}")
+ channel = self.make_request("GET", self.url2, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1808,14 +2760,18 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.register_user("user2", "pass")
other_user2_token = self.login("user2", "pass")
- request, channel = self.make_request(
- "GET", self.url1, access_token=other_user2_token,
+ channel = self.make_request(
+ "GET",
+ self.url1,
+ access_token=other_user2_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- request, channel = self.make_request(
- "GET", self.url2, access_token=other_user2_token,
+ channel = self.make_request(
+ "GET",
+ self.url2,
+ access_token=other_user2_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1827,14 +2783,18 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
- request, channel = self.make_request(
- "GET", url1, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url1,
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
- request, channel = self.make_request(
- "GET", url2, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ url2,
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
@@ -1843,15 +2803,19 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
"""
The lookup should succeed for an admin.
"""
- request, channel = self.make_request(
- "GET", self.url1, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url1,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
- request, channel = self.make_request(
- "GET", self.url2, access_token=self.admin_user_tok,
+ channel = self.make_request(
+ "GET",
+ self.url2,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
@@ -1863,16 +2827,84 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "GET", self.url1, access_token=other_user_token,
+ channel = self.make_request(
+ "GET",
+ self.url1,
+ access_token=other_user_token,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
- request, channel = self.make_request(
- "GET", self.url2, access_token=other_user_token,
+ channel = self.make_request(
+ "GET",
+ self.url2,
+ access_token=other_user_token,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
+
+
+class ShadowBanRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+
+ self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to get information of an user without authentication.
+ """
+ channel = self.make_request("POST", self.url)
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_not_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ channel = self.make_request("POST", self.url, access_token=other_user_token)
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that shadow-banning for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+
+ channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+
+ def test_success(self):
+ """
+ Shadow-banning should succeed for an admin.
+ """
+ # The user starts off as not shadow-banned.
+ other_user_token = self.login("user", "pass")
+ result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ self.assertFalse(result.shadow_banned)
+
+ channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual({}, channel.json_body)
+
+ # Ensure the user is shadow-banned (and the cache was cleared).
+ result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ self.assertTrue(result.shadow_banned)
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index e2e6a5e16d..c74693e9b2 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -61,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
def test_render_public_consent(self):
"""You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs)
- request, channel = make_request(
+ channel = make_request(
self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
)
self.assertEqual(channel.code, 200)
@@ -82,7 +82,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+ "&u=user"
)
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(resource),
"GET",
@@ -97,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
self.assertEqual(consented, "False")
# POST to the consent page, saying we've agreed
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(resource),
"POST",
@@ -109,7 +109,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# Fetch the consent page, to get the consent version -- it should have
# changed
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(resource),
"GET",
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index a1ccc4ee9a..56937dcd2e 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -93,7 +93,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
def get_event(self, room_id, event_id, expected_code=200):
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
- request, channel = self.make_request("GET", url)
+ channel = self.make_request("GET", url)
self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index aa094c3543..61bdae0879 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
-
from mock import Mock
from twisted.internet import defer
@@ -50,22 +48,17 @@ class IdentityDisabledTestCase(unittest.HomeserverTestCase):
self.tok = self.login("kermit", "monkey")
def test_3pid_invite_disabled(self):
- request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=self.tok
- )
+ channel = self.make_request(b"POST", "/createRoom", {}, access_token=self.tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
room_id = channel.json_body["room_id"]
- params = {
+ data = {
"id_server": "testis",
"medium": "email",
"address": "test@example.com",
}
- request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
- request, channel = self.make_request(
- b"POST", request_url, request_data, access_token=self.tok
- )
+ channel = self.make_request(b"POST", request_url, data, access_token=self.tok)
self.assertEquals(channel.result["code"], b"403", channel.result)
def test_3pid_lookup_disabled(self):
@@ -73,7 +66,7 @@ class IdentityDisabledTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account/3pid/lookup"
"?id_server=testis&medium=email&address=foo@bar.baz"
)
- request, channel = self.make_request("GET", url, access_token=self.tok)
+ channel = self.make_request("GET", url, access_token=self.tok)
self.assertEqual(channel.result["code"], b"403", channel.result)
def test_3pid_bulk_lookup_disabled(self):
@@ -82,10 +75,7 @@ class IdentityDisabledTestCase(unittest.HomeserverTestCase):
"id_server": "testis",
"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
}
- request_data = json.dumps(data)
- request, channel = self.make_request(
- "POST", url, request_data, access_token=self.tok
- )
+ channel = self.make_request("POST", url, data, access_token=self.tok)
self.assertEqual(channel.result["code"], b"403", channel.result)
@@ -125,22 +115,19 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase):
self.tok = self.login("kermit", "monkey")
def test_3pid_invite_enabled(self):
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", "/createRoom", b"{}", access_token=self.tok
)
self.assertEquals(channel.result["code"], b"200", channel.result)
room_id = channel.json_body["room_id"]
- params = {
+ data = {
"id_server": "testis",
"medium": "email",
"address": "test@example.com",
}
- request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
- request, channel = self.make_request(
- b"POST", request_url, request_data, access_token=self.tok
- )
+ channel = self.make_request(b"POST", request_url, data, access_token=self.tok)
get_json = self.hs.get_identity_handler().http_client.get_json
get_json.assert_called_once_with(
@@ -168,8 +155,7 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase):
"id_server": "testis",
"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
}
- request_data = json.dumps(data)
- self.make_request("POST", url, request_data, access_token=self.tok)
+ self.make_request("POST", url, data, access_token=self.tok)
post_json = self.hs.get_simple_http_client().post_json_get_json
post_json.assert_called_once_with(
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
index 913ea3c98e..5256c11fe6 100644
--- a/tests/rest/client/test_power_levels.py
+++ b/tests/rest/client/test_power_levels.py
@@ -73,7 +73,9 @@ class PowerLevelsTestCase(HomeserverTestCase):
# Mod the mod
room_power_levels = self.helper.get_state(
- self.room_id, "m.room.power_levels", tok=self.admin_access_token,
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
)
# Update existing power levels with mod at PL50
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index c1f516cc93..e0c74591b6 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -69,16 +69,12 @@ class RedactionsTestCase(HomeserverTestCase):
"""
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
- request, channel = self.make_request(
- "POST", path, content={}, access_token=access_token
- )
+ channel = self.make_request("POST", path, content={}, access_token=access_token)
self.assertEqual(int(channel.result["code"]), expect_code)
return channel.json_body
def _sync_room_timeline(self, access_token, room_id):
- request, channel = self.make_request(
- "GET", "sync", access_token=self.mod_access_token
- )
+ channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
self.assertEqual(channel.result["code"], b"200")
room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"]
@@ -185,8 +181,7 @@ class RedactionsTestCase(HomeserverTestCase):
)
def test_redact_event_as_moderator_ratelimit(self):
- """Tests that the correct ratelimiting is applied to redactions
- """
+ """Tests that the correct ratelimiting is applied to redactions"""
message_ids = []
# as a regular user, send messages to redact
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 62b41ada42..b8285f3240 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -252,7 +252,8 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["backfill"])
self.hs = self.setup_test_homeserver(
- config=config, federation_client=mock_federation_client,
+ config=config,
+ federation_client=mock_federation_client,
)
return self.hs
@@ -327,7 +328,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
def get_event(self, room_id, event_id, expected_code=200):
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
- request, channel = self.make_request("GET", url, access_token=self.token)
+ channel = self.make_request("GET", url, access_token=self.token)
self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
index 89016b82ce..6344fcdf6a 100644
--- a/tests/rest/client/test_room_access_rules.py
+++ b/tests/rest/client/test_room_access_rules.py
@@ -86,7 +86,9 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["send_invite"])
mock_federation_client.send_invite.side_effect = send_invite
- mock_http_client = Mock(spec=["get_json", "post_json_get_json"],)
+ mock_http_client = Mock(
+ spec=["get_json", "post_json_get_json"],
+ )
# Mocking the response for /info on the IS API.
mock_http_client.get_json.side_effect = get_json
# Mocking the response for /store-invite on the IS API.
@@ -154,8 +156,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
self.create_room(rule=AccessRules.DIRECT, expected_code=400)
def test_create_room_direct_invalid_rule(self):
- """Tests that creating a direct room with an invalid rule will fail.
- """
+ """Tests that creating a direct room with an invalid rule will fail."""
self.create_room(direct=True, rule=AccessRules.RESTRICTED, expected_code=400)
def test_create_room_default_power_level_rules(self):
@@ -234,7 +235,9 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
"""
# Creates a room with the default power levels
room_id = self.create_room(
- direct=True, rule=AccessRules.DIRECT, expected_code=200,
+ direct=True,
+ rule=AccessRules.DIRECT,
+ expected_code=200,
)
# Attempt to drop invite and state_default power levels after the fact
@@ -273,7 +276,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
)
# List preset_room_id in the public room list
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/directory/list/room/%s" % (preset_room_id,),
{"visibility": "public"},
@@ -282,7 +285,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
# List init_state_room_id in the public room list
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/directory/list/room/%s" % (init_state_room_id,),
{"visibility": "public"},
@@ -361,14 +364,14 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room
data = {"visibility": "public"}
- request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ channel = self.make_request("PUT", url, data, access_token=self.tok)
self.assertEqual(channel.code, 200, channel.result)
# We are allowed to remove the room from the public room list
url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room
data = {"visibility": "private"}
- request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ channel = self.make_request("PUT", url, data, access_token=self.tok)
self.assertEqual(channel.code, 200, channel.result)
def test_direct(self):
@@ -470,7 +473,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/r0/directory/list/room/%s" % self.direct_rooms[0]
data = {"visibility": "public"}
- request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ channel = self.make_request("PUT", url, data, access_token=self.tok)
self.assertEqual(channel.code, 403, channel.result)
def test_unrestricted(self):
@@ -549,7 +552,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/r0/directory/list/room/%s" % self.unrestricted_room
data = {"visibility": "public"}
- request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ channel = self.make_request("PUT", url, data, access_token=self.tok)
self.assertEqual(channel.code, 403, channel.result)
def test_change_rules(self):
@@ -607,7 +610,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/r0/directory/list/room/%s" % test_room_id
data = {"visibility": "public"}
- request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ channel = self.make_request("PUT", url, data, access_token=self.tok)
self.assertEqual(channel.code, 200, channel.result)
# Attempt to switch the room to "unrestricted"
@@ -854,22 +857,30 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
"""
def freeze_room_with_id_and_power_levels(
- room_id: str, custom_power_levels_content: Optional[JsonDict] = None,
+ room_id: str,
+ custom_power_levels_content: Optional[JsonDict] = None,
):
# Invite a user to the room, they join with PL 0
self.helper.invite(
- room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok,
+ room=room_id,
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
)
# Invitee joins the room
self.helper.join(
- room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ room=room_id,
+ user=self.invitee_id,
+ tok=self.invitee_tok,
)
if not custom_power_levels_content:
# Retrieve the room's current power levels event content
power_levels = self.helper.get_state(
- room_id=room_id, event_type="m.room.power_levels", tok=self.tok,
+ room_id=room_id,
+ event_type="m.room.power_levels",
+ tok=self.tok,
)
else:
power_levels = custom_power_levels_content
@@ -884,12 +895,16 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
# Ensure that the invitee leaving the room does not change the power levels
self.helper.leave(
- room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ room=room_id,
+ user=self.invitee_id,
+ tok=self.invitee_tok,
)
# Retrieve the new power levels of the room
new_power_levels = self.helper.get_state(
- room_id=room_id, event_type="m.room.power_levels", tok=self.tok,
+ room_id=room_id,
+ event_type="m.room.power_levels",
+ tok=self.tok,
)
# Ensure they have not changed
@@ -897,22 +912,31 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
# Invite the user back again
self.helper.invite(
- room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok,
+ room=room_id,
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
)
# Invitee joins the room
self.helper.join(
- room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ room=room_id,
+ user=self.invitee_id,
+ tok=self.invitee_tok,
)
# Now the admin leaves the room
self.helper.leave(
- room=room_id, user=self.user_id, tok=self.tok,
+ room=room_id,
+ user=self.user_id,
+ tok=self.tok,
)
# Check the power levels again
new_power_levels = self.helper.get_state(
- room_id=room_id, event_type="m.room.power_levels", tok=self.invitee_tok,
+ room_id=room_id,
+ event_type="m.room.power_levels",
+ tok=self.invitee_tok,
)
# Ensure that the new power levels prevent anyone but admins from sending
@@ -1001,8 +1025,11 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
if power_levels_content_override:
content["power_levels_content_override"] = power_levels_content_override
- request, channel = self.make_request(
- "POST", "/_matrix/client/r0/createRoom", content, access_token=self.tok,
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ content,
+ access_token=self.tok,
)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -1011,7 +1038,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
return channel.json_body["room_id"]
def current_rule_in_room(self, room_id):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
access_token=self.tok,
@@ -1022,7 +1049,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
def change_rule_in_room(self, room_id, new_rule, expected_code=200):
data = {"rule": new_rule}
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
json.dumps(data),
@@ -1033,7 +1060,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200):
data = {"join_rule": new_join_rule}
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules),
json.dumps(data),
@@ -1045,7 +1072,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
def send_threepid_invite(self, address, room_id, expected_code=200):
params = {"id_server": "testis", "medium": "email", "address": address}
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/invite" % room_id,
json.dumps(params),
@@ -1062,9 +1089,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
state_key,
)
- request, channel = self.make_request(
- "PUT", path, json.dumps(body), access_token=tok
- )
+ channel = self.make_request("PUT", path, json.dumps(body), access_token=tok)
self.assertEqual(channel.code, expect_code, channel.result)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 94dcfb9f7c..d2cce44032 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -18,6 +18,7 @@ import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+from synapse.types import UserID
from tests import unittest
@@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase):
self.store = self.hs.get_datastore()
self.get_success(
- self.store.db_pool.simple_update(
- table="users",
- keyvalues={"name": self.banned_user_id},
- updatevalues={"shadow_banned": True},
- desc="shadow_ban",
- )
+ self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True)
)
self.other_user_id = self.register_user("otheruser", "pass")
@@ -89,7 +85,7 @@ class RoomTestCase(_ShadowBannedBase):
)
# Inviting the user completes successfully.
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/rooms/%s/invite" % (room_id,),
{"id_server": "test", "medium": "email", "address": "test@test.test"},
@@ -103,7 +99,7 @@ class RoomTestCase(_ShadowBannedBase):
def test_create_room(self):
"""Invitations during a room creation should be discarded, but the room still gets created."""
# The room creation is successful.
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/createRoom",
{"visibility": "public", "invite": [self.other_user_id]},
@@ -158,7 +154,7 @@ class RoomTestCase(_ShadowBannedBase):
self.banned_user_id, tok=self.banned_access_token
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
{"new_version": "6"},
@@ -183,7 +179,7 @@ class RoomTestCase(_ShadowBannedBase):
self.banned_user_id, tok=self.banned_access_token
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (room_id, self.banned_user_id),
{"typing": True, "timeout": 30000},
@@ -198,7 +194,7 @@ class RoomTestCase(_ShadowBannedBase):
# The other user can join and send typing events.
self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (room_id, self.other_user_id),
{"typing": True, "timeout": 30000},
@@ -244,7 +240,7 @@ class ProfileTestCase(_ShadowBannedBase):
)
# The update should succeed.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
{"displayname": new_display_name},
@@ -254,7 +250,7 @@ class ProfileTestCase(_ShadowBannedBase):
self.assertEqual(channel.json_body, {})
# The user's display name should be updated.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/profile/%s/displayname" % (self.banned_user_id,)
)
self.assertEqual(channel.code, 200, channel.result)
@@ -264,7 +260,10 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ self.banned_user_id,
+ room_id,
+ "m.room.member",
+ self.banned_user_id,
)
)
self.assertEqual(
@@ -282,7 +281,7 @@ class ProfileTestCase(_ShadowBannedBase):
)
# The update should succeed.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
% (room_id, self.banned_user_id),
@@ -296,7 +295,10 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ self.banned_user_id,
+ room_id,
+ "m.room.member",
+ self.banned_user_id,
)
)
self.assertEqual(
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 0e96697f9b..bf39014277 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -86,7 +86,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
callback = Mock(spec=[], side_effect=check)
current_rules_module().check_event_allowed = callback
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
{},
@@ -104,7 +104,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.assertEqual(ev.type, k[0])
self.assertEqual(ev.state_key, k[1])
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id,
{},
@@ -123,7 +123,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
current_rules_module().check_event_allowed = check
# now send the event
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
{"x": "x"},
@@ -142,7 +142,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
current_rules_module().check_event_allowed = check
# now send the event
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
{"x": "x"},
@@ -152,7 +152,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
event_id = channel.json_body["event_id"]
# ... and check that it got modified
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
@@ -161,6 +161,68 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
+ def test_message_edit(self):
+ """Ensure that the module doesn't cause issues with edited messages."""
+ # first patch the event checker so that it will modify the event
+ async def check(ev: EventBase, state):
+ d = ev.get_dict()
+ d["content"] = {
+ "msgtype": "m.text",
+ "body": d["content"]["body"].upper(),
+ }
+ return d
+
+ current_rules_module().check_event_allowed = check
+
+ # Send an event, then edit it.
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+ {
+ "msgtype": "m.text",
+ "body": "Original body",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ orig_event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/m.room.message/2" % self.room_id,
+ {
+ "m.new_content": {"msgtype": "m.text", "body": "Edited body"},
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": orig_event_id,
+ },
+ "msgtype": "m.text",
+ "body": "Edited body",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ edited_event_id = channel.json_body["event_id"]
+
+ # ... and check that they both got modified
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["body"], "ORIGINAL BODY")
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["body"], "EDITED BODY")
+
def test_send_event(self):
"""Tests that the module can send an event into a room via the module api"""
content = {
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index 7a2c653df8..edd1d184f8 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/v1/test_directory.py
@@ -91,7 +91,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# that we can make sure that the check is done on the whole alias.
data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, 400, channel.result)
@@ -104,7 +104,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# as cautious as possible here.
data = {"room_alias_name": random_string(5)}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, 200, channel.result)
@@ -118,7 +118,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
data = {"aliases": [self.random_alias(alias_length)]}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -128,7 +128,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
data = {"room_id": self.room_id}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 12a93f5687..2ae896db1e 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -63,13 +63,13 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
# implementation is now part of the r0 implementation, the newer
# behaviour is used instead to be consistent with the r0 spec.
# see issue #2602
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/events?access_token=%s" % ("invalid" + self.token,)
)
self.assertEquals(channel.code, 401, msg=channel.result)
# valid token, expect content
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/events?access_token=%s&timeout=0" % (self.token,)
)
self.assertEquals(channel.code, 200, msg=channel.result)
@@ -87,7 +87,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
)
# valid token, expect content
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/events?access_token=%s&timeout=0" % (self.token,)
)
self.assertEquals(channel.code, 200, msg=channel.result)
@@ -149,7 +149,9 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
resp = self.helper.send(self.room_id, tok=self.token)
event_id = resp["event_id"]
- request, channel = self.make_request(
- "GET", "/events/" + event_id, access_token=self.token,
+ channel = self.make_request(
+ "GET",
+ "/events/" + event_id,
+ access_token=self.token,
)
self.assertEquals(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index ee75f4076f..988821b16f 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,23 +1,89 @@
-import json
+# -*- coding: utf-8 -*-
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import time
import urllib.parse
+from typing import Any, Dict, List, Optional, Union
+from urllib.parse import urlencode
from mock import Mock
-import jwt
+import pymacaroons
+
+from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.types import create_requester
from tests import unittest
-from tests.unittest import override_config
+from tests.handlers.test_oidc import HAS_OIDC
+from tests.handlers.test_saml import has_saml2
+from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.test_utils.html_parsers import TestHtmlParser
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
+
+try:
+ import jwt
+
+ HAS_JWT = True
+except ImportError:
+ HAS_JWT = False
+
+
+# synapse server name: used to populate public_baseurl in some tests
+SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
+
+# public_baseurl for some tests. It uses an http:// scheme because
+# FakeChannel.isSecure() returns False, so synapse will see the requested uri as
+# http://..., so using http in the public_baseurl stops Synapse trying to redirect to
+# https://....
+BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
+
+# CAS server used in some tests
+CAS_SERVER = "https://fake.test"
+
+# just enough to tell pysaml2 where to redirect to
+SAML_SERVER = "https://test.saml.server/idp/sso"
+TEST_SAML_METADATA = """
+<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
+ <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
+ <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/>
+ </md:IDPSSODescriptor>
+</md:EntityDescriptor>
+""" % {
+ "SAML_SERVER": SAML_SERVER,
+}
LOGIN_URL = b"/_matrix/client/r0/login"
TEST_URL = b"/_matrix/client/r0/account/whoami"
+# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is +
+TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
+
+# the query params in TEST_CLIENT_REDIRECT_URL
+EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
+
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
class LoginRestServletTestCase(unittest.HomeserverTestCase):
@@ -63,7 +129,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -82,7 +148,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -108,7 +174,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -127,7 +193,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -153,7 +219,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -172,7 +238,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.result["code"], b"403", channel.result)
@@ -181,7 +247,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.register_user("kermit", "monkey")
# we shouldn't be able to make requests without an access token
- request, channel = self.make_request(b"GET", TEST_URL)
+ channel = self.make_request(b"GET", TEST_URL)
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN")
@@ -191,25 +257,21 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
@@ -223,9 +285,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# more requests with the expired token should still return a soft-logout
self.reactor.advance(3600)
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
@@ -233,16 +293,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# ... but if we delete that device, it will be a proper logout
self._delete_device(access_token_2, "kermit", "monkey", device_id)
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], False)
def _delete_device(self, access_token, user_id, password, device_id):
"""Perform the UI-Auth to delete a device"""
- request, channel = self.make_request(
+ channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token
)
self.assertEquals(channel.code, 401, channel.result)
@@ -262,7 +320,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"session": channel.json_body["session"],
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"DELETE",
"devices/" + device_id,
access_token=access_token,
@@ -278,26 +336,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token = self.login("kermit", "monkey")
# we should now be able to make requests with the access token
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
# Now try to hard logout this session
- request, channel = self.make_request(
- b"POST", "/logout", access_token=access_token
- )
+ channel = self.make_request(b"POST", "/logout", access_token=access_token)
self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config({"session_lifetime": "24h"})
@@ -308,29 +360,350 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token = self.login("kermit", "monkey")
# we should now be able to make requests with the access token
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
# Now try to hard log out all of the user's sessions
- request, channel = self.make_request(
- b"POST", "/logout/all", access_token=access_token
- )
+ channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
self.assertEquals(channel.result["code"], b"200", channel.result)
+@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
+class MultiSSOTestCase(unittest.HomeserverTestCase):
+ """Tests for homeservers with multiple SSO providers enabled"""
+
+ servlets = [
+ login.register_servlets,
+ ]
+
+ def default_config(self) -> Dict[str, Any]:
+ config = super().default_config()
+
+ config["public_baseurl"] = BASE_URL
+
+ config["cas_config"] = {
+ "enabled": True,
+ "server_url": CAS_SERVER,
+ "service_url": "https://matrix.goodserver.com:8448",
+ }
+
+ config["saml2_config"] = {
+ "sp_config": {
+ "metadata": {"inline": [TEST_SAML_METADATA]},
+ # use the XMLSecurity backend to avoid relying on xmlsec1
+ "crypto_backend": "XMLSecurity",
+ },
+ }
+
+ # default OIDC provider
+ config["oidc_config"] = TEST_OIDC_CONFIG
+
+ # additional OIDC providers
+ config["oidc_providers"] = [
+ {
+ "idp_id": "idp1",
+ "idp_name": "IDP1",
+ "discover": False,
+ "issuer": "https://issuer1",
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["profile"],
+ "authorization_endpoint": "https://issuer1/auth",
+ "token_endpoint": "https://issuer1/token",
+ "userinfo_endpoint": "https://issuer1/userinfo",
+ "user_mapping_provider": {
+ "config": {"localpart_template": "{{ user.sub }}"}
+ },
+ }
+ ]
+ return config
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d.update(build_synapse_client_resource_tree(self.hs))
+ return d
+
+ def test_get_login_flows(self):
+ """GET /login should return password and SSO flows"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ expected_flow_types = [
+ "m.login.cas",
+ "m.login.sso",
+ "m.login.token",
+ "m.login.password",
+ ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]
+
+ self.assertCountEqual(
+ [f["type"] for f in channel.json_body["flows"]], expected_flow_types
+ )
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_get_msc2858_login_flows(self):
+ """The SSO flow should include IdP info if MSC2858 is enabled"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # stick the flows results in a dict by type
+ flow_results = {} # type: Dict[str, Any]
+ for f in channel.json_body["flows"]:
+ flow_type = f["type"]
+ self.assertNotIn(
+ flow_type, flow_results, "duplicate flow type %s" % (flow_type,)
+ )
+ flow_results[flow_type] = f
+
+ self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned")
+ sso_flow = flow_results.pop("m.login.sso")
+ # we should have a set of IdPs
+ self.assertCountEqual(
+ sso_flow["org.matrix.msc2858.identity_providers"],
+ [
+ {"id": "cas", "name": "CAS"},
+ {"id": "saml", "name": "SAML"},
+ {"id": "oidc-idp1", "name": "IDP1"},
+ {"id": "oidc", "name": "OIDC"},
+ ],
+ )
+
+ # the rest of the flows are simple
+ expected_flows = [
+ {"type": "m.login.cas"},
+ {"type": "m.login.token"},
+ {"type": "m.login.password"},
+ ] + ADDITIONAL_LOGIN_FLOWS
+
+ self.assertCountEqual(flow_results.values(), expected_flows)
+
+ def test_multi_sso_redirect(self):
+ """/login/sso/redirect should redirect to an identity picker"""
+ # first hit the redirect url, which should redirect to our idp picker
+ channel = self._make_sso_redirect_request(False, None)
+ self.assertEqual(channel.code, 302, channel.result)
+ uri = channel.headers.getRawHeaders("Location")[0]
+
+ # hitting that picker should give us some HTML
+ channel = self.make_request("GET", uri)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # parse the form to check it has fields assumed elsewhere in this class
+ html = channel.result["body"].decode("utf-8")
+ p = TestHtmlParser()
+ p.feed(html)
+ p.close()
+
+ # there should be a link for each href
+ returned_idps = [] # type: List[str]
+ for link in p.links:
+ path, query = link.split("?", 1)
+ self.assertEqual(path, "pick_idp")
+ params = urllib.parse.parse_qs(query)
+ self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
+ returned_idps.append(params["idp"][0])
+
+ self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
+
+ def test_multi_sso_redirect_to_cas(self):
+ """If CAS is chosen, should redirect to the CAS server"""
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/client/pick_idp?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ + "&idp=cas",
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 302, channel.result)
+ location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
+ cas_uri = location_headers[0]
+ cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
+
+ # it should redirect us to the login page of the cas server
+ self.assertEqual(cas_uri_path, CAS_SERVER + "/login")
+
+ # check that the redirectUrl is correctly encoded in the service param - ie, the
+ # place that CAS will redirect to
+ cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
+ service_uri = cas_uri_params["service"][0]
+ _, service_uri_query = service_uri.split("?", 1)
+ service_uri_params = urllib.parse.parse_qs(service_uri_query)
+ self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
+
+ def test_multi_sso_redirect_to_saml(self):
+ """If SAML is chosen, should redirect to the SAML server"""
+ channel = self.make_request(
+ "GET",
+ "/_synapse/client/pick_idp?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ + "&idp=saml",
+ )
+ self.assertEqual(channel.code, 302, channel.result)
+ location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
+ saml_uri = location_headers[0]
+ saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
+
+ # it should redirect us to the login page of the SAML server
+ self.assertEqual(saml_uri_path, SAML_SERVER)
+
+ # the RelayState is used to carry the client redirect url
+ saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
+ relay_state_param = saml_uri_params["RelayState"][0]
+ self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
+
+ def test_login_via_oidc(self):
+ """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
+
+ # pick the default OIDC provider
+ channel = self.make_request(
+ "GET",
+ "/_synapse/client/pick_idp?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ + "&idp=oidc",
+ )
+ self.assertEqual(channel.code, 302, channel.result)
+ location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
+ oidc_uri = location_headers[0]
+ oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+ # it should redirect us to the auth page of the OIDC server
+ self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+ # ... and should have set a cookie including the redirect url
+ cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
+ assert cookie_headers
+ cookies = {} # type: Dict[str, str]
+ for h in cookie_headers:
+ key, value = h.split(";")[0].split("=", maxsplit=1)
+ cookies[key] = value
+
+ oidc_session_cookie = cookies["oidc_session"]
+ macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
+ self.assertEqual(
+ self._get_value_from_macaroon(macaroon, "client_redirect_url"),
+ TEST_CLIENT_REDIRECT_URL,
+ )
+
+ channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
+
+ # that should serve a confirmation page
+ self.assertEqual(channel.code, 200, channel.result)
+ content_type_headers = channel.headers.getRawHeaders("Content-Type")
+ assert content_type_headers
+ self.assertTrue(content_type_headers[-1].startswith("text/html"))
+ p = TestHtmlParser()
+ p.feed(channel.text_body)
+ p.close()
+
+ # ... which should contain our redirect link
+ self.assertEqual(len(p.links), 1)
+ path, query = p.links[0].split("?", 1)
+ self.assertEqual(path, "https://x")
+
+ # it will have url-encoded the params properly, so we'll have to parse them
+ params = urllib.parse.parse_qsl(
+ query, keep_blank_values=True, strict_parsing=True, errors="strict"
+ )
+ self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+ self.assertEqual(params[2][0], "loginToken")
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token, mxid, and device id.
+ login_token = params[2][1]
+ chan = self.make_request(
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+ self.assertEqual(chan.json_body["user_id"], "@user1:test")
+
+ def test_multi_sso_redirect_to_unknown(self):
+ """An unknown IdP should cause a 400"""
+ channel = self.make_request(
+ "GET",
+ "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def test_client_idp_redirect_to_unknown(self):
+ """If the client tries to pick an unknown IdP, return a 404"""
+ channel = self._make_sso_redirect_request(False, "xxx")
+ self.assertEqual(channel.code, 404, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
+
+ def test_client_idp_redirect_to_oidc(self):
+ """If the client pick a known IdP, redirect to it"""
+ channel = self._make_sso_redirect_request(False, "oidc")
+ self.assertEqual(channel.code, 302, channel.result)
+ oidc_uri = channel.headers.getRawHeaders("Location")[0]
+ oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+ # it should redirect us to the auth page of the OIDC server
+ self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_client_msc2858_redirect_to_oidc(self):
+ """Test the unstable API"""
+ channel = self._make_sso_redirect_request(True, "oidc")
+ self.assertEqual(channel.code, 302, channel.result)
+ oidc_uri = channel.headers.getRawHeaders("Location")[0]
+ oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+ # it should redirect us to the auth page of the OIDC server
+ self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+ def test_client_idp_redirect_msc2858_disabled(self):
+ """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400"""
+ channel = self._make_sso_redirect_request(True, "oidc")
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
+ def _make_sso_redirect_request(
+ self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
+ ):
+ """Send a request to /_matrix/client/r0/login/sso/redirect
+
+ ... or the unstable equivalent
+
+ ... possibly specifying an IDP provider
+ """
+ endpoint = (
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect"
+ if unstable_endpoint
+ else "/_matrix/client/r0/login/sso/redirect"
+ )
+ if idp_prov is not None:
+ endpoint += "/" + idp_prov
+ endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+
+ return self.make_request(
+ "GET",
+ endpoint,
+ custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)],
+ )
+
+ @staticmethod
+ def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+ prefix = key + " = "
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith(prefix):
+ return caveat.caveat_id[len(prefix) :]
+ raise ValueError("No %s caveat in macaroon" % (key,))
+
+
class CASTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -342,10 +715,12 @@ class CASTestCase(unittest.HomeserverTestCase):
self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
config = self.default_config()
+ config["public_baseurl"] = (
+ config.get("public_baseurl") or "https://matrix.goodserver.com:8448"
+ )
config["cas_config"] = {
"enabled": True,
- "server_url": "https://fake.test",
- "service_url": "https://matrix.goodserver.com:8448",
+ "server_url": CAS_SERVER,
}
cas_user_id = "username"
@@ -379,7 +754,8 @@ class CASTestCase(unittest.HomeserverTestCase):
mocked_http_client.get_raw.side_effect = get_raw
self.hs = self.setup_test_homeserver(
- config=config, proxied_http_client=mocked_http_client,
+ config=config,
+ proxied_http_client=mocked_http_client,
)
return self.hs
@@ -402,10 +778,10 @@ class CASTestCase(unittest.HomeserverTestCase):
cas_ticket_url = urllib.parse.urlunparse(url_parts)
# Get Synapse to call the fake CAS and serve the template.
- request, channel = self.make_request("GET", cas_ticket_url)
+ channel = self.make_request("GET", cas_ticket_url)
# Test that the response is HTML.
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, 200, channel.result)
content_type_header_value = ""
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
@@ -430,8 +806,7 @@ class CASTestCase(unittest.HomeserverTestCase):
}
)
def test_cas_redirect_whitelisted(self):
- """Tests that the SSO login flow serves a redirect to a whitelisted url
- """
+ """Tests that the SSO login flow serves a redirect to a whitelisted url"""
self._test_redirect("https://legit-site.com/")
@override_config({"public_baseurl": "https://example.com"})
@@ -446,10 +821,11 @@ class CASTestCase(unittest.HomeserverTestCase):
)
# Get Synapse to call the fake CAS and serve the template.
- request, channel = self.make_request("GET", cas_ticket_url)
+ channel = self.make_request("GET", cas_ticket_url)
self.assertEqual(channel.code, 302)
location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
@override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
@@ -462,7 +838,9 @@ class CASTestCase(unittest.HomeserverTestCase):
# Deactivate the account.
self.get_success(
- self.deactivate_account_handler.deactivate_account(self.user_id, False)
+ self.deactivate_account_handler.deactivate_account(
+ self.user_id, False, create_requester(self.user_id)
+ )
)
# Request the CAS ticket.
@@ -472,13 +850,14 @@ class CASTestCase(unittest.HomeserverTestCase):
)
# Get Synapse to call the fake CAS and serve the template.
- request, channel = self.make_request("GET", cas_ticket_url)
+ channel = self.make_request("GET", cas_ticket_url)
# Because the user is deactivated they are served an error template.
self.assertEqual(channel.code, 403)
self.assertIn(b"SSO account deactivated", channel.result["body"])
+@skip_unless(HAS_JWT, "requires jwt")
class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -495,18 +874,18 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs
- def jwt_encode(self, token: str, secret: str = jwt_secret) -> str:
+ def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result = jwt.encode(token, secret, self.jwt_algorithm)
+ result = jwt.encode(
+ payload, secret, self.jwt_algorithm
+ ) # type: Union[str, bytes]
if isinstance(result, bytes):
return result.decode("ascii")
return result
def jwt_login(self, *args):
- params = json.dumps(
- {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
- )
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+ channel = self.make_request(b"POST", LOGIN_URL, params)
return channel
def test_login_jwt_valid_registered(self):
@@ -637,8 +1016,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
)
def test_login_no_token(self):
- params = json.dumps({"type": "org.matrix.login.jwt"})
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ params = {"type": "org.matrix.login.jwt"}
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -647,6 +1026,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
# signed by the private key.
+@skip_unless(HAS_JWT, "requires jwt")
class JWTPubKeyTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
@@ -704,18 +1084,16 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = "RS256"
return self.hs
- def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str:
+ def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result = jwt.encode(token, secret, "RS256")
+ result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str]
if isinstance(result, bytes):
return result.decode("ascii")
return result
def jwt_login(self, *args):
- params = json.dumps(
- {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
- )
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+ channel = self.make_request(b"POST", LOGIN_URL, params)
return channel
def test_login_jwt_valid(self):
@@ -743,7 +1121,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
]
def register_as_user(self, username):
- request, channel = self.make_request(
+ self.make_request(
b"POST",
"/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
{"username": username},
@@ -784,60 +1162,56 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
return self.hs
def test_login_appservice_user(self):
- """Test that an appservice user can use /login
- """
+ """Test that an appservice user can use /login"""
self.register_as_user(AS_USER)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": AS_USER},
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token
)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_user_bot(self):
- """Test that the appservice bot can use /login
- """
+ """Test that the appservice bot can use /login"""
self.register_as_user(AS_USER)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": self.service.sender},
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token
)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_wrong_user(self):
- """Test that non-as users cannot login with the as token
- """
+ """Test that non-as users cannot login with the as token"""
self.register_as_user(AS_USER)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": "fibble_wibble"},
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token
)
self.assertEquals(channel.result["code"], b"403", channel.result)
def test_login_appservice_wrong_as(self):
- """Test that as users cannot login with wrong as token
- """
+ """Test that as users cannot login with wrong as token"""
self.register_as_user(AS_USER)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": AS_USER},
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.another_service.token
)
@@ -845,7 +1219,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
def test_login_appservice_no_token(self):
"""Test that users must provide a token when using the appservice
- login method
+ login method
"""
self.register_as_user(AS_USER)
@@ -853,6 +1227,124 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": AS_USER},
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.result["code"], b"401", channel.result)
+
+
+@skip_unless(HAS_OIDC, "requires OIDC")
+class UsernamePickerTestCase(HomeserverTestCase):
+ """Tests for the username picker flow of SSO login"""
+
+ servlets = [login.register_servlets]
+
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+
+ config["oidc_config"] = {}
+ config["oidc_config"].update(TEST_OIDC_CONFIG)
+ config["oidc_config"]["user_mapping_provider"] = {
+ "config": {"display_name_template": "{{ user.displayname }}"}
+ }
+
+ # whitelist this client URI so we redirect straight to it rather than
+ # serving a confirmation page
+ config["sso"] = {"client_whitelist": ["https://x"]}
+ return config
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d.update(build_synapse_client_resource_tree(self.hs))
+ return d
+
+ def test_username_picker(self):
+ """Test the happy path of a username picker flow."""
+
+ # do the start of the login flow
+ channel = self.helper.auth_via_oidc(
+ {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
+ )
+
+ # that should redirect to the username picker
+ self.assertEqual(channel.code, 302, channel.result)
+ location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
+ picker_url = location_headers[0]
+ self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
+
+ # ... with a username_mapping_session cookie
+ cookies = {} # type: Dict[str,str]
+ channel.extract_cookies(cookies)
+ self.assertIn("username_mapping_session", cookies)
+ session_id = cookies["username_mapping_session"]
+
+ # introspect the sso handler a bit to check that the username mapping session
+ # looks ok.
+ username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
+ self.assertIn(
+ session_id,
+ username_mapping_sessions,
+ "session id not found in map",
+ )
+ session = username_mapping_sessions[session_id]
+ self.assertEqual(session.remote_user_id, "tester")
+ self.assertEqual(session.display_name, "Jonny")
+ self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
+
+ # the expiry time should be about 15 minutes away
+ expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
+ self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
+
+ # Now, submit a username to the username picker, which should serve a redirect
+ # to the completion page
+ content = urlencode({b"username": b"bobby"}).encode("utf8")
+ chan = self.make_request(
+ "POST",
+ path=picker_url,
+ content=content,
+ content_is_form=True,
+ custom_headers=[
+ ("Cookie", "username_mapping_session=" + session_id),
+ # old versions of twisted don't do form-parsing without a valid
+ # content-length header.
+ ("Content-Length", str(len(content))),
+ ],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+ assert location_headers
+
+ # send a request to the completion page, which should 302 to the client redirectUrl
+ chan = self.make_request(
+ "GET",
+ path=location_headers[0],
+ custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+ assert location_headers
+
+ # ensure that the returned location matches the requested redirect URL
+ path, query = location_headers[0].split("?", 1)
+ self.assertEqual(path, "https://x")
+
+ # it will have url-encoded the params properly, so we'll have to parse them
+ params = urllib.parse.parse_qsl(
+ query, keep_blank_values=True, strict_parsing=True, errors="strict"
+ )
+ self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+ self.assertEqual(params[2][0], "loginToken")
+
+ # fish the login token out of the returned redirect uri
+ login_token = params[2][1]
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token, mxid, and device id.
+ chan = self.make_request(
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+ self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 11cd8efe21..94a5154834 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -53,7 +53,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.hs.config.use_presence = True
body = {"presence": "here", "status_msg": "beep boop"}
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
@@ -68,7 +68,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.hs.config.use_presence = False
body = {"presence": "here", "status_msg": "beep boop"}
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 2a3b483eaf..f3448c94dd 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -14,163 +14,11 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
-import json
-
-from mock import Mock
-
-from twisted.internet import defer
-
-import synapse.types
-from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin
from synapse.rest.client.v1 import login, profile, room
from tests import unittest
-from ....utils import MockHttpResource, setup_test_homeserver
-
-myid = "@1234ABCD:test"
-PATH_PREFIX = "/_matrix/client/r0"
-
-
-class MockHandlerProfileTestCase(unittest.TestCase):
- """ Tests rest layer of profile management.
-
- Todo: move these into ProfileTestCase
- """
-
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.mock_handler = Mock(
- spec=[
- "get_displayname",
- "set_displayname",
- "get_avatar_url",
- "set_avatar_url",
- "check_profile_query_allowed",
- ]
- )
-
- self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
- self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
- self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
- self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
- self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
- Mock()
- )
-
- hs = yield setup_test_homeserver(
- self.addCleanup,
- "test",
- federation_http_client=None,
- resource_for_client=self.mock_resource,
- federation=Mock(),
- federation_client=Mock(),
- profile_handler=self.mock_handler,
- )
-
- async def _get_user_by_req(request=None, allow_guest=False):
- return synapse.types.create_requester(myid)
-
- hs.get_auth().get_user_by_req = _get_user_by_req
-
- profile.register_servlets(hs, self.mock_resource)
-
- @defer.inlineCallbacks
- def test_get_my_name(self):
- mocked_get = self.mock_handler.get_displayname
- mocked_get.return_value = defer.succeed("Frank")
-
- (code, response) = yield self.mock_resource.trigger(
- "GET", "/profile/%s/displayname" % (myid), None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"displayname": "Frank"}, response)
- self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
-
- @defer.inlineCallbacks
- def test_set_my_name(self):
- mocked_set = self.mock_handler.set_displayname
- mocked_set.return_value = defer.succeed(())
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}'
- )
-
- self.assertEquals(200, code)
- self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.")
-
- @defer.inlineCallbacks
- def test_set_my_name_noauth(self):
- mocked_set = self.mock_handler.set_displayname
- mocked_set.side_effect = AuthError(400, "message")
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/profile/%s/displayname" % ("@4567:test"),
- b'{"displayname": "Frank Jr."}',
- )
-
- self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code))
-
- @defer.inlineCallbacks
- def test_get_other_name(self):
- mocked_get = self.mock_handler.get_displayname
- mocked_get.return_value = defer.succeed("Bob")
-
- (code, response) = yield self.mock_resource.trigger(
- "GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"displayname": "Bob"}, response)
-
- @defer.inlineCallbacks
- def test_set_other_name(self):
- mocked_set = self.mock_handler.set_displayname
- mocked_set.side_effect = SynapseError(400, "message")
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/profile/%s/displayname" % ("@opaque:elsewhere"),
- b'{"displayname":"bob"}',
- )
-
- self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code))
-
- @defer.inlineCallbacks
- def test_get_my_avatar(self):
- mocked_get = self.mock_handler.get_avatar_url
- mocked_get.return_value = defer.succeed("http://my.server/me.png")
-
- (code, response) = yield self.mock_resource.trigger(
- "GET", "/profile/%s/avatar_url" % (myid), None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"avatar_url": "http://my.server/me.png"}, response)
- self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
-
- @defer.inlineCallbacks
- def test_set_my_avatar(self):
- mocked_set = self.mock_handler.set_avatar_url
- mocked_set.return_value = defer.succeed(())
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/profile/%s/avatar_url" % (myid),
- b'{"avatar_url": "http://my.server/pic.gif"}',
- )
-
- self.assertEquals(200, code)
- self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")
-
class ProfileTestCase(unittest.HomeserverTestCase):
@@ -187,39 +35,122 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.owner = self.register_user("owner", "pass")
self.owner_tok = self.login("owner", "pass")
+ self.other = self.register_user("other", "pass", displayname="Bob")
+
+ def test_get_displayname(self):
+ res = self._get_displayname()
+ self.assertEqual(res, "owner")
def test_set_displayname(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
- content=json.dumps({"displayname": "test"}),
+ content={"displayname": "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
- res = self.get_displayname()
+ res = self._get_displayname()
self.assertEqual(res, "test")
+ def test_set_displayname_noauth(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.owner,),
+ content={"displayname": "test"},
+ )
+ self.assertEqual(channel.code, 401, channel.result)
+
def test_set_displayname_too_long(self):
"""Attempts to set a stupid displayname should get a 400"""
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
- content=json.dumps({"displayname": "test" * 100}),
+ content={"displayname": "test" * 100},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
- res = self.get_displayname()
+ res = self._get_displayname()
self.assertEqual(res, "owner")
- def get_displayname(self):
- request, channel = self.make_request(
- "GET", "/profile/%s/displayname" % (self.owner,)
+ def test_get_displayname_other(self):
+ res = self._get_displayname(self.other)
+ self.assertEquals(res, "Bob")
+
+ def test_set_displayname_other(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.other,),
+ content={"displayname": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def test_get_avatar_url(self):
+ res = self._get_avatar_url()
+ self.assertIsNone(res)
+
+ def test_set_avatar_url(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.owner,),
+ content={"avatar_url": "http://my.server/pic.gif"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ res = self._get_avatar_url()
+ self.assertEqual(res, "http://my.server/pic.gif")
+
+ def test_set_avatar_url_noauth(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.owner,),
+ content={"avatar_url": "http://my.server/pic.gif"},
+ )
+ self.assertEqual(channel.code, 401, channel.result)
+
+ def test_set_avatar_url_too_long(self):
+ """Attempts to set a stupid avatar_url should get a 400"""
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.owner,),
+ content={"avatar_url": "http://my.server/pic.gif" * 100},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ res = self._get_avatar_url()
+ self.assertIsNone(res)
+
+ def test_get_avatar_url_other(self):
+ res = self._get_avatar_url(self.other)
+ self.assertIsNone(res)
+
+ def test_set_avatar_url_other(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.other,),
+ content={"avatar_url": "http://my.server/pic.gif"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def _get_displayname(self, name=None):
+ channel = self.make_request(
+ "GET", "/profile/%s/displayname" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["displayname"]
+ def _get_avatar_url(self, name=None):
+ channel = self.make_request(
+ "GET", "/profile/%s/avatar_url" % (name or self.owner,)
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body.get("avatar_url")
+
class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
@@ -278,7 +209,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
)
def request_profile(self, expected_code, url_suffix="", access_token=None):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.profile_url + url_suffix, access_token=access_token
)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -320,19 +251,19 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
"""Tests that a user can lookup their own profile without having to be in a room
if 'require_auth_for_profile_requests' is set to true in the server's config.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/profile/" + self.requester, access_token=self.requester_tok
)
self.assertEqual(channel.code, 200, channel.result)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/profile/" + self.requester + "/displayname",
access_token=self.requester_tok,
)
self.assertEqual(channel.code, 200, channel.result)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/profile/" + self.requester + "/avatar_url",
access_token=self.requester_tok,
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
index 7add5523c8..2bc512d75e 100644
--- a/tests/rest/client/v1/test_push_rule_attrs.py
+++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -45,13 +45,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# GET enabled for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 200)
@@ -74,13 +74,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# disable the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/enabled",
{"enabled": False},
@@ -89,26 +89,26 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200)
# check rule disabled
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], False)
# DELETE the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", "/pushrules/global/override/best.friend", access_token=token
)
self.assertEqual(channel.code, 200)
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# GET enabled for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 200)
@@ -130,13 +130,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# disable the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/enabled",
{"enabled": False},
@@ -145,14 +145,14 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200)
# check rule disabled
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], False)
# re-enable the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/enabled",
{"enabled": True},
@@ -161,7 +161,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200)
# check rule enabled
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 200)
@@ -182,32 +182,32 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# GET enabled for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 200)
# DELETE the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", "/pushrules/global/override/best.friend", access_token=token
)
self.assertEqual(channel.code, 200)
# check 404 for deleted rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 404)
@@ -221,7 +221,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token
)
self.assertEqual(channel.code, 404)
@@ -235,7 +235,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# enable & check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/enabled",
{"enabled": True},
@@ -252,7 +252,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# enable & check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/.m.muahahah/enabled",
{"enabled": True},
@@ -276,13 +276,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# GET actions for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/actions", access_token=token
)
self.assertEqual(channel.code, 200)
@@ -305,13 +305,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# change the rule actions
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/actions",
{"actions": ["dont_notify"]},
@@ -320,7 +320,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200)
# GET actions for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/actions", access_token=token
)
self.assertEqual(channel.code, 200)
@@ -341,26 +341,26 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# DELETE the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", "/pushrules/global/override/best.friend", access_token=token
)
self.assertEqual(channel.code, 200)
# check 404 for deleted rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 404)
@@ -374,7 +374,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token
)
self.assertEqual(channel.code, 404)
@@ -388,7 +388,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# enable & check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/actions",
{"actions": ["dont_notify"]},
@@ -405,7 +405,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# enable & check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/.m.muahahah/actions",
{"actions": ["dont_notify"]},
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index e67de41c18..ed65f645fc 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -26,9 +26,10 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
+from synapse.rest import admin
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias, UserID
+from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.util.stringutils import random_string
from tests import unittest
@@ -45,7 +46,9 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver(
- "red", federation_http_client=None, federation_client=Mock(),
+ "red",
+ federation_http_client=None,
+ federation_client=Mock(),
)
self.hs.get_federation_handler = Mock()
@@ -83,13 +86,13 @@ class RoomPermissionsTestCase(RoomBase):
self.created_rmid_msg_path = (
"rooms/%s/send/m.room.message/a1" % (self.created_rmid)
).encode("ascii")
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
)
self.assertEquals(200, channel.code, channel.result)
# set topic for public room
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
b'{"topic":"Public Room Topic"}',
@@ -111,7 +114,7 @@ class RoomPermissionsTestCase(RoomBase):
)
# send message in uncreated room, expect 403
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content,
@@ -119,24 +122,24 @@ class RoomPermissionsTestCase(RoomBase):
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room not joined (no state), expect 403
- request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+ channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
- request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+ channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
- request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+ channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# send message in created room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
- request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+ channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_topic_perms(self):
@@ -144,30 +147,30 @@ class RoomPermissionsTestCase(RoomBase):
topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
# set/get topic in uncreated room, expect 403
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
)
self.assertEquals(403, channel.code, msg=channel.result["body"])
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403
- request, channel = self.make_request("PUT", topic_path, topic_content)
+ channel = self.make_request("PUT", topic_path, topic_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", topic_path)
+ channel = self.make_request("GET", topic_path)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
- request, channel = self.make_request("PUT", topic_path, topic_content)
+ channel = self.make_request("PUT", topic_path, topic_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403
- request, channel = self.make_request("GET", topic_path)
+ channel = self.make_request("GET", topic_path)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200
@@ -175,29 +178,29 @@ class RoomPermissionsTestCase(RoomBase):
# Only room ops can set topic by default
self.helper.auth_user_id = self.rmcreator_id
- request, channel = self.make_request("PUT", topic_path, topic_content)
+ channel = self.make_request("PUT", topic_path, topic_content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.auth_user_id = self.user_id
- request, channel = self.make_request("GET", topic_path)
+ channel = self.make_request("GET", topic_path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
- request, channel = self.make_request("PUT", topic_path, topic_content)
+ channel = self.make_request("PUT", topic_path, topic_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", topic_path)
+ channel = self.make_request("GET", topic_path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
topic_content,
@@ -207,7 +210,7 @@ class RoomPermissionsTestCase(RoomBase):
def _test_get_membership(self, room=None, members=[], expect_code=None):
for member in members:
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
- request, channel = self.make_request("GET", path)
+ channel = self.make_request("GET", path)
self.assertEquals(expect_code, channel.code)
def test_membership_basic_room_perms(self):
@@ -379,16 +382,16 @@ class RoomsMemberListTestCase(RoomBase):
def test_get_member_list(self):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
+ channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEquals(200, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self):
- request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
+ channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self):
room_id = self.helper.create_room_as("@some_other_guy:red")
- request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
+ channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self):
@@ -397,17 +400,17 @@ class RoomsMemberListTestCase(RoomBase):
room_path = "/rooms/%s/members" % room_id
self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited.
- request, channel = self.make_request("GET", room_path)
+ channel = self.make_request("GET", room_path)
self.assertEquals(403, channel.code, msg=channel.result["body"])
self.helper.join(room=room_id, user=self.user_id)
# can see list now joined
- request, channel = self.make_request("GET", room_path)
+ channel = self.make_request("GET", room_path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left
- request, channel = self.make_request("GET", room_path)
+ channel = self.make_request("GET", room_path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
@@ -418,30 +421,26 @@ class RoomsCreateTestCase(RoomBase):
def test_post_room_no_keys(self):
# POST with no config keys, expect new room id
- request, channel = self.make_request("POST", "/createRoom", "{}")
+ channel = self.make_request("POST", "/createRoom", "{}")
self.assertEquals(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_visibility_key(self):
# POST with visibility config key, expect new room id
- request, channel = self.make_request(
- "POST", "/createRoom", b'{"visibility":"private"}'
- )
+ channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self):
# POST with custom config keys, expect new room id
- request, channel = self.make_request(
- "POST", "/createRoom", b'{"custom":"stuff"}'
- )
+ channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self):
# POST with custom + known config keys, expect new room id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
)
self.assertEquals(200, channel.code)
@@ -449,16 +448,16 @@ class RoomsCreateTestCase(RoomBase):
def test_post_room_invalid_content(self):
# POST with invalid content / paths, expect 400
- request, channel = self.make_request("POST", "/createRoom", b'{"visibili')
+ channel = self.make_request("POST", "/createRoom", b'{"visibili')
self.assertEquals(400, channel.code)
- request, channel = self.make_request("POST", "/createRoom", b'["hello"]')
+ channel = self.make_request("POST", "/createRoom", b'["hello"]')
self.assertEquals(400, channel.code)
def test_post_room_invitees_invalid_mxid(self):
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
# Note the trailing space in the MXID here!
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
)
self.assertEquals(400, channel.code)
@@ -476,54 +475,54 @@ class RoomTopicTestCase(RoomBase):
def test_invalid_puts(self):
# missing keys or invalid json
- request, channel = self.make_request("PUT", self.path, "{}")
+ channel = self.make_request("PUT", self.path, "{}")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
+ channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, '{"nao')
+ channel = self.make_request("PUT", self.path, '{"nao')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
)
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, "text only")
+ channel = self.make_request("PUT", self.path, "text only")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, "")
+ channel = self.make_request("PUT", self.path, "")
self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid key, wrong type
content = '{"topic":["Topic name"]}'
- request, channel = self.make_request("PUT", self.path, content)
+ channel = self.make_request("PUT", self.path, content)
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_topic(self):
# nothing should be there
- request, channel = self.make_request("GET", self.path)
+ channel = self.make_request("GET", self.path)
self.assertEquals(404, channel.code, msg=channel.result["body"])
# valid put
content = '{"topic":"Topic name"}'
- request, channel = self.make_request("PUT", self.path, content)
+ channel = self.make_request("PUT", self.path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get
- request, channel = self.make_request("GET", self.path)
+ channel = self.make_request("GET", self.path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self):
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
- request, channel = self.make_request("PUT", self.path, content)
+ channel = self.make_request("PUT", self.path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get
- request, channel = self.make_request("GET", self.path)
+ channel = self.make_request("GET", self.path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
@@ -539,24 +538,22 @@ class RoomMemberStateTestCase(RoomBase):
def test_invalid_puts(self):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
- request, channel = self.make_request("PUT", path, "{}")
+ channel = self.make_request("PUT", path, "{}")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, '{"_name":"bo"}')
+ channel = self.make_request("PUT", path, '{"_name":"bo"}')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, '{"nao')
+ channel = self.make_request("PUT", path, '{"nao')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request(
- "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
- )
+ channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, "text only")
+ channel = self.make_request("PUT", path, "text only")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, "")
+ channel = self.make_request("PUT", path, "")
self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid keys, wrong types
@@ -565,7 +562,7 @@ class RoomMemberStateTestCase(RoomBase):
Membership.JOIN,
Membership.LEAVE,
)
- request, channel = self.make_request("PUT", path, content.encode("ascii"))
+ channel = self.make_request("PUT", path, content.encode("ascii"))
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_members_self(self):
@@ -576,10 +573,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
- request, channel = self.make_request("PUT", path, content.encode("ascii"))
+ channel = self.make_request("PUT", path, content.encode("ascii"))
self.assertEquals(200, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", path, None)
+ channel = self.make_request("GET", path, None)
self.assertEquals(200, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN}
@@ -594,10 +591,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid invite message
content = '{"membership":"%s"}' % Membership.INVITE
- request, channel = self.make_request("PUT", path, content)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", path, None)
+ channel = self.make_request("GET", path, None)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body)
@@ -613,18 +610,54 @@ class RoomMemberStateTestCase(RoomBase):
Membership.INVITE,
"Join us!",
)
- request, channel = self.make_request("PUT", path, content)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", path, None)
+ channel = self.make_request("GET", path, None)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body)
+class RoomInviteRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ admin.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_rooms_ratelimit(self):
+ """Tests that invites in a room are actually rate-limited."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ for i in range(3):
+ self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,))
+
+ self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429)
+
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_users_ratelimit(self):
+ """Tests that invites to a specific user are actually rate-limited."""
+
+ for i in range(3):
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red")
+
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
servlets = [
+ admin.register_servlets,
profile.register_servlets,
room.register_servlets,
]
@@ -666,7 +699,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
- request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
+ channel = self.make_request("PUT", path, {"displayname": "John Doe"})
self.assertEquals(channel.code, 200, channel.json_body)
# Check that all the rooms have been sent a profile update into.
@@ -676,7 +709,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
self.user_id,
)
- request, channel = self.make_request("GET", path)
+ channel = self.make_request("GET", path)
self.assertEquals(channel.code, 200)
self.assertIn("displayname", channel.json_body)
@@ -700,9 +733,23 @@ class RoomJoinRatelimitTestCase(RoomBase):
# Make sure we send more requests than the rate-limiting config would allow
# if all of these requests ended up joining the user to a room.
for i in range(4):
- request, channel = self.make_request("POST", path % room_id, {})
+ channel = self.make_request("POST", path % room_id, {})
self.assertEquals(channel.code, 200)
+ @unittest.override_config(
+ {
+ "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}},
+ "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"],
+ "autocreate_auto_join_rooms": True,
+ },
+ )
+ def test_autojoin_rooms(self):
+ user_id = self.register_user("testuser", "password")
+
+ # Check that the new user successfully joined the four rooms
+ rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id))
+ self.assertEqual(len(rooms), 4)
+
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
@@ -715,42 +762,40 @@ class RoomMessagesTestCase(RoomBase):
def test_invalid_puts(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
- request, channel = self.make_request("PUT", path, b"{}")
+ channel = self.make_request("PUT", path, b"{}")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b'{"_name":"bo"}')
+ channel = self.make_request("PUT", path, b'{"_name":"bo"}')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b'{"nao')
+ channel = self.make_request("PUT", path, b'{"nao')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request(
- "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
- )
+ channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b"text only")
+ channel = self.make_request("PUT", path, b"text only")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b"")
+ channel = self.make_request("PUT", path, b"")
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_messages_sent(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = b'{"body":"test","msgtype":{"type":"a"}}'
- request, channel = self.make_request("PUT", path, content)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(400, channel.code, msg=channel.result["body"])
# custom message types
content = b'{"body":"test","msgtype":"test.custom.text"}'
- request, channel = self.make_request("PUT", path, content)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# m.text message type
path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = b'{"body":"test2","msgtype":"m.text"}'
- request, channel = self.make_request("PUT", path, content)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
@@ -764,9 +809,7 @@ class RoomInitialSyncTestCase(RoomBase):
self.room_id = self.helper.create_room_as(self.user_id)
def test_initial_sync(self):
- request, channel = self.make_request(
- "GET", "/rooms/%s/initialSync" % self.room_id
- )
+ channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
self.assertEquals(200, channel.code)
self.assertEquals(self.room_id, channel.json_body["room_id"])
@@ -807,7 +850,7 @@ class RoomMessageListTestCase(RoomBase):
def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
self.assertEquals(200, channel.code)
@@ -818,7 +861,7 @@ class RoomMessageListTestCase(RoomBase):
def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
self.assertEquals(200, channel.code)
@@ -851,7 +894,7 @@ class RoomMessageListTestCase(RoomBase):
self.helper.send(self.room_id, "message 3")
# Check that we get the first and second message when querying /messages.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (
@@ -879,7 +922,7 @@ class RoomMessageListTestCase(RoomBase):
# Check that we only get the second message through /message now that the first
# has been purged.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (
@@ -896,7 +939,7 @@ class RoomMessageListTestCase(RoomBase):
# Check that we get no event, but also no error, when querying /messages with
# the token that was pointing at the first event, because we don't have it
# anymore.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (
@@ -955,7 +998,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
self.helper.send(self.room, body="There!", tok=self.other_access_token)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/search?access_token=%s" % (self.access_token,),
{
@@ -984,7 +1027,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
self.helper.send(self.room, body="There!", tok=self.other_access_token)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/search?access_token=%s" % (self.access_token,),
{
@@ -1032,14 +1075,14 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
return self.hs
def test_restricted_no_auth(self):
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401, channel.result)
def test_restricted_auth(self):
self.register_user("user", "pass")
tok = self.login("user", "pass")
- request, channel = self.make_request("GET", self.url, access_token=tok)
+ channel = self.make_request("GET", self.url, access_token=tok)
self.assertEqual(channel.code, 200, channel.result)
@@ -1067,7 +1110,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
self.displayname = "test user"
data = {"displayname": self.displayname}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.user_id,),
request_data,
@@ -1080,7 +1123,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
def test_per_room_profile_forbidden(self):
data = {"membership": "join", "displayname": "other test user"}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
% (self.room_id, self.user_id),
@@ -1090,7 +1133,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
@@ -1123,7 +1166,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
def test_join_reason(self):
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/join".format(self.room_id),
content={"reason": reason},
@@ -1137,7 +1180,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
content={"reason": reason},
@@ -1151,7 +1194,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
content={"reason": reason, "user_id": self.second_user_id},
@@ -1165,7 +1208,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
content={"reason": reason, "user_id": self.second_user_id},
@@ -1177,7 +1220,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
def test_unban_reason(self):
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
content={"reason": reason, "user_id": self.second_user_id},
@@ -1189,7 +1232,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
def test_invite_reason(self):
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
content={"reason": reason, "user_id": self.second_user_id},
@@ -1208,7 +1251,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
)
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
content={"reason": reason},
@@ -1219,7 +1262,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
def _check_for_reason(self, reason):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
self.room_id, self.second_user_id
@@ -1268,7 +1311,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""Test that we can filter by a label on a /context request."""
event_id = self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/context/%s?filter=%s"
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
@@ -1298,7 +1341,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""Test that we can filter by the absence of a label on a /context request."""
event_id = self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/context/%s?filter=%s"
% (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
@@ -1333,7 +1376,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""
event_id = self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/context/%s?filter=%s"
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
@@ -1361,7 +1404,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
% (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)),
@@ -1378,7 +1421,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
% (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)),
@@ -1401,7 +1444,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
% (
@@ -1432,14 +1475,16 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/search?access_token=%s" % self.tok, request_data
)
results = channel.json_body["search_categories"]["room_events"]["results"]
self.assertEqual(
- len(results), 2, [result["result"]["content"] for result in results],
+ len(results),
+ 2,
+ [result["result"]["content"] for result in results],
)
self.assertEqual(
results[0]["result"]["content"]["body"],
@@ -1467,14 +1512,16 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/search?access_token=%s" % self.tok, request_data
)
results = channel.json_body["search_categories"]["room_events"]["results"]
self.assertEqual(
- len(results), 4, [result["result"]["content"] for result in results],
+ len(results),
+ 4,
+ [result["result"]["content"] for result in results],
)
self.assertEqual(
results[0]["result"]["content"]["body"],
@@ -1514,14 +1561,16 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/search?access_token=%s" % self.tok, request_data
)
results = channel.json_body["search_categories"]["room_events"]["results"]
self.assertEqual(
- len(results), 1, [result["result"]["content"] for result in results],
+ len(results),
+ 1,
+ [result["result"]["content"] for result in results],
)
self.assertEqual(
results[0]["result"]["content"]["body"],
@@ -1635,7 +1684,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
# Check that we can still see the messages before the erasure request.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
'/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
% (self.room_id, event_id),
@@ -1681,7 +1730,9 @@ class ContextTestCase(unittest.HomeserverTestCase):
deactivate_account_handler = self.hs.get_deactivate_account_handler()
self.get_success(
- deactivate_account_handler.deactivate_account(self.user_id, erase_data=True)
+ deactivate_account_handler.deactivate_account(
+ self.user_id, True, create_requester(self.user_id)
+ )
)
# Invite another user in the room. This is needed because messages will be
@@ -1699,7 +1750,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
# Check that a user that joined the room after the erasure request can't see
# the messages anymore.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
'/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
% (self.room_id, event_id),
@@ -1789,7 +1840,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases"
% (self.room_id,),
@@ -1810,7 +1861,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
data = {"room_id": self.room_id}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -1840,14 +1891,14 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
data = {"room_id": self.room_id}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
access_token=self.room_owner_tok,
@@ -1859,7 +1910,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
json.dumps(content),
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index ae0207366b..329dbd06de 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -18,8 +18,6 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.rest.client.v1 import room
from synapse.types import UserID
@@ -39,7 +37,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "red", federation_http_client=None, federation_client=Mock(),
+ "red",
+ federation_http_client=None,
+ federation_client=Mock(),
)
self.event_source = hs.get_event_sources().sources["typing"]
@@ -60,32 +60,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_datastore().insert_client_ip = _insert_client_ip
- def get_room_members(room_id):
- if room_id == self.room_id:
- return defer.succeed([self.user])
- else:
- return defer.succeed([])
-
- @defer.inlineCallbacks
- def fetch_room_distributions_into(
- room_id, localusers=None, remotedomains=None, ignore_user=None
- ):
- members = yield get_room_members(room_id)
- for member in members:
- if ignore_user is not None and member == ignore_user:
- continue
-
- if hs.is_mine(member):
- if localusers is not None:
- localusers.add(member)
- else:
- if remotedomains is not None:
- remotedomains.add(member.domain)
-
- hs.get_room_member_handler().fetch_room_distributions_into = (
- fetch_room_distributions_into
- )
-
return hs
def prepare(self, reactor, clock, hs):
@@ -94,7 +68,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user="@jim:red")
def test_set_typing(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
@@ -117,7 +91,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
)
def test_set_not_typing(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": false}',
@@ -125,7 +99,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code)
def test_typing_timeout(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
@@ -138,7 +112,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 2)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 737c38c396..946740aa5d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,8 +17,12 @@
# limitations under the License.
import json
+import re
import time
-from typing import Any, Dict, Optional
+import urllib.parse
+from typing import Any, Dict, Mapping, MutableMapping, Optional
+
+from mock import patch
import attr
@@ -26,8 +30,11 @@ from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership
+from synapse.types import JsonDict
-from tests.server import FakeSite, make_request
+from tests.server import FakeChannel, FakeSite, make_request
+from tests.test_utils import FakeResponse
+from tests.test_utils.html_parsers import TestHtmlParser
@attr.s
@@ -75,7 +82,7 @@ class RestHelper:
if tok:
path = path + "?access_token=%s" % tok
- _, channel = make_request(
+ channel = make_request(
self.hs.get_reactor(),
self.site,
"POST",
@@ -151,7 +158,7 @@ class RestHelper:
data = {"membership": membership}
data.update(extra_data)
- _, channel = make_request(
+ channel = make_request(
self.hs.get_reactor(),
self.site,
"PUT",
@@ -159,9 +166,12 @@ class RestHelper:
json.dumps(data).encode("utf8"),
)
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
)
self.auth_user_id = temp_id
@@ -186,7 +196,7 @@ class RestHelper:
if tok:
path = path + "?access_token=%s" % tok
- _, channel = make_request(
+ channel = make_request(
self.hs.get_reactor(),
self.site,
"PUT",
@@ -194,9 +204,12 @@ class RestHelper:
json.dumps(content).encode("utf8"),
)
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
)
return channel.json_body
@@ -242,13 +255,14 @@ class RestHelper:
if body is not None:
content = json.dumps(body).encode("utf8")
- _, channel = make_request(
- self.hs.get_reactor(), self.site, method, path, content
- )
+ channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
)
return channel.json_body
@@ -327,7 +341,7 @@ class RestHelper:
"""
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
- _, channel = make_request(
+ channel = make_request(
self.hs.get_reactor(),
FakeSite(resource),
"POST",
@@ -344,3 +358,268 @@ class RestHelper:
)
return channel.json_body
+
+ def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+ """Log in (as a new user) via OIDC
+
+ Returns the result of the final token login.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+ """
+ client_redirect_url = "https://x"
+ channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
+
+ # expect a confirmation page
+ assert channel.code == 200, channel.result
+
+ # fish the matrix login token out of the body of the confirmation page
+ m = re.search(
+ 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+ channel.text_body,
+ )
+ assert m, channel.text_body
+ login_token = m.group(1)
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token and device id.
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ assert channel.code == 200
+ return channel.json_body
+
+ def auth_via_oidc(
+ self,
+ user_info_dict: JsonDict,
+ client_redirect_url: Optional[str] = None,
+ ui_auth_session_id: Optional[str] = None,
+ ) -> FakeChannel:
+ """Perform an OIDC authentication flow via a mock OIDC provider.
+
+ This can be used for either login or user-interactive auth.
+
+ Starts by making a request to the relevant synapse redirect endpoint, which is
+ expected to serve a 302 to the OIDC provider. We then make a request to the
+ OIDC callback endpoint, intercepting the HTTP requests that will get sent back
+ to the OIDC provider.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+
+ Args:
+ user_info_dict: the remote userinfo that the OIDC provider should present.
+ Typically this should be '{"sub": "<remote user id>"}'.
+ client_redirect_url: for a login flow, the client redirect URL to pass to
+ the login redirect endpoint
+ ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
+ of the UI auth.
+
+ Returns:
+ A FakeChannel containing the result of calling the OIDC callback endpoint.
+ Note that the response code may be a 200, 302 or 400 depending on how things
+ went.
+ """
+
+ cookies = {}
+
+ # if we're doing a ui auth, hit the ui auth redirect endpoint
+ if ui_auth_session_id:
+ # can't set the client redirect url for UI Auth
+ assert client_redirect_url is None
+ oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
+ else:
+ # otherwise, hit the login redirect endpoint
+ oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+
+ # we now have a URI for the OIDC IdP, but we skip that and go straight
+ # back to synapse's OIDC callback resource. However, we do need the "state"
+ # param that synapse passes to the IdP via query params, as well as the cookie
+ # that synapse passes to the client.
+
+ oauth_uri_path, _ = oauth_uri.split("?", 1)
+ assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
+ "unexpected SSO URI " + oauth_uri_path
+ )
+ return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+
+ def complete_oidc_auth(
+ self,
+ oauth_uri: str,
+ cookies: Mapping[str, str],
+ user_info_dict: JsonDict,
+ ) -> FakeChannel:
+ """Mock out an OIDC authentication flow
+
+ Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
+ initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
+ Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
+ sent back to the OIDC provider.
+
+ Requires the OIDC callback resource to be mounted at the normal place.
+
+ Args:
+ oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
+ from initiate_sso_login or initiate_sso_ui_auth).
+ cookies: the cookies set by synapse's redirect endpoint, which will be
+ sent back to the callback endpoint.
+ user_info_dict: the remote userinfo that the OIDC provider should present.
+ Typically this should be '{"sub": "<remote user id>"}'.
+
+ Returns:
+ A FakeChannel containing the result of calling the OIDC callback endpoint.
+ """
+ _, oauth_uri_qs = oauth_uri.split("?", 1)
+ params = urllib.parse.parse_qs(oauth_uri_qs)
+ callback_uri = "%s?%s" % (
+ urllib.parse.urlparse(params["redirect_uri"][0]).path,
+ urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+ )
+
+ # before we hit the callback uri, stub out some methods in the http client so
+ # that we don't have to handle full HTTPS requests.
+ # (expected url, json response) pairs, in the order we expect them.
+ expected_requests = [
+ # first we get a hit to the token endpoint, which we tell to return
+ # a dummy OIDC access token
+ (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
+ # and then one to the user_info endpoint, which returns our remote user id.
+ (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
+ ]
+
+ async def mock_req(method: str, uri: str, data=None, headers=None):
+ (expected_uri, resp_obj) = expected_requests.pop(0)
+ assert uri == expected_uri
+ resp = FakeResponse(
+ code=200,
+ phrase=b"OK",
+ body=json.dumps(resp_obj).encode("utf-8"),
+ )
+ return resp
+
+ with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+ # now hit the callback URI with the right params and a made-up code
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ callback_uri,
+ custom_headers=[
+ ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
+ ],
+ )
+ return channel
+
+ def initiate_sso_login(
+ self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
+ ) -> str:
+ """Make a request to the login-via-sso redirect endpoint, and return the target
+
+ Assumes that exactly one SSO provider has been configured. Requires the login
+ servlet to be mounted.
+
+ Args:
+ client_redirect_url: the client redirect URL to pass to the login redirect
+ endpoint
+ cookies: any cookies returned will be added to this dict
+
+ Returns:
+ the URI that the client gets redirected to (ie, the SSO server)
+ """
+ params = {}
+ if client_redirect_url:
+ params["redirectUrl"] = client_redirect_url
+
+ # hit the redirect url (which should redirect back to the redirect url. This
+ # is the easiest way of figuring out what the Host header ought to be set to
+ # to keep Synapse happy.
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
+ )
+ assert channel.code == 302
+
+ # hit the redirect url again with the right Host header, which should now issue
+ # a cookie and redirect to the SSO provider.
+ location = channel.headers.getRawHeaders("Location")[0]
+ parts = urllib.parse.urlsplit(location)
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ urllib.parse.urlunsplit(("", "") + parts[2:]),
+ custom_headers=[
+ ("Host", parts[1]),
+ ],
+ )
+
+ assert channel.code == 302
+ channel.extract_cookies(cookies)
+ return channel.headers.getRawHeaders("Location")[0]
+
+ def initiate_sso_ui_auth(
+ self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
+ ) -> str:
+ """Make a request to the ui-auth-via-sso endpoint, and return the target
+
+ Assumes that exactly one SSO provider has been configured. Requires the
+ AuthRestServlet to be mounted.
+
+ Args:
+ ui_auth_session_id: the session id of the UI auth
+ cookies: any cookies returned will be added to this dict
+
+ Returns:
+ the URI that the client gets linked to (ie, the SSO server)
+ """
+ sso_redirect_endpoint = (
+ "/_matrix/client/r0/auth/m.login.sso/fallback/web?"
+ + urllib.parse.urlencode({"session": ui_auth_session_id})
+ )
+ # hit the redirect url (which will issue a cookie and state)
+ channel = make_request(
+ self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
+ )
+ # that should serve a confirmation page
+ assert channel.code == 200, channel.text_body
+ channel.extract_cookies(cookies)
+
+ # parse the confirmation page to fish out the link.
+ p = TestHtmlParser()
+ p.feed(channel.text_body)
+ p.close()
+ assert len(p.links) == 1, "not exactly one link in confirmation page"
+ oauth_uri = p.links[0]
+ return oauth_uri
+
+
+# an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
+TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
+TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
+TEST_OIDC_CONFIG = {
+ "enabled": True,
+ "discover": False,
+ "issuer": "https://issuer.test",
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["profile"],
+ "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
+ "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
+ "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
+ "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 2ac1ecb7d3..e72b61963d 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -19,13 +19,12 @@ import os
import re
from email.parser import Parser
from typing import Optional
-from urllib.parse import urlencode
import pkg_resources
import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, HttpResponseException
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
@@ -76,8 +75,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
def test_basic_password_reset(self):
- """Test basic password reset flow
- """
+ """Test basic password reset flow"""
old_password = "monkey"
new_password = "kangeroo"
@@ -113,6 +111,55 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)
+ @override_config({"rc_3pid_validation": {"burst_count": 3}})
+ def test_ratelimit_by_email(self):
+ """Test that we ratelimit /requestToken for the same email."""
+ old_password = "monkey"
+ new_password = "kangeroo"
+
+ user_id = self.register_user("kermit", old_password)
+ self.login("kermit", old_password)
+
+ email = "test1@example.com"
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ def reset(ip):
+ client_secret = "foobar"
+ session_id = self._request_token(email, client_secret, ip)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ self._reset_password(new_password, session_id, client_secret)
+
+ self.email_attempts.clear()
+
+ # We expect to be able to make three requests before getting rate
+ # limited.
+ #
+ # We change IPs to ensure that we're not being ratelimited due to the
+ # same IP
+ reset("127.0.0.1")
+ reset("127.0.0.2")
+ reset("127.0.0.3")
+
+ with self.assertRaises(HttpResponseException) as cm:
+ reset("127.0.0.4")
+
+ self.assertEqual(cm.exception.code, 429)
+
def test_basic_password_reset_canonicalise_email(self):
"""Test basic password reset flow
Request password reset with different spelling
@@ -154,8 +201,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.attempt_wrong_password_login("kermit", old_password)
def test_cant_reset_password_without_clicking_link(self):
- """Test that we do actually need to click the link in the email
- """
+ """Test that we do actually need to click the link in the email"""
old_password = "monkey"
new_password = "kangeroo"
@@ -240,13 +286,20 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(session_id)
- def _request_token(self, email, client_secret):
- request, channel = self.make_request(
+ def _request_token(self, email, client_secret, ip="127.0.0.1"):
+ channel = self.make_request(
"POST",
b"account/password/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
+ client_ip=ip,
)
- self.assertEquals(200, channel.code, channel.result)
+
+ if channel.code != 200:
+ raise HttpResponseException(
+ channel.code,
+ channel.result["reason"],
+ channel.result["body"],
+ )
return channel.json_body["sid"]
@@ -255,7 +308,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "")
# Load the password reset confirmation page
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
"GET",
@@ -268,20 +321,13 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
# password reset confirm button
- # Send arguments as url-encoded form data, matching the template's behaviour
- form_args = []
- for key, value_list in request.args.items():
- for value in value_list:
- arg = (key, value)
- form_args.append(arg)
-
# Confirm the password reset
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
"POST",
path,
- content=urlencode(form_args).encode("utf8"),
+ content=b"",
shorthand=False,
content_is_form=True,
)
@@ -310,7 +356,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
def _reset_password(
self, new_password, session_id, client_secret, expected_code=200
):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/password",
{
@@ -352,8 +398,8 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
# Check that this access token has been invalidated.
- request, channel = self.make_request("GET", "account/whoami")
- self.assertEqual(request.code, 401)
+ channel = self.make_request("GET", "account/whoami")
+ self.assertEqual(channel.code, 401)
def test_pending_invites(self):
"""Tests that deactivating a user rejects every pending invite for them."""
@@ -407,10 +453,10 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
"erase": False,
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
@@ -517,9 +563,22 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_address_trim(self):
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
+ @override_config({"rc_3pid_validation": {"burst_count": 3}})
+ def test_ratelimit_by_ip(self):
+ """Tests that adding emails is ratelimited by IP"""
+
+ # We expect to be able to set three emails before getting ratelimited.
+ self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
+ self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
+ self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+
+ with self.assertRaises(HttpResponseException) as cm:
+ self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+
+ self.assertEqual(cm.exception.code, 429)
+
def test_add_email_if_disabled(self):
- """Test adding email to profile when doing so is disallowed
- """
+ """Test adding email to profile when doing so is disallowed"""
self.hs.config.enable_3pid_changes = False
client_secret = "foobar"
@@ -530,7 +589,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self._validate_token(link)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"/_matrix/client/unstable/account/3pid/add",
{
@@ -548,16 +607,17 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
- request, channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ channel = self.make_request(
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email(self):
- """Test deleting an email from profile
- """
+ """Test deleting an email from profile"""
# Add a threepid
self.get_success(
self.store.user_add_threepid(
@@ -569,7 +629,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/3pid/delete",
{"medium": "email", "address": self.email},
@@ -578,16 +638,17 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
- request, channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ channel = self.make_request(
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email_if_disabled(self):
- """Test deleting an email from profile when disallowed
- """
+ """Test deleting an email from profile when disallowed"""
self.hs.config.enable_3pid_changes = False
# Add a threepid
@@ -601,7 +662,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/3pid/delete",
{"medium": "email", "address": self.email},
@@ -612,8 +673,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
- request, channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ channel = self.make_request(
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -621,15 +684,14 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
def test_cant_add_email_without_clicking_link(self):
- """Test that we do actually need to click the link in the email
- """
+ """Test that we do actually need to click the link in the email"""
client_secret = "foobar"
session_id = self._request_token(self.email, client_secret)
self.assertEquals(len(self.email_attempts), 1)
# Attempt to add email without clicking the link
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"/_matrix/client/unstable/account/3pid/add",
{
@@ -647,8 +709,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
- request, channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ channel = self.make_request(
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -662,7 +726,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
session_id = "weasle"
# Attempt to add email without even requesting an email
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"/_matrix/client/unstable/account/3pid/add",
{
@@ -680,8 +744,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
- request, channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ channel = self.make_request(
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -726,7 +792,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Ensure not providing a next_link parameter still works
self._request_token(
- "something@example.com", "some_secret", next_link=None, expect_code=200,
+ "something@example.com",
+ "some_secret",
+ next_link=None,
+ expect_code=200,
)
self._request_token(
@@ -784,17 +853,29 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
if next_link:
body["next_link"] = next_link
- request, channel = self.make_request(
- "POST", b"account/3pid/email/requestToken", body,
+ channel = self.make_request(
+ "POST",
+ b"account/3pid/email/requestToken",
+ body,
)
- self.assertEquals(expect_code, channel.code, channel.result)
+
+ if channel.code != expect_code:
+ raise HttpResponseException(
+ channel.code,
+ channel.result["reason"],
+ channel.result["body"],
+ )
return channel.json_body.get("sid")
def _request_token_invalid_email(
- self, email, expected_errcode, expected_error, client_secret="foobar",
+ self,
+ email,
+ expected_errcode,
+ expected_error,
+ client_secret="foobar",
):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
@@ -807,7 +888,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Remove the host
path = link.replace("https://example.com", "")
- request, channel = self.make_request("GET", path, shorthand=False)
+ channel = self.make_request("GET", path, shorthand=False)
self.assertEquals(200, channel.code, channel.result)
def _get_link_from_email(self):
@@ -831,17 +912,18 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
return match.group(0)
def _add_email(self, request_email, expected_email):
- """Test adding an email to profile
- """
+ """Test adding an email to profile"""
+ previous_email_attempts = len(self.email_attempts)
+
client_secret = "foobar"
session_id = self._request_token(request_email, client_secret)
- self.assertEquals(len(self.email_attempts), 1)
+ self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
link = self._get_link_from_email()
self._validate_token(link)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"/_matrix/client/unstable/account/3pid/add",
{
@@ -859,10 +941,14 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
- request, channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ channel = self.make_request(
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
+
+ threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
+ self.assertIn(expected_email, threepids)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 77246e478f..9734a2159a 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,20 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Union
+from typing import Union
from twisted.internet.defer import succeed
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
-from synapse.http.site import SynapseRequest
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.types import JsonDict
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.types import JsonDict, UserID
from tests import unittest
+from tests.handlers.test_oidc import HAS_OIDC
+from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
+from tests.unittest import override_config, skip_unless
class DummyRecaptchaChecker(UserInteractiveAuthChecker):
@@ -64,11 +68,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
def register(self, expected_response: int, body: JsonDict) -> FakeChannel:
"""Make a register request."""
- request, channel = self.make_request(
- "POST", "register", body
- ) # type: SynapseRequest, FakeChannel
+ channel = self.make_request("POST", "register", body)
- self.assertEqual(request.code, expected_response)
+ self.assertEqual(channel.code, expected_response)
return channel
def recaptcha(
@@ -78,18 +80,18 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
if post_session is None:
post_session = session
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "auth/m.login.recaptcha/fallback/web?session=" + session
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request.code, 200)
+ )
+ self.assertEqual(channel.code, 200)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"auth/m.login.recaptcha/fallback/web?session="
+ post_session
+ "&g-recaptcha-response=a",
)
- self.assertEqual(request.code, expected_post_response)
+ self.assertEqual(channel.code, expected_post_response)
# The recaptcha handler is called with the response given
attempts = self.recaptcha_checker.recaptcha_attempts
@@ -100,7 +102,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
"""Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec
channel = self.register(
- 401, {"username": "user", "type": "m.login.password", "password": "bar"},
+ 401,
+ {"username": "user", "type": "m.login.password", "password": "bar"},
)
# Grab the session
@@ -156,31 +159,51 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets,
]
+ def default_config(self):
+ config = super().default_config()
+
+ # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
+ # False, so synapse will see the requested uri as http://..., so using http in
+ # the public_baseurl stops Synapse trying to redirect to https.
+ config["public_baseurl"] = "http://synapse.test"
+
+ if HAS_OIDC:
+ # we enable OIDC as a way of testing SSO flows
+ oidc_config = {}
+ oidc_config.update(TEST_OIDC_CONFIG)
+ oidc_config["allow_existing_users"] = True
+ config["oidc_config"] = oidc_config
+
+ return config
+
+ def create_resource_dict(self):
+ resource_dict = super().create_resource_dict()
+ resource_dict.update(build_synapse_client_resource_tree(self.hs))
+ return resource_dict
+
def prepare(self, reactor, clock, hs):
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
- self.user_tok = self.login("test", self.user_pass)
-
- def get_device_ids(self) -> List[str]:
- # Get the list of devices so one can be deleted.
- request, channel = self.make_request(
- "GET", "devices", access_token=self.user_tok,
- ) # type: SynapseRequest, FakeChannel
-
- # Get the ID of the device.
- self.assertEqual(request.code, 200)
- return [d["device_id"] for d in channel.json_body["devices"]]
+ self.device_id = "dev1"
+ self.user_tok = self.login("test", self.user_pass, self.device_id)
def delete_device(
- self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
+ self,
+ access_token: str,
+ device: str,
+ expected_response: int,
+ body: Union[bytes, JsonDict] = b"",
) -> FakeChannel:
"""Delete an individual device."""
- request, channel = self.make_request(
- "DELETE", "devices/" + device, body, access_token=self.user_tok
- ) # type: SynapseRequest, FakeChannel
+ channel = self.make_request(
+ "DELETE",
+ "devices/" + device,
+ body,
+ access_token=access_token,
+ )
# Ensure the response is sane.
- self.assertEqual(request.code, expected_response)
+ self.assertEqual(channel.code, expected_response)
return channel
@@ -188,12 +211,15 @@ class UIAuthTests(unittest.HomeserverTestCase):
"""Delete 1 or more devices."""
# Note that this uses the delete_devices endpoint so that we can modify
# the payload half-way through some tests.
- request, channel = self.make_request(
- "POST", "delete_devices", body, access_token=self.user_tok,
- ) # type: SynapseRequest, FakeChannel
+ channel = self.make_request(
+ "POST",
+ "delete_devices",
+ body,
+ access_token=self.user_tok,
+ )
# Ensure the response is sane.
- self.assertEqual(request.code, expected_response)
+ self.assertEqual(channel.code, expected_response)
return channel
@@ -201,11 +227,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
"""
Test user interactive authentication outside of registration.
"""
- device_id = self.get_device_ids()[0]
-
# Attempt to delete this device.
# Returns a 401 as per the spec
- channel = self.delete_device(device_id, 401)
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
# Grab the session
session = channel.json_body["session"]
@@ -214,7 +238,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow.
self.delete_device(
- device_id,
+ self.user_tok,
+ self.device_id,
200,
{
"auth": {
@@ -233,13 +258,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
UIA - check that still works.
"""
- device_id = self.get_device_ids()[0]
- channel = self.delete_device(device_id, 401)
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
session = channel.json_body["session"]
# Make another request providing the UI auth flow.
self.delete_device(
- device_id,
+ self.user_tok,
+ self.device_id,
200,
{
"auth": {
@@ -262,14 +287,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
session ID should be rejected.
"""
# Create a second login.
- self.login("test", self.user_pass)
-
- device_ids = self.get_device_ids()
- self.assertEqual(len(device_ids), 2)
+ self.login("test", self.user_pass, "dev2")
# Attempt to delete the first device.
# Returns a 401 as per the spec
- channel = self.delete_devices(401, {"devices": [device_ids[0]]})
+ channel = self.delete_devices(401, {"devices": [self.device_id]})
# Grab the session
session = channel.json_body["session"]
@@ -281,7 +303,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
self.delete_devices(
200,
{
- "devices": [device_ids[1]],
+ "devices": ["dev2"],
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": self.user},
@@ -296,14 +318,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
The initial requested URI cannot be modified during the user interactive authentication session.
"""
# Create a second login.
- self.login("test", self.user_pass)
-
- device_ids = self.get_device_ids()
- self.assertEqual(len(device_ids), 2)
+ self.login("test", self.user_pass, "dev2")
# Attempt to delete the first device.
# Returns a 401 as per the spec
- channel = self.delete_device(device_ids[0], 401)
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
# Grab the session
session = channel.json_body["session"]
@@ -312,8 +331,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow, but try to delete the
# second device. This results in an error.
+ #
+ # This makes use of the fact that the device ID is embedded into the URL.
self.delete_device(
- device_ids[1],
+ self.user_tok,
+ "dev2",
403,
{
"auth": {
@@ -324,3 +346,153 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
},
)
+
+ @unittest.override_config({"ui_auth": {"session_timeout": "5s"}})
+ def test_can_reuse_session(self):
+ """
+ The session can be reused if configured.
+
+ Compare to test_cannot_change_uri.
+ """
+ # Create a second and third login.
+ self.login("test", self.user_pass, "dev2")
+ self.login("test", self.user_pass, "dev3")
+
+ # Attempt to delete a device. This works since the user just logged in.
+ self.delete_device(self.user_tok, "dev2", 200)
+
+ # Move the clock forward past the validation timeout.
+ self.reactor.advance(6)
+
+ # Deleting another devices throws the user into UI auth.
+ channel = self.delete_device(self.user_tok, "dev3", 401)
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # Make another request providing the UI auth flow.
+ self.delete_device(
+ self.user_tok,
+ "dev3",
+ 200,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
+
+ # Make another request, but try to delete the first device. This works
+ # due to re-using the previous session.
+ #
+ # Note that *no auth* information is provided, not even a session iD!
+ self.delete_device(self.user_tok, self.device_id, 200)
+
+ @skip_unless(HAS_OIDC, "requires OIDC")
+ @override_config({"oidc_config": TEST_OIDC_CONFIG})
+ def test_ui_auth_via_sso(self):
+ """Test a successful UI Auth flow via SSO
+
+ This includes:
+ * hitting the UIA SSO redirect endpoint
+ * checking it serves a confirmation page which links to the OIDC provider
+ * calling back to the synapse oidc callback
+ * checking that the original operation succeeds
+ """
+
+ # log the user in
+ remote_user_id = UserID.from_string(self.user).localpart
+ login_resp = self.helper.login_via_oidc(remote_user_id)
+ self.assertEqual(login_resp["user_id"], self.user)
+
+ # initiate a UI Auth process by attempting to delete the device
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+ # check that SSO is offered
+ flows = channel.json_body["flows"]
+ self.assertIn({"stages": ["m.login.sso"]}, flows)
+
+ # run the UIA-via-SSO flow
+ session_id = channel.json_body["session"]
+ channel = self.helper.auth_via_oidc(
+ {"sub": remote_user_id}, ui_auth_session_id=session_id
+ )
+
+ # that should serve a confirmation page
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # and now the delete request should succeed.
+ self.delete_device(
+ self.user_tok,
+ self.device_id,
+ 200,
+ body={"auth": {"session": session_id}},
+ )
+
+ @skip_unless(HAS_OIDC, "requires OIDC")
+ @override_config({"oidc_config": TEST_OIDC_CONFIG})
+ def test_does_not_offer_password_for_sso_user(self):
+ login_resp = self.helper.login_via_oidc("username")
+ user_tok = login_resp["access_token"]
+ device_id = login_resp["device_id"]
+
+ # now call the device deletion API: we should get the option to auth with SSO
+ # and not password.
+ channel = self.delete_device(user_tok, device_id, 401)
+
+ flows = channel.json_body["flows"]
+ self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
+
+ def test_does_not_offer_sso_for_password_user(self):
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+ flows = channel.json_body["flows"]
+ self.assertEqual(flows, [{"stages": ["m.login.password"]}])
+
+ @skip_unless(HAS_OIDC, "requires OIDC")
+ @override_config({"oidc_config": TEST_OIDC_CONFIG})
+ def test_offers_both_flows_for_upgraded_user(self):
+ """A user that had a password and then logged in with SSO should get both flows"""
+ login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+ self.assertEqual(login_resp["user_id"], self.user)
+
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+ flows = channel.json_body["flows"]
+ # we have no particular expectations of ordering here
+ self.assertIn({"stages": ["m.login.password"]}, flows)
+ self.assertIn({"stages": ["m.login.sso"]}, flows)
+ self.assertEqual(len(flows), 2)
+
+ @skip_unless(HAS_OIDC, "requires OIDC")
+ @override_config({"oidc_config": TEST_OIDC_CONFIG})
+ def test_ui_auth_fails_for_incorrect_sso_user(self):
+ """If the user tries to authenticate with the wrong SSO user, they get an error"""
+ # log the user in
+ login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+ self.assertEqual(login_resp["user_id"], self.user)
+
+ # start a UI Auth flow by attempting to delete a device
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+ flows = channel.json_body["flows"]
+ self.assertIn({"stages": ["m.login.sso"]}, flows)
+ session_id = channel.json_body["session"]
+
+ # do the OIDC auth, but auth as the wrong user
+ channel = self.helper.auth_via_oidc(
+ {"sub": "wrong_user"}, ui_auth_session_id=session_id
+ )
+
+ # that should return a failure message
+ self.assertSubstring("We were unable to validate", channel.text_body)
+
+ # ... and the delete op should now fail with a 403
+ self.delete_device(
+ self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
+ )
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index 767e126875..287a1a485c 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -18,6 +18,7 @@ from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import capabilities
from tests import unittest
+from tests.unittest import override_config
class CapabilitiesTestCase(unittest.HomeserverTestCase):
@@ -33,10 +34,11 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver()
self.store = hs.get_datastore()
self.config = hs.config
+ self.auth_handler = hs.get_auth_handler()
return hs
def test_check_auth_required(self):
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401)
@@ -44,7 +46,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.register_user("user", "pass")
access_token = self.login("user", "pass")
- request, channel = self.make_request("GET", self.url, access_token=access_token)
+ channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
@@ -56,21 +58,47 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
capabilities["m.room_versions"]["default"],
)
- def test_get_change_password_capabilities(self):
+ def test_get_change_password_capabilities_password_login(self):
localpart = "user"
password = "pass"
user = self.register_user(localpart, password)
access_token = self.login(user, password)
- request, channel = self.make_request("GET", self.url, access_token=access_token)
+ channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
-
- # Test case where password is handled outside of Synapse
self.assertTrue(capabilities["m.change_password"]["enabled"])
- self.get_success(self.store.user_set_password_hash(user, None))
- request, channel = self.make_request("GET", self.url, access_token=access_token)
+
+ @override_config({"password_config": {"localdb_enabled": False}})
+ def test_get_change_password_capabilities_localdb_disabled(self):
+ localpart = "user"
+ password = "pass"
+ user = self.register_user(localpart, password)
+ access_token = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["m.change_password"]["enabled"])
+
+ @override_config({"password_config": {"enabled": False}})
+ def test_get_change_password_capabilities_password_disabled(self):
+ localpart = "user"
+ password = "pass"
+ user = self.register_user(localpart, password)
+ access_token = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 231d5aefea..f761c44936 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -36,7 +36,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastore()
def test_add_filter(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
@@ -49,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEquals(filter.result, self.EXAMPLE_FILTER)
def test_add_filter_for_other_user(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON,
@@ -61,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
@@ -79,7 +79,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
self.reactor.advance(1)
filter_id = filter_id.result
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
)
@@ -87,7 +87,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
)
@@ -97,7 +97,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
# Currently invalid params do not have an appropriate errcode
# in errors.py
def test_get_filter_invalid_id(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
)
@@ -105,7 +105,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
# No ID also returns an invalid_id error
def test_get_filter_no_id(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
)
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
index ee86b94917..5ebc5707a5 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -70,9 +70,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_get_policy(self):
"""Tests if the /password_policy endpoint returns the configured policy."""
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/password_policy"
- )
+ channel = self.make_request("GET", "/_matrix/client/r0/password_policy")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(
@@ -89,52 +87,62 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_too_short(self):
request_data = json.dumps({"username": "kermit", "password": "shorty"})
- request, channel = self.make_request("POST", self.register_url, request_data)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_TOO_SHORT,
+ channel.result,
)
def test_password_no_digit(self):
request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
- request, channel = self.make_request("POST", self.register_url, request_data)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_NO_DIGIT,
+ channel.result,
)
def test_password_no_symbol(self):
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
- request, channel = self.make_request("POST", self.register_url, request_data)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_NO_SYMBOL,
+ channel.result,
)
def test_password_no_uppercase(self):
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
- request, channel = self.make_request("POST", self.register_url, request_data)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_NO_UPPERCASE,
+ channel.result,
)
def test_password_no_lowercase(self):
request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
- request, channel = self.make_request("POST", self.register_url, request_data)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_NO_LOWERCASE,
+ channel.result,
)
def test_password_compliant(self):
request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
- request, channel = self.make_request("POST", self.register_url, request_data)
+ channel = self.make_request("POST", self.register_url, request_data)
# Getting a 401 here means the password has passed validation and the server has
# responded with a list of registration flows.
@@ -160,7 +168,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
},
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/account/password",
request_data,
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 610b263577..2d4ce871eb 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -67,7 +67,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.hs.get_datastore().services_cache.append(appservice)
request_data = json.dumps({"username": "as_user_kermit"})
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
@@ -78,7 +78,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_appservice_registration_invalid(self):
self.appservice = None # no application service exists
request_data = json.dumps({"username": "kermit"})
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
@@ -86,7 +86,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_bad_password(self):
request_data = json.dumps({"username": "kermit", "password": 666})
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")
@@ -101,7 +101,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"auth": {"type": LoginType.DUMMY},
}
request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
det_data = {
"user_id": user_id,
@@ -116,16 +116,17 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
+ self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self):
self.hs.config.macaroon_secret_key = "test"
self.hs.config.allow_guest_access = True
- request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -134,7 +135,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_disabled_guest_registration(self):
self.hs.config.allow_guest_access = False
- request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
@@ -143,7 +144,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_ratelimiting_guest(self):
for i in range(0, 6):
url = self.url + b"?kind=guest"
- request, channel = self.make_request(b"POST", url, b"{}")
+ channel = self.make_request(b"POST", url, b"{}")
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -153,7 +154,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
- request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -167,7 +168,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"auth": {"type": LoginType.DUMMY},
}
request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -177,12 +178,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
- request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_advertised_flows(self):
- request, channel = self.make_request(b"POST", self.url, b"{}")
+ channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@@ -205,7 +206,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}
)
def test_advertised_flows_captcha_and_terms_and_3pids(self):
- request, channel = self.make_request(b"POST", self.url, b"{}")
+ channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@@ -237,7 +238,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}
)
def test_advertised_flows_no_msisdn_email_required(self):
- request, channel = self.make_request(b"POST", self.url, b"{}")
+ channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@@ -277,7 +278,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"register/email/requestToken",
{"client_secret": "foobar", "email": email, "send_attempt": 1},
@@ -407,13 +408,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
@@ -435,14 +436,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/account_validity/validity"
params = {"user_id": user_id}
request_data = json.dumps(params)
- request, channel = self.make_request(
- b"POST", url, request_data, access_token=admin_tok
- )
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_manual_expire(self):
@@ -459,14 +458,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
- request, channel = self.make_request(
- b"POST", url, request_data, access_token=admin_tok
- )
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
@@ -486,20 +483,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
- request, channel = self.make_request(
- b"POST", url, request_data, access_token=admin_tok
- )
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Try to log the user out
- request, channel = self.make_request(b"POST", "/logout", access_token=tok)
+ channel = self.make_request(b"POST", "/logout", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Log the user in again (allowed for expired accounts)
tok = self.login("kermit", "monkey")
# Try to log out all of the user's sessions
- request, channel = self.make_request(b"POST", "/logout/all", access_token=tok)
+ channel = self.make_request(b"POST", "/logout/all", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -550,9 +545,7 @@ class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
- request, channel = self.make_request(
- b"POST", url, request_data, access_token=admin_tok
- )
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Mock the homeserver's HTTP client
@@ -590,9 +583,7 @@ class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
- request, channel = self.make_request(
- b"POST", url, request_data, access_token=admin_tok
- )
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Wait for the background job to run which hides expired users in the directory
@@ -623,9 +614,7 @@ class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
- request, channel = self.make_request(
- b"POST", url, request_data, access_token=admin_tok
- )
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.pump(10)
@@ -719,7 +708,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# retrieve the token from the DB.
renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
- request, channel = self.make_request(b"GET", url)
+ channel = self.make_request(b"GET", url)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Check that we're getting HTML back.
@@ -737,7 +726,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Move 1 day forward. Try to renew with the same token again.
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
- request, channel = self.make_request(b"GET", url)
+ channel = self.make_request(b"GET", url)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Check that we're getting HTML back.
@@ -746,8 +735,10 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Check that the HTML we're getting is the one we expect when reusing a
# token. The account expiration date should not have changed.
- expected_html = self.hs.config.account_validity_account_previously_renewed_template.render(
- expiration_ts=expiration_ts
+ expected_html = (
+ self.hs.config.account_validity_account_previously_renewed_template.render(
+ expiration_ts=expiration_ts
+ )
)
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
@@ -757,14 +748,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# our access token should be denied from now, otherwise they should
# succeed.
self.reactor.advance(datetime.timedelta(days=3).total_seconds())
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_renewal_invalid_token(self):
# Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123"
- request, channel = self.make_request(b"GET", url)
+ channel = self.make_request(b"GET", url)
self.assertEquals(channel.result["code"], b"404", channel.result)
# Check that we're getting HTML back.
@@ -782,7 +773,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.email_attempts = []
(user_id, tok) = self.create_user()
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST",
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
@@ -806,10 +797,10 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"erase": False,
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
@@ -857,7 +848,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.email_attempts = []
# Test that we're still able to manually trigger a mail to be sent.
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST",
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 6cd4eb6624..e7bb5583fc 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -39,6 +39,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# We need to enable msc1849 support for aggregations
config = self.default_config()
config["experimental_msc1849_support_enabled"] = True
+
+ # We enable frozen dicts as relations/edits change event contents, so we
+ # want to test that we don't modify the events in the caches.
+ config["use_frozen_dicts"] = True
+
return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, hs):
@@ -60,7 +65,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
event_id = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, event_id),
access_token=self.user_token,
@@ -83,14 +88,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
def test_deny_membership(self):
- """Test that we deny relations on membership events
- """
+ """Test that we deny relations on membership events"""
channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
self.assertEquals(400, channel.code, channel.json_body)
def test_deny_double_react(self):
- """Test that we deny relations on membership events
- """
+ """Test that we deny relations on membership events"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -98,8 +101,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(400, channel.code, channel.json_body)
def test_basic_paginate_relations(self):
- """Tests that calling pagination API correctly the latest relations.
- """
+ """Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
self.assertEquals(200, channel.code, channel.json_body)
@@ -107,7 +109,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
annotation_id = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
% (self.room, self.parent_id),
@@ -152,7 +154,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if prev_token:
from_token = "&from=" + prev_token
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
% (self.room, self.parent_id, from_token),
@@ -174,8 +176,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(found_event_ids, expected_event_ids)
def test_aggregation_pagination_groups(self):
- """Test that we can paginate annotation groups correctly.
- """
+ """Test that we can paginate annotation groups correctly."""
# We need to create ten separate users to send each reaction.
access_tokens = [self.user_token, self.user2_token]
@@ -210,7 +211,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if prev_token:
from_token = "&from=" + prev_token
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
% (self.room, self.parent_id, from_token),
@@ -240,8 +241,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(sent_groups, found_groups)
def test_aggregation_pagination_within_group(self):
- """Test that we can paginate within an annotation group.
- """
+ """Test that we can paginate within an annotation group."""
# We need to create ten separate users to send each reaction.
access_tokens = [self.user_token, self.user2_token]
@@ -279,7 +279,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if prev_token:
from_token = "&from=" + prev_token
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s"
"/aggregations/%s/%s/m.reaction/%s?limit=1%s"
@@ -311,8 +311,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(found_event_ids, expected_event_ids)
def test_aggregation(self):
- """Test that annotations get correctly aggregated.
- """
+ """Test that annotations get correctly aggregated."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -325,7 +324,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
self.assertEquals(200, channel.code, channel.json_body)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
% (self.room, self.parent_id),
@@ -344,8 +343,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
def test_aggregation_redactions(self):
- """Test that annotations get correctly aggregated after a redaction.
- """
+ """Test that annotations get correctly aggregated after a redaction."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -357,7 +355,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Now lets redact one of the 'a' reactions
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
access_token=self.user_token,
@@ -365,7 +363,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(200, channel.code, channel.json_body)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
% (self.room, self.parent_id),
@@ -379,10 +377,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
def test_aggregation_must_be_annotation(self):
- """Test that aggregations must be annotations.
- """
+ """Test that aggregations must be annotations."""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
% (self.room, self.parent_id, RelationTypes.REPLACE),
@@ -414,7 +411,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
reply_2 = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
@@ -437,8 +434,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
def test_edit(self):
- """Test that a simple edit works.
- """
+ """Test that a simple edit works."""
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
channel = self._send_relation(
@@ -450,7 +446,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
edit_event_id = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
@@ -507,7 +503,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(200, channel.code, channel.json_body)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
@@ -527,6 +523,63 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
+ def test_edit_reply(self):
+ """Test that editing a reply works."""
+
+ # Create a reply to edit.
+ channel = self._send_relation(
+ RelationTypes.REFERENCE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "A reply!"},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ reply = channel.json_body["event_id"]
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ parent_id=reply,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ edit_event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, reply),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect to see the new body in the dict, as well as the reference
+ # metadata sill intact.
+ self.assertDictContainsSubset(new_body, channel.json_body["content"])
+ self.assertDictContainsSubset(
+ {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "key": None,
+ "rel_type": "m.reference",
+ }
+ },
+ channel.json_body["content"],
+ )
+
+ # We expect that the edit relation appears in the unsigned relations
+ # section.
+ relations_dict = channel.json_body["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ )
+
def test_relations_redaction_redacts_edits(self):
"""Test that edits of an event are redacted when the original event
is redacted.
@@ -549,7 +602,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Check the relation is returned
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
% (self.room, original_event_id),
@@ -561,7 +614,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(len(channel.json_body["chunk"]), 1)
# Redact the original event
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/redact/%s/%s"
% (self.room, original_event_id, "test_relations_redaction_redacts_edits"),
@@ -571,7 +624,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Try to check for remaining m.replace relations
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
% (self.room, original_event_id),
@@ -598,7 +651,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Redact the original
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/redact/%s/%s"
% (
@@ -612,7 +665,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Check that aggregations returns zero
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction"
% (self.room, original_event_id),
@@ -656,7 +709,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
original_id = parent_id if parent_id else self.parent_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, original_id, relation_type, event_type, query),
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
index 562a9c1ba4..dd83a1f8ff 100644
--- a/tests/rest/client/v2_alpha/test_shared_rooms.py
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -17,6 +17,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import shared_rooms
from tests import unittest
+from tests.server import FakeChannel
class UserSharedRoomsTest(unittest.HomeserverTestCase):
@@ -40,75 +41,75 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.store = hs.get_datastore()
self.handler = hs.get_user_directory_handler()
- def _get_shared_rooms(self, token, other_user):
- request, channel = self.make_request(
+ def _get_shared_rooms(self, token, other_user) -> FakeChannel:
+ return self.make_request(
"GET",
"/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
% other_user,
access_token=token,
)
- return request, channel
def test_shared_room_list_public(self):
"""
A room should show up in the shared list of rooms between two users
if it is public.
"""
- u1 = self.register_user("user1", "pass")
- u1_token = self.login(u1, "pass")
- u2 = self.register_user("user2", "pass")
- u2_token = self.login(u2, "pass")
-
- room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
- self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
- self.helper.join(room, user=u2, tok=u2_token)
-
- request, channel = self._get_shared_rooms(u1_token, u2)
- self.assertEquals(200, channel.code, channel.result)
- self.assertEquals(len(channel.json_body["joined"]), 1)
- self.assertEquals(channel.json_body["joined"][0], room)
+ self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
def test_shared_room_list_private(self):
"""
A room should show up in the shared list of rooms between two users
if it is private.
"""
- u1 = self.register_user("user1", "pass")
- u1_token = self.login(u1, "pass")
- u2 = self.register_user("user2", "pass")
- u2_token = self.login(u2, "pass")
-
- room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
- self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
- self.helper.join(room, user=u2, tok=u2_token)
-
- request, channel = self._get_shared_rooms(u1_token, u2)
- self.assertEquals(200, channel.code, channel.result)
- self.assertEquals(len(channel.json_body["joined"]), 1)
- self.assertEquals(channel.json_body["joined"][0], room)
+ self._check_shared_rooms_with(
+ room_one_is_public=False, room_two_is_public=False
+ )
def test_shared_room_list_mixed(self):
"""
The shared room list between two users should contain both public and private
rooms.
"""
+ self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False)
+
+ def _check_shared_rooms_with(
+ self, room_one_is_public: bool, room_two_is_public: bool
+ ):
+ """Checks that shared public or private rooms between two users appear in
+ their shared room lists
+ """
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
u2 = self.register_user("user2", "pass")
u2_token = self.login(u2, "pass")
- room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
- room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token)
- self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token)
- self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token)
- self.helper.join(room_public, user=u2, tok=u2_token)
- self.helper.join(room_private, user=u1, tok=u1_token)
+ # Create a room. user1 invites user2, who joins
+ room_id_one = self.helper.create_room_as(
+ u1, is_public=room_one_is_public, tok=u1_token
+ )
+ self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room_id_one, user=u2, tok=u2_token)
- request, channel = self._get_shared_rooms(u1_token, u2)
+ # Check shared rooms from user1's perspective.
+ # We should see the one room in common
+ channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room_id_one)
+
+ # Create another room and invite user2 to it
+ room_id_two = self.helper.create_room_as(
+ u1, is_public=room_two_is_public, tok=u1_token
+ )
+ self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room_id_two, user=u2, tok=u2_token)
+
+ # Check shared rooms again. We should now see both rooms.
+ channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 2)
- self.assertTrue(room_public in channel.json_body["joined"])
- self.assertTrue(room_private in channel.json_body["joined"])
+ for room_id_id in channel.json_body["joined"]:
+ self.assertIn(room_id_id, [room_id_one, room_id_two])
def test_shared_room_list_after_leave(self):
"""
@@ -125,13 +126,19 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Assert user directory is not empty
- request, channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 1)
self.assertEquals(channel.json_body["joined"][0], room)
self.helper.leave(room, user=u1, tok=u1_token)
- request, channel = self._get_shared_rooms(u2_token, u1)
+ # Check user1's view of shared rooms with user2
+ channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 0)
+
+ # Check user2's view of shared rooms with user1
+ channel = self._get_shared_rooms(u2_token, u1)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index f4438e1636..8755bfb38a 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -45,7 +45,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
]
def test_sync_argless(self):
- request, channel = self.make_request("GET", "/sync")
+ channel = self.make_request("GET", "/sync")
self.assertEqual(channel.code, 200)
self.assertTrue(
@@ -65,7 +65,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"""
self.hs.config.use_presence = False
- request, channel = self.make_request("GET", "/sync")
+ channel = self.make_request("GET", "/sync")
self.assertEqual(channel.code, 200)
self.assertTrue(
@@ -204,7 +204,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
tok=tok,
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/sync?filter=%s" % sync_filter, access_token=tok
)
self.assertEqual(channel.code, 200, channel.result)
@@ -255,21 +255,19 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.helper.send(room, body="There!", tok=other_access_token)
# Start typing.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
typing_url % (room, other_user_id, other_access_token),
b'{"typing": true, "timeout": 30000}',
)
self.assertEquals(200, channel.code)
- request, channel = self.make_request(
- "GET", "/sync?access_token=%s" % (access_token,)
- )
+ channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,))
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
# Stop typing.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
typing_url % (room, other_user_id, other_access_token),
b'{"typing": false}',
@@ -277,7 +275,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code)
# Start typing.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
typing_url % (room, other_user_id, other_access_token),
b'{"typing": true, "timeout": 30000}',
@@ -285,9 +283,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code)
# Should return immediately
- request, channel = self.make_request(
- "GET", sync_url % (access_token, next_batch)
- )
+ channel = self.make_request("GET", sync_url % (access_token, next_batch))
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -299,9 +295,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
# invalidate the stream token.
self.helper.send(room, body="There!", tok=other_access_token)
- request, channel = self.make_request(
- "GET", sync_url % (access_token, next_batch)
- )
+ channel = self.make_request("GET", sync_url % (access_token, next_batch))
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -309,9 +303,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
# ahead, and therefore it's saying the typing (that we've actually
# already seen) is new, since it's got a token above our new, now-reset
# stream token.
- request, channel = self.make_request(
- "GET", sync_url % (access_token, next_batch)
- )
+ channel = self.make_request("GET", sync_url % (access_token, next_batch))
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -357,8 +349,10 @@ class SyncKnockTestCase(
self.knocker_tok = self.login("knocker", "monkey")
# Perform an initial sync for the knocking user.
- _, channel = self.make_request(
- "GET", self.url % self.next_batch, access_token=self.tok,
+ channel = self.make_request(
+ "GET",
+ self.url % self.next_batch,
+ access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -374,7 +368,7 @@ class SyncKnockTestCase(
def test_knock_room_state(self):
"""Tests that /sync returns state from a room after knocking on it."""
# Knock on a room
- _, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/unstable/xyz.amorgan.knock/%s" % (self.room_id,),
b"{}",
@@ -389,8 +383,10 @@ class SyncKnockTestCase(
}
# Check that /sync includes stripped state from the room
- _, channel = self.make_request(
- "GET", self.url % self.next_batch, access_token=self.knocker_tok,
+ channel = self.make_request(
+ "GET",
+ self.url % self.next_batch,
+ access_token=self.knocker_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -476,7 +472,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Send a read receipt to tell the server we've read the latest event.
body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/rooms/%s/read_markers" % self.room_id,
body,
@@ -489,13 +485,19 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Check that room name changes increase the unread counter.
self.helper.send_state(
- self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+ self.room_id,
+ "m.room.name",
+ {"name": "my super room"},
+ tok=self.tok2,
)
self._check_unread_count(1)
# Check that room topic changes increase the unread counter.
self.helper.send_state(
- self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+ self.room_id,
+ "m.room.topic",
+ {"topic": "welcome!!!"},
+ tok=self.tok2,
)
self._check_unread_count(2)
@@ -505,7 +507,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Check that custom events with a body increase the unread counter.
self.helper.send_event(
- self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+ self.room_id,
+ "org.matrix.custom_type",
+ {"body": "hello"},
+ tok=self.tok2,
)
self._check_unread_count(4)
@@ -543,15 +548,19 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
def _check_unread_count(self, expected_count: int):
"""Syncs and compares the unread count with the expected value."""
- request, channel = self.make_request(
- "GET", self.url % self.next_batch, access_token=self.tok,
+ channel = self.make_request(
+ "GET",
+ self.url % self.next_batch,
+ access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
room_entry = channel.json_body["rooms"]["join"][self.room_id]
self.assertEqual(
- room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+ room_entry["org.matrix.msc2654.unread_count"],
+ expected_count,
+ room_entry,
)
# Store the next batch for the next request.
diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py
new file mode 100644
index 0000000000..d890d11863
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_upgrade_room.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from synapse.config.server import DEFAULT_ROOM_VERSION
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+
+from tests import unittest
+from tests.server import FakeChannel
+
+
+class UpgradeRoomTest(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ self.creator = self.register_user("creator", "pass")
+ self.creator_token = self.login(self.creator, "pass")
+
+ self.other = self.register_user("user", "pass")
+ self.other_token = self.login(self.other, "pass")
+
+ self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token)
+ self.helper.join(self.room_id, self.other, tok=self.other_token)
+
+ def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel:
+ # We never want a cached response.
+ self.reactor.advance(5 * 60 + 1)
+
+ return self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id,
+ # This will upgrade a room to the same version, but that's fine.
+ content={"new_version": DEFAULT_ROOM_VERSION},
+ access_token=token or self.creator_token,
+ )
+
+ def test_upgrade(self):
+ """
+ Upgrading a room should work fine.
+ """
+ channel = self._upgrade_room()
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertIn("replacement_room", channel.json_body)
+
+ def test_not_in_room(self):
+ """
+ Upgrading a room should work fine.
+ """
+ # THe user isn't in the room.
+ roomless = self.register_user("roomless", "pass")
+ roomless_token = self.login(roomless, "pass")
+
+ channel = self._upgrade_room(roomless_token)
+ self.assertEquals(403, channel.code, channel.result)
+
+ def test_power_levels(self):
+ """
+ Another user can upgrade the room if their power level is increased.
+ """
+ # The other user doesn't have the proper power level.
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(403, channel.code, channel.result)
+
+ # Increase the power levels so that this user can upgrade.
+ power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.creator_token,
+ )
+ power_levels["users"][self.other] = 100
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ body=power_levels,
+ tok=self.creator_token,
+ )
+
+ # The upgrade should succeed!
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(200, channel.code, channel.result)
+
+ def test_power_levels_user_default(self):
+ """
+ Another user can upgrade the room if the default power level for users is increased.
+ """
+ # The other user doesn't have the proper power level.
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(403, channel.code, channel.result)
+
+ # Increase the power levels so that this user can upgrade.
+ power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.creator_token,
+ )
+ power_levels["users_default"] = 100
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ body=power_levels,
+ tok=self.creator_token,
+ )
+
+ # The upgrade should succeed!
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(200, channel.code, channel.result)
+
+ def test_power_levels_tombstone(self):
+ """
+ Another user can upgrade the room if they can send the tombstone event.
+ """
+ # The other user doesn't have the proper power level.
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(403, channel.code, channel.result)
+
+ # Increase the power levels so that this user can upgrade.
+ power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.creator_token,
+ )
+ power_levels["events"]["m.room.tombstone"] = 0
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ body=power_levels,
+ tok=self.creator_token,
+ )
+
+ # The upgrade should succeed!
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(200, channel.code, channel.result)
+
+ power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.creator_token,
+ )
+ self.assertNotIn(self.other, power_levels["users"])
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 5e90d656f7..9d0d0ef414 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -180,7 +180,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
async def post_json(destination, path, data):
self.assertEqual(destination, self.hs.hostname)
self.assertEqual(
- path, "/_matrix/key/v2/query",
+ path,
+ "/_matrix/key/v2/query",
)
channel = FakeChannel(self.site, self.reactor)
@@ -188,7 +189,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
req.content = BytesIO(encode_canonical_json(data))
req.requestReceived(
- b"POST", path.encode("utf-8"), b"1.1",
+ b"POST",
+ path.encode("utf-8"),
+ b"1.1",
)
channel.await_result()
self.assertEqual(channel.code, 200)
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 4c749f1a61..9f77125fd4 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -30,6 +30,8 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
@@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest
from tests.server import FakeSite, make_request
+from tests.utils import default_config
class MediaStorageTests(unittest.HomeserverTestCase):
@@ -102,7 +105,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.assertEqual(test_body, body)
-@attr.s
+@attr.s(slots=True, frozen=True)
class _TestImage:
"""An image for testing thumbnailing with the expected results
@@ -114,13 +117,15 @@ class _TestImage:
test should just check for success.
expected_scaled: The expected bytes from scaled thumbnailing, or None if
test should just check for a valid image returned.
+ expected_found: True if the file should exist on the server, or False if
+ a 404 is expected.
"""
data = attr.ib(type=bytes)
content_type = attr.ib(type=bytes)
extension = attr.ib(type=bytes)
- expected_cropped = attr.ib(type=Optional[bytes])
- expected_scaled = attr.ib(type=Optional[bytes])
+ expected_cropped = attr.ib(type=Optional[bytes], default=None)
+ expected_scaled = attr.ib(type=Optional[bytes], default=None)
expected_found = attr.ib(default=True, type=bool)
@@ -150,6 +155,21 @@ class _TestImage:
),
),
),
+ # small png with transparency.
+ (
+ _TestImage(
+ unhexlify(
+ b"89504e470d0a1a0a0000000d49484452000000010000000101000"
+ b"00000376ef9240000000274524e5300010194fdae0000000a4944"
+ b"4154789c636800000082008177cd72b60000000049454e44ae426"
+ b"082"
+ ),
+ b"image/png",
+ b".png",
+ # Note that we don't check the output since it varies across
+ # different versions of Pillow.
+ ),
+ ),
# small lossless webp
(
_TestImage(
@@ -159,12 +179,17 @@ class _TestImage:
),
b"image/webp",
b".webp",
- None,
- None,
),
),
# an empty file
- (_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
+ (
+ _TestImage(
+ b"",
+ b"image/gif",
+ b".gif",
+ expected_found=False,
+ ),
+ ),
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
@@ -202,7 +227,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
config = self.default_config()
config["media_store_path"] = self.media_store_path
- config["thumbnail_requirements"] = {}
config["max_image_pixels"] = 2000000
provider_config = {
@@ -220,15 +244,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
- self.media_repo = hs.get_media_repository_resource()
- self.download_resource = self.media_repo.children[b"download"]
- self.thumbnail_resource = self.media_repo.children[b"thumbnail"]
+ media_resource = hs.get_media_repository_resource()
+ self.download_resource = media_resource.children[b"download"]
+ self.thumbnail_resource = media_resource.children[b"thumbnail"]
+ self.store = hs.get_datastore()
+ self.media_repo = hs.get_media_repository()
self.media_id = "example.com/12345"
def _req(self, content_disposition):
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
@@ -313,18 +339,103 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
+ """Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found
)
def test_thumbnail_scale(self):
+ """Test that a scaled remote thumbnail is available."""
+ self._test_thumbnail(
+ "scale", self.test_image.expected_scaled, self.test_image.expected_found
+ )
+
+ def test_invalid_type(self):
+ """An invalid thumbnail type is never available."""
+ self._test_thumbnail("invalid", None, False)
+
+ @unittest.override_config(
+ {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
+ )
+ def test_no_thumbnail_crop(self):
+ """
+ Override the config to generate only scaled thumbnails, but request a cropped one.
+ """
+ self._test_thumbnail("crop", None, False)
+
+ @unittest.override_config(
+ {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
+ )
+ def test_no_thumbnail_scale(self):
+ """
+ Override the config to generate only cropped thumbnails, but request a scaled one.
+ """
+ self._test_thumbnail("scale", None, False)
+
+ def test_thumbnail_repeated_thumbnail(self):
+ """Test that fetching the same thumbnail works, and deleting the on disk
+ thumbnail regenerates it.
+ """
self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)
+ if not self.test_image.expected_found:
+ return
+
+ # Fetching again should work, without re-requesting the image from the
+ # remote.
+ params = "?width=32&height=32&method=scale"
+ channel = make_request(
+ self.reactor,
+ FakeSite(self.thumbnail_resource),
+ "GET",
+ self.media_id + params,
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ self.assertEqual(channel.code, 200)
+ if self.test_image.expected_scaled:
+ self.assertEqual(
+ channel.result["body"],
+ self.test_image.expected_scaled,
+ channel.result["body"],
+ )
+
+ # Deleting the thumbnail on disk then re-requesting it should work as
+ # Synapse should regenerate missing thumbnails.
+ origin, media_id = self.media_id.split("/")
+ info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
+ file_id = info["filesystem_id"]
+
+ thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
+ origin, file_id
+ )
+ shutil.rmtree(thumbnail_dir, ignore_errors=True)
+
+ channel = make_request(
+ self.reactor,
+ FakeSite(self.thumbnail_resource),
+ "GET",
+ self.media_id + params,
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ self.assertEqual(channel.code, 200)
+ if self.test_image.expected_scaled:
+ self.assertEqual(
+ channel.result["body"],
+ self.test_image.expected_scaled,
+ channel.result["body"],
+ )
+
def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.thumbnail_resource),
"GET",
@@ -362,3 +473,106 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"error": "Not found [b'example.com', b'12345']",
},
)
+
+ def test_x_robots_tag_header(self):
+ """
+ Tests that the `X-Robots-Tag` header is present, which informs web crawlers
+ to not index, archive, or follow links in media.
+ """
+ channel = self._req(b"inline; filename=out" + self.test_image.extension)
+
+ headers = channel.headers
+ self.assertEqual(
+ headers.getRawHeaders(b"X-Robots-Tag"),
+ [b"noindex, nofollow, noarchive, noimageindex"],
+ )
+
+
+class TestSpamChecker:
+ """A spam checker module that rejects all media that includes the bytes
+ `evil`.
+ """
+
+ def __init__(self, config, api):
+ self.config = config
+ self.api = api
+
+ def parse_config(config):
+ return config
+
+ async def check_event_for_spam(self, foo):
+ return False # allow all events
+
+ async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ return True # allow all invites
+
+ async def user_may_create_room(self, userid):
+ return True # allow all room creations
+
+ async def user_may_create_room_alias(self, userid, room_alias):
+ return True # allow all room aliases
+
+ async def user_may_publish_room(self, userid, room_id):
+ return True # allow publishing of all rooms
+
+ async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+ buf = BytesIO()
+ await file_wrapper.write_chunks_to(buf.write)
+
+ return b"evil" in buf.getvalue()
+
+
+class SpamCheckerTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.user = self.register_user("user", "pass")
+ self.tok = self.login("user", "pass")
+
+ # Allow for uploading and downloading to/from the media repo
+ self.media_repo = hs.get_media_repository_resource()
+ self.download_resource = self.media_repo.children[b"download"]
+ self.upload_resource = self.media_repo.children[b"upload"]
+
+ def default_config(self):
+ config = default_config("test")
+
+ config.update(
+ {
+ "spam_checker": [
+ {
+ "module": TestSpamChecker.__module__ + ".TestSpamChecker",
+ "config": {},
+ }
+ ]
+ }
+ )
+
+ return config
+
+ def test_upload_innocent(self):
+ """Attempt to upload some innocent data that should be allowed."""
+
+ image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ self.helper.upload_media(
+ self.upload_resource, image_data, tok=self.tok, expect_code=200
+ )
+
+ def test_upload_ban(self):
+ """Attempt to upload some data that includes bytes "evil", which should
+ get rejected by the spam checker.
+ """
+
+ data = b"Some evil data"
+
+ self.helper.upload_media(
+ self.upload_resource, data, tok=self.tok, expect_code=400
+ )
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index ccdc8c2ecf..6968502433 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -18,42 +18,23 @@ import re
from mock import patch
-import attr
-
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
-from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol
-from twisted.web._newclient import ResponseDone
from tests import unittest
from tests.server import FakeTransport
-
-@attr.s
-class FakeResponse:
- version = attr.ib()
- code = attr.ib()
- phrase = attr.ib()
- headers = attr.ib()
- body = attr.ib()
- absoluteURI = attr.ib()
-
- @property
- def request(self):
- @attr.s
- class FakeTransport:
- absoluteURI = self.absoluteURI
-
- return FakeTransport()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
+try:
+ import lxml
+except ImportError:
+ lxml = None
class URLPreviewTests(unittest.HomeserverTestCase):
+ if not lxml:
+ skip = "url preview feature requires lxml"
hijack_auth = True
user_id = "@test:user"
@@ -139,7 +120,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
def test_cache_returns_correct_type(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://matrix.org",
shorthand=False,
@@ -164,7 +145,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
# Check the cache returns the correct response
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://matrix.org", shorthand=False
)
@@ -180,7 +161,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertNotIn("http://matrix.org", self.preview_url._cache)
# Check the database cache returns the correct response
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://matrix.org", shorthand=False
)
@@ -201,7 +182,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://matrix.org",
shorthand=False,
@@ -236,7 +217,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://matrix.org",
shorthand=False,
@@ -271,7 +252,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://matrix.org",
shorthand=False,
@@ -304,7 +285,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://example.com",
shorthand=False,
@@ -334,7 +315,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False
)
@@ -355,7 +336,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False
)
@@ -372,7 +353,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
Blacklisted IP addresses, accessed directly, are not spidered.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://192.168.1.1", shorthand=False
)
@@ -391,7 +372,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
Blacklisted IP ranges, accessed directly, are not spidered.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://1.1.1.2", shorthand=False
)
@@ -411,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://example.com",
shorthand=False,
@@ -448,7 +429,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
(IPv4Address, "10.1.2.3"),
]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False
)
self.assertEqual(channel.code, 502)
@@ -468,7 +449,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
(IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False
)
@@ -489,7 +470,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False
)
@@ -506,7 +487,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
OPTIONS returns the OPTIONS.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"OPTIONS", "preview_url?url=http://example.com", shorthand=False
)
self.assertEqual(channel.code, 200)
@@ -519,7 +500,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
# Build and make a request to the server
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://example.com",
shorthand=False,
@@ -593,7 +574,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
@@ -658,7 +639,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}
end_content = json.dumps(result).encode("utf-8")
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 02a46e5fda..32acd93dc1 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -25,7 +25,7 @@ class HealthCheckTests(unittest.HomeserverTestCase):
return HealthResource()
def test_health(self):
- request, channel = self.make_request("GET", "/health", shorthand=False)
+ channel = self.make_request("GET", "/health", shorthand=False)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 6a930f4148..14de0921be 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -28,11 +28,11 @@ class WellKnownTests(unittest.HomeserverTestCase):
self.hs.config.public_baseurl = "https://tesths"
self.hs.config.default_identity_server = "https://testis"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body,
{
@@ -44,8 +44,8 @@ class WellKnownTests(unittest.HomeserverTestCase):
def test_well_known_no_public_baseurl(self):
self.hs.config.public_baseurl = None
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(request.code, 404)
+ self.assertEqual(channel.code, 404)
diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py
index 0937f304cf..3c70a9c8c0 100644
--- a/tests/rulecheck/test_domainrulecheck.py
+++ b/tests/rulecheck/test_domainrulecheck.py
@@ -22,7 +22,6 @@ from synapse.rest.client.v1 import login, room
from synapse.rulecheck.domain_rule_checker import DomainRuleChecker
from tests import unittest
-from tests.server import make_request
class DomainRuleCheckerTestCase(unittest.TestCase):
@@ -285,8 +284,7 @@ class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase):
)
def test_cannot_3pid_invite(self):
- """Test that unbound 3pid invites get rejected.
- """
+ """Test that unbound 3pid invites get rejected."""
channel = self._create_room(self.admin_access_token)
assert channel.result["code"] == b"200", channel.result
@@ -311,7 +309,7 @@ class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase):
expect_code=403,
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"rooms/%s/invite" % (room_id),
{"address": "foo@bar.com", "medium": "email", "id_server": "localhost"},
@@ -322,9 +320,7 @@ class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase):
def _create_room(self, token, content={}):
path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,)
- request, channel = make_request(
- self.hs.get_reactor(),
- self.site,
+ channel = self.make_request(
"POST",
path,
content=json.dumps(content).encode("utf8"),
diff --git a/tests/server.py b/tests/server.py
index a51ad0c14e..b535a5d886 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,7 +2,7 @@ import json
import logging
from collections import deque
from io import SEEK_END, BytesIO
-from typing import Callable, Iterable, Optional, Tuple, Union
+from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
import attr
from typing_extensions import Deque
@@ -13,9 +13,13 @@ from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
+ IHostnameResolver,
+ IProtocol,
+ IPullProducer,
+ IPushProducer,
IReactorPluggableNameResolver,
- IReactorTCP,
IResolverSimple,
+ ITransport,
)
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
@@ -44,16 +48,29 @@ class FakeChannel:
wire).
"""
- site = attr.ib(type=Site)
+ site = attr.ib(type=Union[Site, "FakeSite"])
_reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict))
- _producer = None
+ _ip = attr.ib(type=str, default="127.0.0.1")
+ _producer = None # type: Optional[Union[IPullProducer, IPushProducer]]
@property
def json_body(self):
- if not self.result:
- raise Exception("No result yet.")
- return json.loads(self.result["body"].decode("utf8"))
+ return json.loads(self.text_body)
+
+ @property
+ def text_body(self) -> str:
+ """The body of the result, utf-8-decoded.
+
+ Raises an exception if the request has not yet completed.
+ """
+ if not self.is_finished:
+ raise Exception("Request not yet completed")
+ return self.result["body"].decode("utf8")
+
+ def is_finished(self) -> bool:
+ """check if the response has been completely received"""
+ return self.result.get("done", False)
@property
def code(self):
@@ -62,7 +79,7 @@ class FakeChannel:
return int(self.result["code"])
@property
- def headers(self):
+ def headers(self) -> Headers:
if not self.result:
raise Exception("No result yet.")
h = Headers()
@@ -108,10 +125,14 @@ class FakeChannel:
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU
- return address.IPv4Address("TCP", "127.0.0.1", 3423)
+ return address.IPv4Address("TCP", self._ip, 3423)
def getHost(self):
- return None
+ # this is called by Request.__init__ to configure Request.host.
+ return address.IPv4Address("TCP", "127.0.0.1", 8888)
+
+ def isSecure(self):
+ return False
@property
def transport(self):
@@ -124,7 +145,7 @@ class FakeChannel:
self._reactor.run()
x = 0
- while not self.result.get("done"):
+ while not self.is_finished():
# If there's a producer, tell it to resume producing so we get content
if self._producer:
self._producer.resumeProducing()
@@ -136,6 +157,20 @@ class FakeChannel:
self._reactor.advance(0.1)
+ def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
+ """Process the contents of any Set-Cookie headers in the response
+
+ Any cookines found are added to the given dict
+ """
+ headers = self.headers.getRawHeaders("Set-Cookie")
+ if not headers:
+ return
+
+ for h in headers:
+ parts = h.split(";")
+ k, v = parts[0].split("=", maxsplit=1)
+ cookies[k] = v
+
class FakeSite:
"""
@@ -161,7 +196,7 @@ class FakeSite:
def make_request(
reactor,
- site: Site,
+ site: Union[Site, FakeSite],
method,
path,
content=b"",
@@ -174,11 +209,12 @@ def make_request(
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
-):
+ client_ip: str = "127.0.0.1",
+) -> FakeChannel:
"""
Make a web request using the given method, path and content, and render it
- Returns the Request and the Channel underneath.
+ Returns the fake Channel object which records the response to the request.
Args:
site: The twisted Site to use to render the request
@@ -201,8 +237,11 @@ def make_request(
will pump the reactor until the the renderer tells the channel the request
is finished.
+ client_ip: The IP to use as the requesting IP. Useful for testing
+ ratelimiting.
+
Returns:
- Tuple[synapse.http.site.SynapseRequest, channel]
+ channel
"""
if not isinstance(method, bytes):
method = method.encode("ascii")
@@ -216,8 +255,9 @@ def make_request(
and not path.startswith(b"/_matrix")
and not path.startswith(b"/_synapse")
):
+ if path.startswith(b"/"):
+ path = path[1:]
path = b"/_matrix/client/r0/" + path
- path = path.replace(b"//", b"/")
if not path.startswith(b"/"):
path = b"/" + path
@@ -227,7 +267,7 @@ def make_request(
if isinstance(content, str):
content = content.encode("utf8")
- channel = FakeChannel(site, reactor)
+ channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel)
req.content = BytesIO(content)
@@ -258,12 +298,13 @@ def make_request(
for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v)
+ req.parseCookies()
req.requestReceived(method, path, b"1.1")
if await_result:
channel.await_result()
- return req, channel
+ return channel
@implementer(IReactorPluggableNameResolver)
@@ -277,8 +318,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self._tcp_callbacks = {}
self._udp = []
- lookups = self.lookups = {}
- self._thread_callbacks = deque() # type: Deque[Callable[[], None]]()
+ lookups = self.lookups = {} # type: Dict[str, str]
+ self._thread_callbacks = deque() # type: Deque[Callable[[], None]]
@implementer(IResolverSimple)
class FakeResolver:
@@ -290,6 +331,9 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
super().__init__()
+ def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
+ raise NotImplementedError()
+
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening()
@@ -318,8 +362,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self._tcp_callbacks[(host, port)] = callback
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
- """Fake L{IReactorTCP.connectTCP}.
- """
+ """Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP(
host, port, factory, timeout=timeout, bindAddress=None
@@ -435,6 +478,7 @@ def get_clock():
return clock, hs_clock
+@implementer(ITransport)
@attr.s(cmp=False)
class FakeTransport:
"""
@@ -559,7 +603,7 @@ class FakeTransport:
if self.disconnected:
return
- if getattr(self.other, "transport") is None:
+ if not hasattr(self.other, "transport"):
# the other has no transport yet; reschedule
if self.autoflush:
self._reactor.callLater(0.0, self.flush)
@@ -587,7 +631,9 @@ class FakeTransport:
self.disconnected = True
-def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
+def connect_client(
+ reactor: ThreadedMemoryReactorClock, client_id: int
+) -> Tuple[IProtocol, AccumulatingProtocol]:
"""
Connect a client to a fake TCP transport.
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index e0a9cd93ac..4dd5a36178 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -70,7 +70,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
the notice URL + an authentication code.
"""
# Initial sync, to get the user consent room invite
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=self.access_token
)
self.assertEqual(channel.code, 200)
@@ -79,7 +79,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
room_id = list(channel.json_body["rooms"]["invite"].keys())[0]
# Join the room
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/" + room_id + "/join",
access_token=self.access_token,
@@ -87,7 +87,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
# Sync again, to get the message in the room
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=self.access_token
)
self.assertEqual(channel.code, 200)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 9c8027a5b2..d40d65b06a 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -305,7 +305,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.register_user("user", "password")
tok = self.login("user", "password")
- request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+ channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
invites = channel.json_body["rooms"]["invite"]
self.assertEqual(len(invites), 0, invites)
@@ -318,7 +318,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
# Sync again to retrieve the events in the room, so we can check whether this
# room has a notice in it.
- request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+ channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
# Scan the events in the room to search for a message from the server notices
# user.
@@ -353,8 +353,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
tok = self.login(localpart, "password")
# Sync with the user's token to mark the user as active.
- request, channel = self.make_request(
- "GET", "/sync?timeout=0", access_token=tok,
+ channel = self.make_request(
+ "GET",
+ "/sync?timeout=0",
+ access_token=tok,
)
# Also retrieves the list of invites for this user. We don't care about that
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index ad9bbef9d2..66e3cafe8e 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
from synapse.events import make_event_from_dict
-from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.state.v2 import (
+ _get_auth_chain_difference,
+ lexicographical_topological_sort,
+ resolve_events_with_store,
+)
from synapse.types import EventID
from tests import unittest
@@ -84,7 +88,7 @@ class FakeEvent:
event_dict = {
"auth_events": [(a, {}) for a in auth_events],
"prev_events": [(p, {}) for p in prev_events],
- "event_id": self.node_id,
+ "event_id": self.event_id,
"sender": self.sender,
"type": self.type,
"content": self.content,
@@ -377,6 +381,60 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
+ def test_mainline_sort(self):
+ """Tests that the mainline ordering works correctly."""
+
+ events = [
+ FakeEvent(
+ id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+ ),
+ FakeEvent(
+ id="PA1",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {ALICE: 100, BOB: 50}},
+ ),
+ FakeEvent(
+ id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+ ),
+ FakeEvent(
+ id="PA2",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={
+ "users": {ALICE: 100, BOB: 50},
+ "events": {EventTypes.PowerLevels: 100},
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {ALICE: 100, BOB: 50}},
+ ),
+ FakeEvent(
+ id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={}
+ ),
+ FakeEvent(
+ id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+ ),
+ ]
+
+ edges = [
+ ["END", "T3", "PA2", "T2", "PA1", "T1", "START"],
+ ["END", "T4", "PB", "PA1"],
+ ]
+
+ # We expect T3 to be picked as the other topics are pointing at older
+ # power levels. Note that without mainline ordering we'd pick T4 due to
+ # it being sent *after* T3.
+ expected_state_ids = ["T3", "PA2"]
+
+ self.do_check(events, edges, expected_state_ids)
+
def do_check(self, events, edges, expected_state_ids):
"""Take a list of events and edges and calculate the state of the
graph at END, and asserts it matches `expected_state_ids`
@@ -587,6 +645,182 @@ class SimpleParamStateTestCase(unittest.TestCase):
self.assert_dict(self.expected_combined_state, state)
+class AuthChainDifferenceTestCase(unittest.TestCase):
+ """We test that `_get_auth_chain_difference` correctly handles unpersisted
+ events.
+ """
+
+ def test_simple(self):
+ # Test getting the auth difference for a simple chain with a single
+ # unpersisted event:
+ #
+ # Unpersisted | Persisted
+ # |
+ # C -|-> B -> A
+
+ a = FakeEvent(
+ id="A",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
+ ).to_event([b.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c}
+
+ state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {c.event_id})
+
+ def test_multiple_unpersisted_chain(self):
+ # Test getting the auth difference for a simple chain with multiple
+ # unpersisted events:
+ #
+ # Unpersisted | Persisted
+ # |
+ # D -> C -|-> B -> A
+
+ a = FakeEvent(
+ id="A",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
+ ).to_event([b.event_id], [])
+
+ d = FakeEvent(
+ id="D",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
+ ).to_event([c.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c, d.event_id: d}
+
+ state_sets = [
+ {"a": a.event_id, "b": b.event_id},
+ {"c": c.event_id, "d": d.event_id},
+ ]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {d.event_id, c.event_id})
+
+ def test_unpersisted_events_different_sets(self):
+ # Test getting the auth difference for with multiple unpersisted events
+ # in different branches:
+ #
+ # Unpersisted | Persisted
+ # |
+ # D --> C -|-> B -> A
+ # E ----^ -|---^
+ # |
+
+ a = FakeEvent(
+ id="A",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
+ ).to_event([b.event_id], [])
+
+ d = FakeEvent(
+ id="D",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
+ ).to_event([c.event_id], [])
+
+ e = FakeEvent(
+ id="E",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
+ ).to_event([c.event_id, b.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
+
+ state_sets = [
+ {"a": a.event_id, "b": b.event_id, "e": e.event_id},
+ {"c": c.event_id, "d": d.event_id},
+ ]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {d.event_id, e.event_id})
+
+
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
@@ -647,7 +881,7 @@ class TestStateResolutionStore:
return list(result)
- def get_auth_chain_difference(self, auth_sets):
+ def get_auth_chain_difference(self, room_id, auth_sets):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
new file mode 100644
index 0000000000..38444e48e2
--- /dev/null
+++ b/tests/storage/test_account_data.py
@@ -0,0 +1,122 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable, Set
+
+from synapse.api.constants import AccountDataTypes
+
+from tests import unittest
+
+
+class IgnoredUsersTestCase(unittest.HomeserverTestCase):
+ def prepare(self, hs, reactor, clock):
+ self.store = self.hs.get_datastore()
+ self.user = "@user:test"
+
+ def _update_ignore_list(
+ self, *ignored_user_ids: Iterable[str], ignorer_user_id: str = None
+ ) -> None:
+ """Update the account data to block the given users."""
+ if ignorer_user_id is None:
+ ignorer_user_id = self.user
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ ignorer_user_id,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": {u: {} for u in ignored_user_ids}},
+ )
+ )
+
+ def assert_ignorers(
+ self, ignored_user_id: str, expected_ignorer_user_ids: Set[str]
+ ) -> None:
+ self.assertEqual(
+ self.get_success(self.store.ignored_by(ignored_user_id)),
+ expected_ignorer_user_ids,
+ )
+
+ def test_ignoring_users(self):
+ """Basic adding/removing of users from the ignore list."""
+ self._update_ignore_list("@other:test", "@another:remote")
+
+ # Check a user which no one ignores.
+ self.assert_ignorers("@user:test", set())
+
+ # Check a local user which is ignored.
+ self.assert_ignorers("@other:test", {self.user})
+
+ # Check a remote user which is ignored.
+ self.assert_ignorers("@another:remote", {self.user})
+
+ # Add one user, remove one user, and leave one user.
+ self._update_ignore_list("@foo:test", "@another:remote")
+
+ # Check the removed user.
+ self.assert_ignorers("@other:test", set())
+
+ # Check the added user.
+ self.assert_ignorers("@foo:test", {self.user})
+
+ # Check the removed user.
+ self.assert_ignorers("@another:remote", {self.user})
+
+ def test_caching(self):
+ """Ensure that caching works properly between different users."""
+ # The first user ignores a user.
+ self._update_ignore_list("@other:test")
+ self.assert_ignorers("@other:test", {self.user})
+
+ # The second user ignores them.
+ self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
+ self.assert_ignorers("@other:test", {self.user, "@second:test"})
+
+ # The first user un-ignores them.
+ self._update_ignore_list()
+ self.assert_ignorers("@other:test", {"@second:test"})
+
+ def test_invalid_data(self):
+ """Invalid data ends up clearing out the ignored users list."""
+ # Add some data and ensure it is there.
+ self._update_ignore_list("@other:test")
+ self.assert_ignorers("@other:test", {self.user})
+
+ # No ignored_users key.
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.user,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {},
+ )
+ )
+
+ # No one ignores the user now.
+ self.assert_ignorers("@other:test", set())
+
+ # Add some data and ensure it is there.
+ self._update_ignore_list("@other:test")
+ self.assert_ignorers("@other:test", {self.user})
+
+ # Invalid data.
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.user,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": "unexpected"},
+ )
+ )
+
+ # No one ignores the user now.
+ self.assert_ignorers("@other:test", set())
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 02aae1c13d..1b4fae0bb5 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -67,7 +67,9 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
async def update(progress, count):
self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
- count, target_background_update_duration_ms / duration_ms, places=0,
+ count,
+ target_background_update_duration_ms / duration_ms,
+ places=0,
)
await self.updates._end_background_update("test_update")
return count
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index c13a57dad1..7791138688 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -43,8 +43,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.room_id = info["room_id"]
def run_background_update(self):
- """Re run the background update to clean up the extremities.
- """
+ """Re run the background update to clean up the extremities."""
# Make sure we don't clash with in progress updates.
self.assertTrue(
self.store.db_pool.updates._all_done, "Background updates are still ongoing"
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index a69117c5a9..34e6526097 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -41,7 +41,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
device_id = "MY_DEVICE"
# Insert a user IP
- self.get_success(self.store.store_device(user_id, device_id, "display name",))
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id,
+ "display name",
+ )
+ )
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", device_id
@@ -214,7 +220,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
device_id = "MY_DEVICE"
# Insert a user IP
- self.get_success(self.store.store_device(user_id, device_id, "display name",))
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id,
+ "display name",
+ )
+ )
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", device_id
@@ -303,7 +315,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
device_id = "MY_DEVICE"
# Insert a user IP
- self.get_success(self.store.store_device(user_id, device_id, "display name",))
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id,
+ "display name",
+ )
+ )
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", device_id
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index ecb00f4e02..dabc1c5f09 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -80,6 +80,32 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
@defer.inlineCallbacks
+ def test_count_devices_by_users(self):
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device1", "display_name 1")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device2", "display_name 2")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id2", "device3", "display_name 3")
+ )
+
+ res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+ self.assertEqual(0, res)
+
+ res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+ self.assertEqual(0, res)
+
+ res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+ self.assertEqual(2, res)
+
+ res = yield defer.ensureDeferred(
+ self.store.count_devices_by_users(["user_id", "user_id2"])
+ )
+ self.assertEqual(3, res)
+
+ @defer.inlineCallbacks
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
new file mode 100644
index 0000000000..16daa66cc9
--- /dev/null
+++ b/tests/storage/test_event_chain.py
@@ -0,0 +1,746 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Set, Tuple
+
+from twisted.trial import unittest
+
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.storage.databases.main.events import _LinkMap
+from synapse.types import create_requester
+
+from tests.unittest import HomeserverTestCase
+
+
+class EventChainStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self._next_stream_ordering = 1
+
+ def test_simple(self):
+ """Test that the example in `docs/auth_chain_difference_algorithm.md`
+ works.
+ """
+
+ event_factory = self.hs.get_event_builder_factory()
+ bob = "@creator:test"
+ alice = "@alice:test"
+ room_id = "!room:test"
+
+ # Ensure that we have a rooms entry so that we generate the chain index.
+ self.get_success(
+ self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id="",
+ is_public=True,
+ room_version=RoomVersions.V6,
+ )
+ )
+
+ create = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Create,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "create"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[])
+ )
+
+ bob_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": bob,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "bob_join"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+ )
+
+ power = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "power"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id],
+ )
+ )
+
+ alice_invite = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "alice_invite"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ alice_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+ )
+ )
+
+ power_2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "power_2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ bob_join_2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": bob,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "bob_join_2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ alice_join2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[
+ create.event_id,
+ alice_join.event_id,
+ power_2.event_id,
+ ],
+ )
+ )
+
+ events = [
+ create,
+ bob_join,
+ power,
+ alice_invite,
+ alice_join,
+ bob_join_2,
+ power_2,
+ alice_join2,
+ ]
+
+ expected_links = [
+ (bob_join, create),
+ (power, create),
+ (power, bob_join),
+ (alice_invite, create),
+ (alice_invite, power),
+ (alice_invite, bob_join),
+ (bob_join_2, power),
+ (alice_join2, power_2),
+ ]
+
+ self.persist(events)
+ chain_map, link_map = self.fetch_chains(events)
+
+ # Check that the expected links and only the expected links have been
+ # added.
+ self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+ for start, end in expected_links:
+ start_id, start_seq = chain_map[start.event_id]
+ end_id, end_seq = chain_map[end.event_id]
+
+ self.assertIn(
+ (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+ )
+
+ # Test that everything can reach the create event, but the create event
+ # can't reach anything.
+ for event in events[1:]:
+ self.assertTrue(
+ link_map.exists_path_from(
+ chain_map[event.event_id], chain_map[create.event_id]
+ ),
+ )
+
+ self.assertFalse(
+ link_map.exists_path_from(
+ chain_map[create.event_id],
+ chain_map[event.event_id],
+ ),
+ )
+
+ def test_out_of_order_events(self):
+ """Test that we handle persisting events that we don't have the full
+ auth chain for yet (which should only happen for out of band memberships).
+ """
+ event_factory = self.hs.get_event_builder_factory()
+ bob = "@creator:test"
+ alice = "@alice:test"
+ room_id = "!room:test"
+
+ # Ensure that we have a rooms entry so that we generate the chain index.
+ self.get_success(
+ self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id="",
+ is_public=True,
+ room_version=RoomVersions.V6,
+ )
+ )
+
+ # First persist the base room.
+ create = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Create,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "create"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[])
+ )
+
+ bob_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": bob,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "bob_join"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+ )
+
+ power = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "power"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id],
+ )
+ )
+
+ self.persist([create, bob_join, power])
+
+ # Now persist an invite and a couple of memberships out of order.
+ alice_invite = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "alice_invite"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ alice_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+ )
+ )
+
+ alice_join2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
+ )
+ )
+
+ self.persist([alice_join])
+ self.persist([alice_join2])
+ self.persist([alice_invite])
+
+ # The end result should be sane.
+ events = [create, bob_join, power, alice_invite, alice_join]
+
+ chain_map, link_map = self.fetch_chains(events)
+
+ expected_links = [
+ (bob_join, create),
+ (power, create),
+ (power, bob_join),
+ (alice_invite, create),
+ (alice_invite, power),
+ (alice_invite, bob_join),
+ ]
+
+ # Check that the expected links and only the expected links have been
+ # added.
+ self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+ for start, end in expected_links:
+ start_id, start_seq = chain_map[start.event_id]
+ end_id, end_seq = chain_map[end.event_id]
+
+ self.assertIn(
+ (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+ )
+
+ def persist(
+ self,
+ events: List[EventBase],
+ ):
+ """Persist the given events and check that the links generated match
+ those given.
+ """
+
+ persist_events_store = self.hs.get_datastores().persist_events
+
+ for e in events:
+ e.internal_metadata.stream_ordering = self._next_stream_ordering
+ self._next_stream_ordering += 1
+
+ def _persist(txn):
+ # We need to persist the events to the events and state_events
+ # tables.
+ persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
+
+ # Actually call the function that calculates the auth chain stuff.
+ persist_events_store._persist_event_auth_chain_txn(txn, events)
+
+ self.get_success(
+ persist_events_store.db_pool.runInteraction(
+ "_persist",
+ _persist,
+ )
+ )
+
+ def fetch_chains(
+ self, events: List[EventBase]
+ ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
+
+ # Fetch the map from event ID -> (chain ID, sequence number)
+ rows = self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="event_auth_chains",
+ column="event_id",
+ iterable=[e.event_id for e in events],
+ retcols=("event_id", "chain_id", "sequence_number"),
+ keyvalues={},
+ )
+ )
+
+ chain_map = {
+ row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
+ }
+
+ # Fetch all the links and pass them to the _LinkMap.
+ rows = self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="event_auth_chain_links",
+ column="origin_chain_id",
+ iterable=[chain_id for chain_id, _ in chain_map.values()],
+ retcols=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
+ keyvalues={},
+ )
+ )
+
+ link_map = _LinkMap()
+ for row in rows:
+ added = link_map.add_link(
+ (row["origin_chain_id"], row["origin_sequence_number"]),
+ (row["target_chain_id"], row["target_sequence_number"]),
+ )
+
+ # We shouldn't have persisted any redundant links
+ self.assertTrue(added)
+
+ return chain_map, link_map
+
+
+class LinkMapTestCase(unittest.TestCase):
+ def test_simple(self):
+ """Basic tests for the LinkMap."""
+ link_map = _LinkMap()
+
+ link_map.add_link((1, 1), (2, 1), new=False)
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+ self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
+ self.assertCountEqual(link_map.get_additions(), [])
+ self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
+ self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
+ self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
+ self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
+
+ # Attempting to add a redundant link is ignored.
+ self.assertFalse(link_map.add_link((1, 4), (2, 1)))
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+
+ # Adding new non-redundant links works
+ self.assertTrue(link_map.add_link((1, 3), (2, 3)))
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+ self.assertTrue(link_map.add_link((2, 5), (1, 3)))
+ self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+ self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
+
+
+class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.user_id = self.register_user("foo", "pass")
+ self.token = self.login("foo", "pass")
+ self.requester = create_requester(self.user_id)
+
+ def _generate_room(self) -> Tuple[str, List[Set[str]]]:
+ """Insert a room without a chain cover index."""
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Mark the room as not having a chain cover index
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": False},
+ desc="test",
+ )
+ )
+
+ # Create a fork in the DAG with different events.
+ event_handler = self.hs.get_event_creation_handler()
+ latest_event_ids = self.get_success(
+ self.store.get_prev_events_for_room(room_id)
+ )
+ event, context = self.get_success(
+ event_handler.create_event(
+ self.requester,
+ {
+ "type": "some_state_type",
+ "state_key": "",
+ "content": {},
+ "room_id": room_id,
+ "sender": self.user_id,
+ },
+ prev_event_ids=latest_event_ids,
+ )
+ )
+ self.get_success(
+ event_handler.handle_new_client_event(self.requester, event, context)
+ )
+ state1 = set(self.get_success(context.get_current_state_ids()).values())
+
+ event, context = self.get_success(
+ event_handler.create_event(
+ self.requester,
+ {
+ "type": "some_state_type",
+ "state_key": "",
+ "content": {},
+ "room_id": room_id,
+ "sender": self.user_id,
+ },
+ prev_event_ids=latest_event_ids,
+ )
+ )
+ self.get_success(
+ event_handler.handle_new_client_event(self.requester, event, context)
+ )
+ state2 = set(self.get_success(context.get_current_state_ids()).values())
+
+ # Delete the chain cover info.
+
+ def _delete_tables(txn):
+ txn.execute("DELETE FROM event_auth_chains")
+ txn.execute("DELETE FROM event_auth_chain_links")
+
+ self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
+
+ return room_id, [state1, state2]
+
+ def test_background_update_single_room(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+
+ # Create a room
+ room_id, states = self._generate_room()
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
+
+ # Test that calculating the auth chain difference using the newly
+ # calculated chain cover works.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "test",
+ self.store._get_auth_chain_difference_using_cover_index_txn,
+ room_id,
+ states,
+ )
+ )
+
+ def test_background_update_multiple_rooms(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+ # Create a room
+ room_id1, states1 = self._generate_room()
+ room_id2, states2 = self._generate_room()
+ room_id3, states2 = self._generate_room()
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
+
+ # Test that calculating the auth chain difference using the newly
+ # calculated chain cover works.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "test",
+ self.store._get_auth_chain_difference_using_cover_index_txn,
+ room_id1,
+ states1,
+ )
+ )
+
+ def test_background_update_single_large_room(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+
+ # Create a room
+ room_id, states = self._generate_room()
+
+ # Add a bunch of state so that it takes multiple iterations of the
+ # background update to process the room.
+ for i in range(0, 150):
+ self.helper.send_state(
+ room_id, event_type="m.test", body={"index": i}, tok=self.token
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ iterations = 0
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ iterations += 1
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Ensure that we did actually take multiple iterations to process the
+ # room.
+ self.assertGreater(iterations, 1)
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
+
+ # Test that calculating the auth chain difference using the newly
+ # calculated chain cover works.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "test",
+ self.store._get_auth_chain_difference_using_cover_index_txn,
+ room_id,
+ states,
+ )
+ )
+
+ def test_background_update_multiple_large_room(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+
+ # Create the rooms
+ room_id1, _ = self._generate_room()
+ room_id2, _ = self._generate_room()
+
+ # Add a bunch of state so that it takes multiple iterations of the
+ # background update to process the room.
+ for i in range(0, 150):
+ self.helper.send_state(
+ room_id1, event_type="m.test", body={"index": i}, tok=self.token
+ )
+
+ for i in range(0, 150):
+ self.helper.send_state(
+ room_id2, event_type="m.test", body={"index": i}, tok=self.token
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ iterations = 0
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ iterations += 1
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Ensure that we did actually take multiple iterations to process the
+ # room.
+ self.assertGreater(iterations, 1)
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index d4c3b867e3..d597d712d6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,6 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import attr
+from parameterized import parameterized
+
+from synapse.events import _EventInternalMetadata
+
import tests.unittest
import tests.utils
@@ -113,7 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
- def test_auth_difference(self):
+ def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@@ -159,77 +164,361 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"j": 1,
}
+ # Mark the room as maybe having a cover index.
+
+ def store_room(txn):
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ "rooms",
+ {
+ "room_id": room_id,
+ "creator": "room_creator_user_id",
+ "is_public": True,
+ "room_version": "6",
+ "has_auth_chain_index": use_chain_cover_index,
+ },
+ )
+
+ self.get_success(self.store.db_pool.runInteraction("store_room", store_room))
+
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
- def insert_event(txn, event_id, stream_ordering):
+ def insert_event(txn):
+ stream_ordering = 0
+
+ for event_id in auth_graph:
+ stream_ordering += 1
+ depth = depth_map[event_id]
+
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "type": "m.test",
+ "processed": True,
+ "outlier": False,
+ "stream_ordering": stream_ordering,
+ },
+ )
+
+ self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ txn,
+ [
+ FakeEvent(event_id, room_id, auth_graph[event_id])
+ for event_id in auth_graph
+ ],
+ )
+
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "insert",
+ insert_event,
+ )
+ )
+
+ return room_id
+
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_chain_ids(self, use_chain_cover_index: bool):
+ room_id = self._setup_auth_chain(use_chain_cover_index)
+
+ # a and b have the same auth chain.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["a", "b"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
+ self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+ # d and e have the same auth chain.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
+ self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
+ self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
+ self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
+ self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
+ self.assertEqual(auth_chain_ids, ["k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
+ self.assertEqual(auth_chain_ids, ["j"])
+
+ # j and k have no parents.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
+ self.assertEqual(auth_chain_ids, [])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
+ self.assertEqual(auth_chain_ids, [])
+
+ # More complex input sequences.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["h", "i"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["k", "j"])
+
+ # e gets returned even though include_given is false, but it is in the
+ # auth chain of b.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["b", "e"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ # Test include_given.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
+ )
+ self.assertCountEqual(auth_chain_ids, ["i", "j"])
+
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_difference(self, use_chain_cover_index: bool):
+ room_id = self._setup_auth_chain(use_chain_cover_index)
+
+ # Now actually test that various combinations give the right result:
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
- depth = depth_map[event_id]
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}])
+ )
+ self.assertSetEqual(difference, set())
+
+ def test_auth_difference_partial_cover(self):
+ """Test that we correctly handle rooms where not all events have a chain
+ cover calculated. This can happen in some obscure edge cases, including
+ during the background update that calculates the chain cover for old
+ rooms.
+ """
+
+ room_id = "@ROOM:local"
+ # The silly auth graph we use to test the auth difference algorithm,
+ # where the top are the most recent events.
+ #
+ # A B
+ # \ /
+ # D E
+ # \ |
+ # ` F C
+ # | /|
+ # G ´ |
+ # | \ |
+ # H I
+ # | |
+ # K J
+
+ auth_graph = {
+ "a": ["e"],
+ "b": ["e"],
+ "c": ["g", "i"],
+ "d": ["f"],
+ "e": ["f"],
+ "f": ["g"],
+ "g": ["h", "i"],
+ "h": ["k"],
+ "i": ["j"],
+ "k": [],
+ "j": [],
+ }
+
+ depth_map = {
+ "a": 7,
+ "b": 7,
+ "c": 4,
+ "d": 6,
+ "e": 6,
+ "f": 5,
+ "g": 3,
+ "h": 2,
+ "i": 2,
+ "k": 1,
+ "j": 1,
+ }
+
+ # We rudely fiddle with the appropriate tables directly, as that's much
+ # easier than constructing events properly.
+
+ def insert_event(txn):
+ # First insert the room and mark it as having a chain cover.
self.store.db_pool.simple_insert_txn(
txn,
- table="events",
- values={
- "event_id": event_id,
+ "rooms",
+ {
"room_id": room_id,
- "depth": depth,
- "topological_ordering": depth,
- "type": "m.test",
- "processed": True,
- "outlier": False,
- "stream_ordering": stream_ordering,
+ "creator": "room_creator_user_id",
+ "is_public": True,
+ "room_version": "6",
+ "has_auth_chain_index": True,
},
)
- self.store.db_pool.simple_insert_many_txn(
+ stream_ordering = 0
+
+ for event_id in auth_graph:
+ stream_ordering += 1
+ depth = depth_map[event_id]
+
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "type": "m.test",
+ "processed": True,
+ "outlier": False,
+ "stream_ordering": stream_ordering,
+ },
+ )
+
+ # Insert all events apart from 'B'
+ self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
- table="event_auth",
- values=[
- {"event_id": event_id, "room_id": room_id, "auth_id": a}
- for a in auth_graph[event_id]
+ [
+ FakeEvent(event_id, room_id, auth_graph[event_id])
+ for event_id in auth_graph
+ if event_id != "b"
],
)
- next_stream_ordering = 0
- for event_id in auth_graph:
- next_stream_ordering += 1
- self.get_success(
- self.store.db_pool.runInteraction(
- "insert", insert_event, event_id, next_stream_ordering
- )
+ # Now we insert the event 'B' without a chain cover, by temporarily
+ # pretending the room doesn't have a chain cover.
+
+ self.store.db_pool.simple_update_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": False},
+ )
+
+ self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ txn,
+ [FakeEvent("b", room_id, auth_graph["b"])],
)
+ self.store.db_pool.simple_update_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": True},
+ )
+
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "insert",
+ insert_event,
+ )
+ )
+
# Now actually test that various combinations give the right result:
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b", "c"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "d", "e"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
)
self.assertSetEqual(difference, {"a", "b"})
- difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}])
+ )
self.assertSetEqual(difference, set())
+
+
+@attr.s
+class FakeEvent:
+ event_id = attr.ib()
+ room_id = attr.ib()
+ auth_events = attr.ib()
+
+ type = "foo"
+ state_key = "foo"
+
+ internal_metadata = _EventInternalMetadata({})
+
+ def auth_event_ids(self):
+ return self.auth_events
+
+ def is_state(self):
+ return True
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index c0595963dd..485f1ee033 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -84,7 +84,9 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(
self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action}, False,
+ event.event_id,
+ {user_id: action},
+ False,
)
)
yield defer.ensureDeferred(
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
new file mode 100644
index 0000000000..ed898b8dbb
--- /dev/null
+++ b/tests/storage/test_events.py
@@ -0,0 +1,332 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.unittest import HomeserverTestCase
+
+
+class ExtremPruneTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.state = self.hs.get_state_handler()
+ self.persistence = self.hs.get_storage().persistence
+ self.store = self.hs.get_datastore()
+
+ self.register_user("user", "pass")
+ self.token = self.login("user", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ "user", room_version=RoomVersions.V6.identifier, tok=self.token
+ )
+
+ body = self.helper.send(self.room_id, body="Test", tok=self.token)
+ local_message_event_id = body["event_id"]
+
+ # Fudge a remote event and persist it. This will be the extremity before
+ # the gap.
+ self.remote_event_1 = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "state_key": "@user:other",
+ "content": {},
+ "room_id": self.room_id,
+ "sender": "@user:other",
+ "depth": 5,
+ "prev_events": [local_message_event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ self.persist_event(self.remote_event_1)
+
+ # Check that the current extremities is the remote event.
+ self.assert_extremities([self.remote_event_1.event_id])
+
+ def persist_event(self, event, state=None):
+ """Persist the event, with optional state"""
+ context = self.get_success(
+ self.state.compute_event_context(event, old_state=state)
+ )
+ self.get_success(self.persistence.persist_event(event, context))
+
+ def assert_extremities(self, expected_extremities):
+ """Assert the current extremities for the room"""
+ extremities = self.get_success(
+ self.store.get_prev_events_for_room(self.room_id)
+ )
+ self.assertCountEqual(extremities, expected_extremities)
+
+ def test_prune_gap(self):
+ """Test that we drop extremities after a gap when we see an event from
+ the same domain.
+ """
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other",
+ "depth": 50,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([remote_event_2.event_id])
+
+ def test_do_not_prune_gap_if_state_different(self):
+ """Test that we don't prune extremities after a gap if the resolved
+ state is different.
+ """
+
+ # Fudge a second event which points to an event we don't have.
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "state_key": "@user:other",
+ "content": {},
+ "room_id": self.room_id,
+ "sender": "@user:other",
+ "depth": 10,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ # Now we persist it with state with a dropped history visibility
+ # setting. The state resolution across the old and new event will then
+ # include it, and so the resolved state won't match the new state.
+ state_before_gap = dict(
+ self.get_success(self.state.get_current_state(self.room_id))
+ )
+ state_before_gap.pop(("m.room.history_visibility", ""))
+
+ context = self.get_success(
+ self.state.compute_event_context(
+ remote_event_2, old_state=state_before_gap.values()
+ )
+ )
+
+ self.get_success(self.persistence.persist_event(remote_event_2, context))
+
+ # Check that we haven't dropped the old extremity.
+ self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
+
+ def test_prune_gap_if_old(self):
+ """Test that we drop extremities after a gap when the previous extremity
+ is "old"
+ """
+
+ # Advance the clock for many days to make the old extremity "old". We
+ # also set the depth to "lots".
+ self.reactor.advance(7 * 24 * 60 * 60)
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10000,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([remote_event_2.event_id])
+
+ def test_do_not_prune_gap_if_other_server(self):
+ """Test that we do not drop extremities after a gap when we see an event
+ from a different domain.
+ """
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
+
+ def test_prune_gap_if_dummy_remote(self):
+ """Test that we drop extremities after a gap when the previous extremity
+ is a local dummy event and only points to remote events.
+ """
+
+ body = self.helper.send_event(
+ self.room_id, type=EventTypes.Dummy, content={}, tok=self.token
+ )
+ local_message_event_id = body["event_id"]
+ self.assert_extremities([local_message_event_id])
+
+ # Advance the clock for many days to make the old extremity "old". We
+ # also set the depth to "lots".
+ self.reactor.advance(7 * 24 * 60 * 60)
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10000,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([remote_event_2.event_id])
+
+ def test_prune_gap_if_dummy_local(self):
+ """Test that we don't drop extremities after a gap when the previous
+ extremity is a local dummy event and points to local events.
+ """
+
+ body = self.helper.send(self.room_id, body="Test", tok=self.token)
+
+ body = self.helper.send_event(
+ self.room_id, type=EventTypes.Dummy, content={}, tok=self.token
+ )
+ local_message_event_id = body["event_id"]
+ self.assert_extremities([local_message_event_id])
+
+ # Advance the clock for many days to make the old extremity "old". We
+ # also set the depth to "lots".
+ self.reactor.advance(7 * 24 * 60 * 60)
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10000,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([remote_event_2.event_id, local_message_event_id])
+
+ def test_do_not_prune_gap_if_not_dummy(self):
+ """Test that we do not drop extremities after a gap when the previous extremity
+ is not a dummy event.
+ """
+
+ body = self.helper.send(self.room_id, body="test", tok=self.token)
+ local_message_event_id = body["event_id"]
+ self.assert_extremities([local_message_event_id])
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10000,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([local_message_event_id, remote_event_2.event_id])
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index cc0612cf65..aad6bc907e 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.db_pool,
stream_name="test_stream",
instance_name=instance_name,
- table="foobar",
- instance_column="instance_name",
- id_column="stream_id",
+ tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers,
)
@@ -88,7 +86,11 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
def _insert(txn):
txn.execute(
- "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ "INSERT INTO foobar VALUES (?, ?)",
+ (
+ stream_id,
+ instance_name,
+ ),
)
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
txn.execute(
@@ -140,8 +142,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_out_of_order_finish(self):
- """Test that IDs persisted out of order are correctly handled
- """
+ """Test that IDs persisted out of order are correctly handled"""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
@@ -248,8 +249,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
def test_get_next_txn(self):
- """Test that the `get_next_txn` function works correctly.
- """
+ """Test that the `get_next_txn` function works correctly."""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
@@ -388,8 +388,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
def test_writer_config_change(self):
- """Test that changing the writer config correctly works.
- """
+ """Test that changing the writer config correctly works."""
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
@@ -436,8 +435,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
def test_sequence_consistency(self):
- """Test that we error out if the table and sequence diverges.
- """
+ """Test that we error out if the table and sequence diverges."""
# Prefill with some rows
self._insert_row_with_id("master", 3)
@@ -454,8 +452,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
- """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
- """
+ """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
@@ -487,9 +484,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.db_pool,
stream_name="test_stream",
instance_name=instance_name,
- table="foobar",
- instance_column="instance_name",
- id_column="stream_id",
+ tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers,
positive=False,
@@ -498,12 +493,15 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_row(self, instance_name: str, stream_id: int):
- """Insert one row as the given instance with given stream_id.
- """
+ """Insert one row as the given instance with given stream_id."""
def _insert(txn):
txn.execute(
- "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ "INSERT INTO foobar VALUES (?, ?)",
+ (
+ stream_id,
+ instance_name,
+ ),
)
txn.execute(
"""
@@ -579,3 +577,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
+
+
+class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db_pool = self.store.db_pool # type: DatabasePool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar1 (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ txn.execute(
+ """
+ CREATE TABLE foobar2 (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db_pool,
+ stream_name="test_stream",
+ instance_name=instance_name,
+ tables=[
+ ("foobar1", "instance_name", "stream_id"),
+ ("foobar2", "instance_name", "stream_id"),
+ ],
+ sequence_name="foobar_seq",
+ writers=writers,
+ )
+
+ return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+
+ def _insert_rows(
+ self,
+ table: str,
+ instance_name: str,
+ number: int,
+ update_stream_table: bool = True,
+ ):
+ """Insert N rows as the given instance, inserting with stream IDs pulled
+ from the postgres sequence.
+ """
+
+ def _insert(txn):
+ for _ in range(number):
+ txn.execute(
+ "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
+ (instance_name,),
+ )
+ if update_stream_table:
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+ """,
+ (instance_name,),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+
+ def test_load_existing_stream(self):
+ """Test creating ID gens with multiple tables that have rows from after
+ the position in `stream_positions` table.
+ """
+ self._insert_rows("foobar1", "first", 3)
+ self._insert_rows("foobar2", "second", 3)
+ self._insert_rows("foobar2", "second", 1, update_stream_table=False)
+
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+
+ # The first ID gen will notice that it can advance its token to 7 as it
+ # has no in progress writes...
+ self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
+ self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
+
+ # ... but the second ID gen doesn't know that.
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index fe37d2ed5a..30e46c650d 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -48,3 +48,10 @@ class DataStoreTestCase(unittest.TestCase):
self.assertEquals(1, total)
self.assertEquals(self.displayname, users.pop()["displayname"])
+
+ users, total = yield defer.ensureDeferred(
+ self.store.get_users_paginate(0, 10, name="BC", guests=False)
+ )
+
+ self.assertEquals(1, total)
+ self.assertEquals(self.displayname, users.pop()["displayname"])
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 8d97b6d4cd..5858c7fcc4 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -198,7 +198,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# value, although it gets stored on the config object as mau_limits.
@override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
def test_reap_monthly_active_users_reserved_users(self):
- """ Tests that reaping correctly handles reaping where reserved users are
+ """Tests that reaping correctly handles reaping where reserved users are
present"""
threepids = self.hs.config.mau_limits_reserved_threepids
initial_users = len(threepids)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 7a38022e71..b7dde51224 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -48,6 +48,19 @@ class ProfileStoreTestCase(unittest.TestCase):
),
)
+ # test set to None
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.u_frank.localpart, None, 2)
+ )
+
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.u_frank.localpart)
+ )
+ )
+ )
+
@defer.inlineCallbacks
def test_avatar_url(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
@@ -66,3 +79,16 @@ class ProfileStoreTestCase(unittest.TestCase):
)
),
)
+
+ # test set to None
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(self.u_frank.localpart, None, 2)
+ )
+
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.u_frank.localpart)
+ )
+ )
+ )
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index a06ad2c03e..41af8c4847 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from synapse.api.errors import NotFoundError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -33,9 +31,12 @@ class PurgeTests(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
- def test_purge(self):
+ self.store = hs.get_datastore()
+ self.storage = self.hs.get_storage()
+
+ def test_purge_history(self):
"""
- Purging a room will delete everything before the topological point.
+ Purging a room history will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
@@ -43,30 +44,27 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
- store = self.hs.get_datastore()
- storage = self.hs.get_storage()
-
# Get the topological token
token = self.get_success(
- store.get_topological_token_for_event(last["event_id"])
+ self.store.get_topological_token_for_event(last["event_id"])
)
token_str = self.get_success(token.to_string(self.hs.get_datastore()))
# Purge everything before this topological token
self.get_success(
- storage.purge_events.purge_history(self.room_id, token_str, True)
+ self.storage.purge_events.purge_history(self.room_id, token_str, True)
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
- self.get_failure(store.get_event(first["event_id"]), NotFoundError)
- self.get_failure(store.get_event(second["event_id"]), NotFoundError)
- self.get_failure(store.get_event(third["event_id"]), NotFoundError)
- self.get_success(store.get_event(last["event_id"]))
+ self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
+ self.get_failure(self.store.get_event(second["event_id"]), NotFoundError)
+ self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
+ self.get_success(self.store.get_event(last["event_id"]))
- def test_purge_wont_delete_extrems(self):
+ def test_purge_history_wont_delete_extrems(self):
"""
- Purging a room will delete everything before the topological point.
+ Purging a room history will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
@@ -74,22 +72,43 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
- storage = self.hs.get_datastore()
-
# Set the topological token higher than it should be
token = self.get_success(
- storage.get_topological_token_for_event(last["event_id"])
+ self.store.get_topological_token_for_event(last["event_id"])
)
event = "t{}-{}".format(token.topological + 1, token.stream + 1)
# Purge everything before this topological token
- purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
- self.pump()
- f = self.failureResultOf(purge)
+ f = self.get_failure(
+ self.storage.purge_events.purge_history(self.room_id, event, True),
+ SynapseError,
+ )
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
- self.get_success(storage.get_event(first["event_id"]))
- self.get_success(storage.get_event(second["event_id"]))
- self.get_success(storage.get_event(third["event_id"]))
- self.get_success(storage.get_event(last["event_id"]))
+ self.get_success(self.store.get_event(first["event_id"]))
+ self.get_success(self.store.get_event(second["event_id"]))
+ self.get_success(self.store.get_event(third["event_id"]))
+ self.get_success(self.store.get_event(last["event_id"]))
+
+ def test_purge_room(self):
+ """
+ Purging a room will delete everything about it.
+ """
+ # Send four messages to the room
+ first = self.helper.send(self.room_id, body="test1")
+
+ # Get the current room state.
+ state_handler = self.hs.get_state_handler()
+ create_event = self.get_success(
+ state_handler.get_current_state(self.room_id, "m.room.create", "")
+ )
+ self.assertIsNotNone(create_event)
+
+ # Purge everything before this topological token
+ self.get_success(self.storage.purge_events.purge_room(self.room_id))
+
+ # The events aren't found.
+ self.store._invalidate_get_event_cache(create_event.event_id)
+ self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
+ self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index fd0add5db3..b2a0e60856 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -14,9 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from mock import Mock
-
from canonicaljson import json
from twisted.internet import defer
@@ -30,12 +27,10 @@ from tests.utils import create_room
class RedactionTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- config = self.default_config()
+ def default_config(self):
+ config = super().default_config()
config["redaction_retention_period"] = "30d"
- return self.setup_test_homeserver(
- resource_for_federation=Mock(), federation_http_client=None, config=config
- )
+ return config
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
@@ -304,8 +299,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
def test_redact_censor(self):
- """Test that a redacted event gets censored in the DB after a month
- """
+ """Test that a redacted event gets censored in the DB after a month"""
self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -375,8 +369,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.assert_dict({"content": {}}, json.loads(event_json))
def test_redact_redaction(self):
- """Tests that we can redact a redaction and can fetch it again.
- """
+ """Tests that we can redact a redaction and can fetch it again."""
self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -409,8 +402,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
def test_store_redacted_redaction(self):
- """Tests that we can store a redacted redaction.
- """
+ """Tests that we can store a redacted redaction."""
self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index c8c7a90e5d..4eb41c46e8 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -52,6 +52,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"creation_ts": 1000,
"user_type": None,
"deactivated": 0,
+ "shadow_banned": 0,
},
(yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
)
@@ -145,7 +146,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
try:
yield defer.ensureDeferred(
self.store.validate_threepid_session(
- "fake_sid", "fake_client_secret", "fake_token", 0,
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
)
)
except ThreepidValidationError as e:
@@ -158,7 +162,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
try:
yield defer.ensureDeferred(
self.store.validate_threepid_session(
- "fake_sid", "fake_client_secret", "fake_token", 0,
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
)
)
except ThreepidValidationError as e:
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 5ba1db2332..d2aed66f6d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
-
from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
@@ -34,12 +32,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(
- resource_for_federation=Mock(), federation_http_client=None
- )
- return hs
-
def prepare(self, reactor, clock, hs: TestHomeServer):
# We can't test the RoomMemberStore on its own without the other event
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8bd12fa847..2471f1267d 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -377,14 +377,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
# deliberately remove e2 (room name) from the _state_group_cache
- (
- is_all,
- known_absent,
- state_dict_ids,
- ) = self.state_datastore._state_group_cache.get(group)
+ cache_entry = self.state_datastore._state_group_cache.get(group)
+ state_dict_ids = cache_entry.value
- self.assertEqual(is_all, True)
- self.assertEqual(known_absent, set())
+ self.assertEqual(cache_entry.full, True)
+ self.assertEqual(cache_entry.known_absent, set())
self.assertDictEqual(
state_dict_ids,
{
@@ -403,14 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
fetched_keys=((e1.type, e1.state_key),),
)
- (
- is_all,
- known_absent,
- state_dict_ids,
- ) = self.state_datastore._state_group_cache.get(group)
+ cache_entry = self.state_datastore._state_group_cache.get(group)
+ state_dict_ids = cache_entry.value
- self.assertEqual(is_all, False)
- self.assertEqual(known_absent, {(e1.type, e1.state_key)})
+ self.assertEqual(cache_entry.full, False)
+ self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
############################################
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 738e912468..a6f63f4aaf 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -21,6 +21,8 @@ from tests.utils import setup_test_homeserver
ALICE = "@alice:a"
BOB = "@bob:b"
BOBBY = "@bobby:a"
+# The localpart isn't 'Bela' on purpose so we can test looking up display names.
+BELA = "@somenickname:a"
class UserDirectoryStoreTestCase(unittest.TestCase):
@@ -41,6 +43,9 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
)
yield defer.ensureDeferred(
+ self.store.update_profile_in_user_dir(BELA, "Bela", None)
+ )
+ yield defer.ensureDeferred(
self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
)
@@ -72,3 +77,21 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
)
finally:
self.hs.config.user_directory_search_all_users = False
+
+ @defer.inlineCallbacks
+ def test_search_user_dir_stop_words(self):
+ """Tests that a user can look up another user by searching for the start if its
+ display name even if that name happens to be a common English word that would
+ usually be ignored in full text searches.
+ """
+ self.hs.config.user_directory_search_all_users = True
+ try:
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+ )
+ finally:
+ self.hs.config.user_directory_search_all_users = False
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 69b4c5d6c2..3f2691ee6b 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -85,7 +85,10 @@ class EventAuthTestCase(unittest.TestCase):
# king should be able to send state
event_auth.check(
- RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False,
+ RoomVersions.V1,
+ _random_state_event(king),
+ auth_events,
+ do_sig_check=False,
)
def test_alias_event(self):
@@ -99,7 +102,10 @@ class EventAuthTestCase(unittest.TestCase):
# creator should be able to send aliases
event_auth.check(
- RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False,
+ RoomVersions.V1,
+ _alias_event(creator),
+ auth_events,
+ do_sig_check=False,
)
# Reject an event with no state key.
@@ -122,7 +128,10 @@ class EventAuthTestCase(unittest.TestCase):
# Note that the member does *not* need to be in the room.
event_auth.check(
- RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False,
+ RoomVersions.V1,
+ _alias_event(other),
+ auth_events,
+ do_sig_check=False,
)
def test_msc2432_alias_event(self):
@@ -136,7 +145,10 @@ class EventAuthTestCase(unittest.TestCase):
# creator should be able to send aliases
event_auth.check(
- RoomVersions.V6, _alias_event(creator), auth_events, do_sig_check=False,
+ RoomVersions.V6,
+ _alias_event(creator),
+ auth_events,
+ do_sig_check=False,
)
# No particular checks are done on the state key.
@@ -156,7 +168,10 @@ class EventAuthTestCase(unittest.TestCase):
# Per standard auth rules, the member must be in the room.
with self.assertRaises(AuthError):
event_auth.check(
- RoomVersions.V6, _alias_event(other), auth_events, do_sig_check=False,
+ RoomVersions.V6,
+ _alias_event(other),
+ auth_events,
+ do_sig_check=False,
)
def test_msc2209(self):
diff --git a/tests/test_federation.py b/tests/test_federation.py
index fa45f8b3b7..fc9aab32d0 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- with LoggingContext(request="lying_event"):
+ with LoggingContext():
failure = self.get_failure(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
diff --git a/tests/test_mau.py b/tests/test_mau.py
index c5ec6396a7..75d28a42df 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -19,6 +19,7 @@ import json
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.appservice import ApplicationService
from synapse.rest.client.v2_alpha import register, sync
from tests import unittest
@@ -75,6 +76,45 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ def test_as_ignores_mau(self):
+ """Test that application services can still create users when the MAU
+ limit has been reached. This only works when application service
+ user ip tracking is disabled.
+ """
+
+ # Create and sync so that the MAU counts get updated
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+
+ # check we're testing what we think we are: there should be two active users
+ self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2)
+
+ # We've created and activated two users, we shouldn't be able to
+ # register new users
+ with self.assertRaises(SynapseError) as cm:
+ self.create_user("kermit3")
+
+ e = cm.exception
+ self.assertEqual(e.code, 403)
+ self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ # Cheekily add an application service that we use to register a new user
+ # with.
+ as_token = "foobartoken"
+ self.store.services_cache.append(
+ ApplicationService(
+ token=as_token,
+ hostname=self.hs.hostname,
+ id="SomeASID",
+ sender="@as_sender:test",
+ namespaces={"users": [{"regex": "@as_*", "exclusive": True}]},
+ )
+ )
+
+ self.create_user("as_kermit4", token=as_token)
+
def test_allowed_after_a_month_mau(self):
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
@@ -192,7 +232,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.reactor.advance(100)
self.assertEqual(2, self.successResultOf(count))
- def create_user(self, localpart):
+ def create_user(self, localpart, token=None):
request_data = json.dumps(
{
"username": localpart,
@@ -201,7 +241,12 @@ class TestMauLimit(unittest.HomeserverTestCase):
}
)
- request, channel = self.make_request("POST", "/register", request_data)
+ channel = self.make_request(
+ "POST",
+ "/register",
+ request_data,
+ access_token=token,
+ )
if channel.code != 200:
raise HttpResponseException(
@@ -213,7 +258,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
return access_token
def do_sync_for_user(self, token):
- request, channel = self.make_request("GET", "/sync", access_token=token)
+ channel = self.make_request("GET", "/sync", access_token=token)
if channel.code != 200:
raise HttpResponseException(
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 759e4cd048..f696fcf89e 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -21,7 +21,7 @@ from tests import unittest
def get_sample_labels_value(sample):
- """ Extract the labels and values of a sample.
+ """Extract the labels and values of a sample.
prometheus_client 0.5 changed the sample type to a named tuple with more
members than the plain tuple had in 0.4 and earlier. This function can
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 7f67ee9e1f..ea83299918 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -15,13 +15,22 @@
from synapse.rest.media.v1.preview_url_resource import (
decode_and_calc_og,
+ get_html_media_encoding,
summarize_paragraphs,
)
from . import unittest
+try:
+ import lxml
+except ImportError:
+ lxml = None
+
+
+class SummarizeTestCase(unittest.TestCase):
+ if not lxml:
+ skip = "url preview feature requires lxml"
-class PreviewTestCase(unittest.TestCase):
def test_long_summarize(self):
example_paras = [
"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
@@ -56,7 +65,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -69,7 +78,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø lies in Northern Norway. The municipality has a population of"
" (2015) 72,066, but with an annual influx of students it has over 75,000"
@@ -96,7 +105,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -122,7 +131,7 @@ class PreviewTestCase(unittest.TestCase):
]
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -136,9 +145,12 @@ class PreviewTestCase(unittest.TestCase):
)
-class PreviewUrlTestCase(unittest.TestCase):
+class CalcOgTestCase(unittest.TestCase):
+ if not lxml:
+ skip = "url preview feature requires lxml"
+
def test_simple(self):
- html = """
+ html = b"""
<html>
<head><title>Foo</title></head>
<body>
@@ -149,10 +161,10 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment(self):
- html = """
+ html = b"""
<html>
<head><title>Foo</title></head>
<body>
@@ -164,10 +176,10 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment2(self):
- html = """
+ html = b"""
<html>
<head><title>Foo</title></head>
<body>
@@ -182,7 +194,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(
+ self.assertEqual(
og,
{
"og:title": "Foo",
@@ -191,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase):
)
def test_script(self):
- html = """
+ html = b"""
<html>
<head><title>Foo</title></head>
<body>
@@ -203,10 +215,10 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_missing_title(self):
- html = """
+ html = b"""
<html>
<body>
Some text.
@@ -216,10 +228,10 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_h1_as_title(self):
- html = """
+ html = b"""
<html>
<meta property="og:description" content="Some text."/>
<body>
@@ -230,10 +242,10 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
def test_missing_title_and_broken_h1(self):
- html = """
+ html = b"""
<html>
<body>
<h1><a href="foo"/></h1>
@@ -244,4 +256,118 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
+
+ def test_empty(self):
+ """Test a body with no data in it."""
+ html = b""
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+ self.assertEqual(og, {})
+
+ def test_no_tree(self):
+ """A valid body with no tree in it."""
+ html = b"\x00"
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+ self.assertEqual(og, {})
+
+ def test_invalid_encoding(self):
+ """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
+ html = b"""
+ <html>
+ <head><title>Foo</title></head>
+ <body>
+ Some text.
+ </body>
+ </html>
+ """
+ og = decode_and_calc_og(
+ html, "http://example.com/test.html", "invalid-encoding"
+ )
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
+
+ def test_invalid_encoding2(self):
+ """A body which doesn't match the sent character encoding."""
+ # Note that this contains an invalid UTF-8 sequence in the title.
+ html = b"""
+ <html>
+ <head><title>\xff\xff Foo</title></head>
+ <body>
+ Some text.
+ </body>
+ </html>
+ """
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+ self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
+
+
+class MediaEncodingTestCase(unittest.TestCase):
+ def test_meta_charset(self):
+ """A character encoding is found via the meta tag."""
+ encoding = get_html_media_encoding(
+ b"""
+ <html>
+ <head><meta charset="ascii">
+ </head>
+ </html>
+ """,
+ "text/html",
+ )
+ self.assertEqual(encoding, "ascii")
+
+ # A less well-formed version.
+ encoding = get_html_media_encoding(
+ b"""
+ <html>
+ <head>< meta charset = ascii>
+ </head>
+ </html>
+ """,
+ "text/html",
+ )
+ self.assertEqual(encoding, "ascii")
+
+ def test_xml_encoding(self):
+ """A character encoding is found via the meta tag."""
+ encoding = get_html_media_encoding(
+ b"""
+ <?xml version="1.0" encoding="ascii"?>
+ <html>
+ </html>
+ """,
+ "text/html",
+ )
+ self.assertEqual(encoding, "ascii")
+
+ def test_meta_xml_encoding(self):
+ """Meta tags take precedence over XML encoding."""
+ encoding = get_html_media_encoding(
+ b"""
+ <?xml version="1.0" encoding="ascii"?>
+ <html>
+ <head><meta charset="UTF-16">
+ </head>
+ </html>
+ """,
+ "text/html",
+ )
+ self.assertEqual(encoding, "UTF-16")
+
+ def test_content_type(self):
+ """A character encoding is found via the Content-Type header."""
+ # Test a few variations of the header.
+ headers = (
+ 'text/html; charset="ascii";',
+ "text/html;charset=ascii;",
+ 'text/html; charset="ascii"',
+ "text/html; charset=ascii",
+ 'text/html; charset="ascii;',
+ 'text/html; charset=ascii";',
+ )
+ for header in headers:
+ encoding = get_html_media_encoding(b"", header)
+ self.assertEqual(encoding, "ascii")
+
+ def test_fallback(self):
+ """A character encoding cannot be found in the body or header."""
+ encoding = get_html_media_encoding(b"", "text/html")
+ self.assertEqual(encoding, "utf-8")
diff --git a/tests/test_server.py b/tests/test_server.py
index 6b2d2f0401..55cde7f62f 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -64,11 +64,10 @@ class JsonResourceTests(unittest.TestCase):
"test_servlet",
)
- request, channel = make_request(
+ make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
)
- self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]})
self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
def test_callback_direct_exception(self):
@@ -85,7 +84,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"500")
@@ -109,7 +108,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"500")
@@ -127,7 +126,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@@ -149,9 +148,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- _, channel = make_request(
- self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar"
- )
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar")
self.assertEqual(channel.result["code"], b"400")
self.assertEqual(channel.json_body["error"], "Unrecognized request")
@@ -169,11 +166,14 @@ class JsonResourceTests(unittest.TestCase):
res = JsonResource(self.homeserver)
res.register_paths(
- "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet",
+ "GET",
+ [re.compile("^/_matrix/foo$")],
+ _callback,
+ "test_servlet",
)
# The path was registered as GET, but this is a HEAD request.
- _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
+ channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
@@ -205,7 +205,7 @@ class OptionsResourceTests(unittest.TestCase):
)
# render the request and return the channel
- _, channel = make_request(self.reactor, site, method, path, shorthand=False)
+ channel = make_request(self.reactor, site, method, path, shorthand=False)
return channel
def test_unknown_options_request(self):
@@ -278,7 +278,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
self.assertEqual(channel.result["code"], b"200")
body = channel.result["body"]
@@ -296,7 +296,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
self.assertEqual(channel.result["code"], b"301")
headers = channel.result["headers"]
@@ -317,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
self.assertEqual(channel.result["code"], b"304")
headers = channel.result["headers"]
@@ -336,7 +336,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
+ channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 71580b454d..a743cdc3a9 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -53,7 +53,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
def test_ui_auth(self):
# Do a UI auth request
request_data = json.dumps({"username": "kermit", "password": "monkey"})
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -96,7 +96,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.registration_handler.check_username = Mock(return_value=True)
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
# We don't bother checking that the response is correct - we'll leave that to
# other tests. We just want to make sure we're on the right path.
@@ -113,7 +113,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
},
}
)
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
# We're interested in getting a response that looks like a successful
# registration, not so much that the details are exactly what we want.
diff --git a/tests/test_types.py b/tests/test_types.py
index d4a722a30f..67ceea6e43 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -65,6 +65,10 @@ class RoomAliasTestCase(unittest.HomeserverTestCase):
self.assertEquals(room.to_string(), "#channel:my.domain")
+ def test_validate(self):
+ id_string = "#test:domain,test"
+ self.assertFalse(RoomAlias.is_valid(id_string))
+
class GroupIDTestCase(unittest.TestCase):
def test_parse(self):
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index d232b72264..43898d8142 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,13 @@ import warnings
from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar
+from mock import Mock
+
+import attr
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+
TV = TypeVar("TV")
@@ -80,3 +87,35 @@ def setup_awaitable_errors() -> Callable[[], None]:
sys.unraisablehook = unraisablehook # type: ignore
return cleanup
+
+
+def simple_async_mock(return_value=None, raises=None) -> Mock:
+ # AsyncMock is not available in python3.5, this mimics part of its behaviour
+ async def cb(*args, **kwargs):
+ if raises:
+ raise raises
+ return return_value
+
+ return Mock(side_effect=cb)
+
+
+@attr.s
+class FakeResponse:
+ """A fake twisted.web.IResponse object
+
+ there is a similar class at treq.test.test_response, but it lacks a `phrase`
+ attribute, and didn't support deliverBody until recently.
+ """
+
+ # HTTP response code
+ code = attr.ib(type=int)
+
+ # HTTP response phrase (eg b'OK' for a 200)
+ phrase = attr.ib(type=bytes)
+
+ # body of the response
+ body = attr.ib(type=bytes)
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
new file mode 100644
index 0000000000..ad563eb3f0
--- /dev/null
+++ b/tests/test_utils/html_parsers.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from html.parser import HTMLParser
+from typing import Dict, Iterable, List, Optional, Tuple
+
+
+class TestHtmlParser(HTMLParser):
+ """A generic HTML page parser which extracts useful things from the HTML"""
+
+ def __init__(self):
+ super().__init__()
+
+ # a list of links found in the doc
+ self.links = [] # type: List[str]
+
+ # the values of any hidden <input>s: map from name to value
+ self.hiddens = {} # type: Dict[str, Optional[str]]
+
+ # the values of any radio buttons: map from name to list of values
+ self.radios = {} # type: Dict[str, List[Optional[str]]]
+
+ def handle_starttag(
+ self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
+ ) -> None:
+ attr_dict = dict(attrs)
+ if tag == "a":
+ href = attr_dict["href"]
+ if href:
+ self.links.append(href)
+ elif tag == "input":
+ input_name = attr_dict.get("name")
+ if attr_dict["type"] == "radio":
+ assert input_name
+ self.radios.setdefault(input_name, []).append(attr_dict["value"])
+ elif attr_dict["type"] == "hidden":
+ assert input_name
+ self.hiddens[input_name] = attr_dict["value"]
+
+ def error(_, message):
+ raise AssertionError(message)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index fdfb840b62..74568b34f8 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -28,7 +28,7 @@ class ToTwistedHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
- self.tx_log.emit(
+ self.tx_log.emit( # type: ignore
twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
)
@@ -48,7 +48,7 @@ def setup_logging():
handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
- handler.addFilter(LoggingContextFilter(request=""))
+ handler.addFilter(LoggingContextFilter())
root_logger.addHandler(handler)
log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
diff --git a/tests/unittest.py b/tests/unittest.py
index a9d59e31f7..58a4daa1ec 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
import inspect
import logging
import time
-from typing import Optional, Tuple, Type, TypeVar, Union, overload
+from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
from mock import Mock, patch
@@ -32,6 +32,7 @@ from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
from twisted.web.resource import Resource
+from synapse import events
from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
@@ -46,6 +47,7 @@ from synapse.logging.context import (
)
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
+from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
@@ -139,7 +141,7 @@ class TestCase(unittest.TestCase):
try:
self.assertEquals(attrs[key], getattr(obj, key))
except AssertionError as e:
- raise (type(e))(e.message + " for '.%s'" % key)
+ raise (type(e))("Assert error for '.{}':".format(key)) from e
def assert_dict(self, required, actual):
"""Does a partial assert of a dict.
@@ -228,6 +230,11 @@ class HomeserverTestCase(TestCase):
self._hs_args = {"clock": self.clock, "reactor": self.reactor}
self.hs = self.make_homeserver(self.reactor, self.clock)
+ # Honour the `use_frozen_dicts` config option. We have to do this
+ # manually because this is taken care of in the app `start` code, which
+ # we don't run. Plus we want to reset it on tearDown.
+ events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts
+
if self.hs is None:
raise Exception("No homeserver returned from make_homeserver.")
@@ -254,7 +261,10 @@ class HomeserverTestCase(TestCase):
# We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success(
self.hs.get_datastore().add_access_token_to_user(
- self.helper.auth_user_id, "some_fake_token", None, None,
+ self.helper.auth_user_id,
+ "some_fake_token",
+ None,
+ None,
)
)
@@ -288,6 +298,10 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs)
+ def tearDown(self):
+ # Reset to not use frozen dicts.
+ events.USE_FROZEN_DICTS = False
+
def wait_on_thread(self, deferred, timeout=10):
"""
Wait until a Deferred is done, where it's waiting on a real thread.
@@ -320,15 +334,28 @@ class HomeserverTestCase(TestCase):
"""
Create a the root resource for the test server.
- The default implementation creates a JsonResource and calls each function in
- `servlets` to register servletes against it
+ The default calls `self.create_resource_dict` and builds the resultant dict
+ into a tree.
"""
- resource = JsonResource(self.hs)
+ root_resource = Resource()
+ create_resource_tree(self.create_resource_dict(), root_resource)
+ return root_resource
- for servlet in self.servlets:
- servlet(self.hs, resource)
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ """Create a resource tree for the test server
- return resource
+ A resource tree is a mapping from path to twisted.web.resource.
+
+ The default implementation creates a JsonResource and calls each function in
+ `servlets` to register servlets against it.
+ """
+ servlet_resource = JsonResource(self.hs)
+ for servlet in self.servlets:
+ servlet(self.hs, servlet_resource)
+ return {
+ "/_matrix/client": servlet_resource,
+ "/_synapse/admin": servlet_resource,
+ }
def default_config(self):
"""
@@ -358,24 +385,6 @@ class HomeserverTestCase(TestCase):
Function to optionally be overridden in subclasses.
"""
- # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest
- # when the `request` arg isn't given, so we define an explicit override to
- # cover that case.
- @overload
- def make_request(
- self,
- method: Union[bytes, str],
- path: Union[bytes, str],
- content: Union[bytes, dict] = b"",
- access_token: Optional[str] = None,
- shorthand: bool = True,
- federation_auth_origin: str = None,
- content_is_form: bool = False,
- await_result: bool = True,
- ) -> Tuple[SynapseRequest, FakeChannel]:
- ...
-
- @overload
def make_request(
self,
method: Union[bytes, str],
@@ -387,21 +396,11 @@ class HomeserverTestCase(TestCase):
federation_auth_origin: str = None,
content_is_form: bool = False,
await_result: bool = True,
- ) -> Tuple[T, FakeChannel]:
- ...
-
- def make_request(
- self,
- method: Union[bytes, str],
- path: Union[bytes, str],
- content: Union[bytes, dict] = b"",
- access_token: Optional[str] = None,
- request: Type[T] = SynapseRequest,
- shorthand: bool = True,
- federation_auth_origin: str = None,
- content_is_form: bool = False,
- await_result: bool = True,
- ) -> Tuple[T, FakeChannel]:
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
+ client_ip: str = "127.0.0.1",
+ ) -> FakeChannel:
"""
Create a SynapseRequest at the path using the method and containing the
given content.
@@ -423,8 +422,13 @@ class HomeserverTestCase(TestCase):
true (the default), will pump the test reactor until the the renderer
tells the channel the request is finished.
+ custom_headers: (name, value) pairs to add as request headers
+
+ client_ip: The IP to use as the requesting IP. Useful for testing
+ ratelimiting.
+
Returns:
- Tuple[synapse.http.site.SynapseRequest, channel]
+ The FakeChannel object which stores the result of the request.
"""
return make_request(
self.reactor,
@@ -438,6 +442,8 @@ class HomeserverTestCase(TestCase):
federation_auth_origin,
content_is_form,
await_result,
+ custom_headers,
+ client_ip,
)
def setup_test_homeserver(self, *args, **kwargs):
@@ -554,7 +560,7 @@ class HomeserverTestCase(TestCase):
self.hs.config.registration_shared_secret = "shared"
# Create the user
- request, channel = self.make_request("GET", "/_synapse/admin/v1/register")
+ channel = self.make_request("GET", "/_synapse/admin/v1/register")
self.assertEqual(channel.code, 200, msg=channel.result)
nonce = channel.json_body["nonce"]
@@ -579,7 +585,7 @@ class HomeserverTestCase(TestCase):
"inhibit_login": True,
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/_synapse/admin/v1/register", body.encode("utf8")
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -597,7 +603,7 @@ class HomeserverTestCase(TestCase):
if device_id:
body["device_id"] = device_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
self.assertEqual(channel.code, 200, channel.result)
@@ -665,7 +671,7 @@ class HomeserverTestCase(TestCase):
"""
body = {"type": "m.login.password", "user": username, "password": password}
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
self.assertEqual(channel.code, 403, channel.result)
@@ -691,13 +697,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
A federating homeserver that authenticates incoming requests as `other.example.com`.
"""
- def prepare(self, reactor, clock, homeserver):
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
+ return d
+
+
+class TestTransportLayerServer(JsonResource):
+ """A test implementation of TransportLayerServer
+
+ authenticates incoming requests as `other.example.com`.
+ """
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
class Authenticator:
def authenticate_request(self, request, content):
return succeed("other.example.com")
+ authenticator = Authenticator()
+
ratelimiter = FederationRateLimiter(
- clock,
+ hs.get_clock(),
FederationRateLimitConfig(
window_size=1,
sleep_limit=1,
@@ -706,11 +728,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
concurrent_requests=1000,
),
)
- federation_server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
- return super().prepare(reactor, clock, homeserver)
+ federation_server.register_servlets(hs, self, authenticator, ratelimiter)
def override_config(extra_config):
@@ -735,3 +754,29 @@ def override_config(extra_config):
return func
return decorator
+
+
+TV = TypeVar("TV")
+
+
+def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
+ """A test decorator which will skip the decorated test unless a condition is set
+
+ For example:
+
+ class MyTestCase(TestCase):
+ @skip_unless(HAS_FOO, "Cannot test without foo")
+ def test_foo(self):
+ ...
+
+ Args:
+ condition: If true, the test will be skipped
+ reason: the reason to give for skipping the test
+ """
+
+ def decorator(f: TV) -> TV:
+ if not condition:
+ f.skip = reason # type: ignore
+ return f
+
+ return decorator
diff --git a/tests/util/caches/test_cached_call.py b/tests/util/caches/test_cached_call.py
new file mode 100644
index 0000000000..f349b5ced0
--- /dev/null
+++ b/tests/util/caches/test_cached_call.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+
+from synapse.util.caches.cached_call import CachedCall, RetryOnExceptionCachedCall
+
+from tests.test_utils import get_awaitable_result
+from tests.unittest import TestCase
+
+
+class CachedCallTestCase(TestCase):
+ def test_get(self):
+ """
+ Happy-path test case: makes a couple of calls and makes sure they behave
+ correctly
+ """
+ d = Deferred()
+
+ async def f():
+ return await d
+
+ slow_call = Mock(side_effect=f)
+
+ cached_call = CachedCall(slow_call)
+
+ # the mock should not yet have been called
+ slow_call.assert_not_called()
+
+ # now fire off a couple of calls
+ completed_results = []
+
+ async def r():
+ res = await cached_call.get()
+ completed_results.append(res)
+
+ r1 = defer.ensureDeferred(r())
+ r2 = defer.ensureDeferred(r())
+
+ # neither result should be complete yet
+ self.assertNoResult(r1)
+ self.assertNoResult(r2)
+
+ # and the mock should have been called *once*, with no params
+ slow_call.assert_called_once_with()
+
+ # allow the deferred to complete, which should complete both the pending results
+ d.callback(123)
+ self.assertEqual(completed_results, [123, 123])
+ self.successResultOf(r1)
+ self.successResultOf(r2)
+
+ # another call to the getter should complete immediately
+ slow_call.reset_mock()
+ r3 = get_awaitable_result(cached_call.get())
+ self.assertEqual(r3, 123)
+ slow_call.assert_not_called()
+
+ def test_fast_call(self):
+ """
+ Test the behaviour when the underlying function completes immediately
+ """
+
+ async def f():
+ return 12
+
+ fast_call = Mock(side_effect=f)
+ cached_call = CachedCall(fast_call)
+
+ # the mock should not yet have been called
+ fast_call.assert_not_called()
+
+ # run the call a couple of times, which should complete immediately
+ self.assertEqual(get_awaitable_result(cached_call.get()), 12)
+ self.assertEqual(get_awaitable_result(cached_call.get()), 12)
+
+ # the mock should have been called once
+ fast_call.assert_called_once_with()
+
+
+class RetryOnExceptionCachedCallTestCase(TestCase):
+ def test_get(self):
+ # set up the RetryOnExceptionCachedCall around a function which will fail
+ # (after a while)
+ d = Deferred()
+
+ async def f1():
+ await d
+ raise ValueError("moo")
+
+ slow_call = Mock(side_effect=f1)
+ cached_call = RetryOnExceptionCachedCall(slow_call)
+
+ # the mock should not yet have been called
+ slow_call.assert_not_called()
+
+ # now fire off a couple of calls
+ completed_results = []
+
+ async def r():
+ try:
+ await cached_call.get()
+ except Exception as e1:
+ completed_results.append(e1)
+
+ r1 = defer.ensureDeferred(r())
+ r2 = defer.ensureDeferred(r())
+
+ # neither result should be complete yet
+ self.assertNoResult(r1)
+ self.assertNoResult(r2)
+
+ # and the mock should have been called *once*, with no params
+ slow_call.assert_called_once_with()
+
+ # complete the deferred, which should make the pending calls fail
+ d.callback(0)
+ self.assertEqual(len(completed_results), 2)
+ for e in completed_results:
+ self.assertIsInstance(e, ValueError)
+ self.assertEqual(e.args, ("moo",))
+
+ # reset the mock to return a successful result, and make another pair of calls
+ # to the getter
+ d = Deferred()
+
+ async def f2():
+ return await d
+
+ slow_call.reset_mock()
+ slow_call.side_effect = f2
+ r3 = defer.ensureDeferred(cached_call.get())
+ r4 = defer.ensureDeferred(cached_call.get())
+
+ self.assertNoResult(r3)
+ self.assertNoResult(r4)
+ slow_call.assert_called_once_with()
+
+ # let that call complete, and check the results
+ d.callback(123)
+ self.assertEqual(self.successResultOf(r3), 123)
+ self.assertEqual(self.successResultOf(r4), 123)
+
+ # and now more calls to the getter should complete immediately
+ slow_call.reset_mock()
+ self.assertEqual(get_awaitable_result(cached_call.get()), 123)
+ slow_call.assert_not_called()
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index dadfabd46d..c24c33ee91 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -25,13 +25,8 @@ from tests.unittest import TestCase
class DeferredCacheTestCase(TestCase):
def test_empty(self):
cache = DeferredCache("test")
- failed = False
- try:
+ with self.assertRaises(KeyError):
cache.get("foo")
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
def test_hit(self):
cache = DeferredCache("test")
@@ -155,13 +150,8 @@ class DeferredCacheTestCase(TestCase):
cache.prefill(("foo",), 123)
cache.invalidate(("foo",))
- failed = False
- try:
+ with self.assertRaises(KeyError):
cache.get(("foo",))
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
def test_invalidate_all(self):
cache = DeferredCache("testcache")
@@ -215,13 +205,8 @@ class DeferredCacheTestCase(TestCase):
cache.prefill(2, "two")
cache.prefill(3, "three") # 1 will be evicted
- failed = False
- try:
+ with self.assertRaises(KeyError):
cache.get(1)
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
cache.get(2)
cache.get(3)
@@ -239,13 +224,58 @@ class DeferredCacheTestCase(TestCase):
cache.prefill(3, "three")
- failed = False
- try:
+ with self.assertRaises(KeyError):
cache.get(2)
- except KeyError:
- failed = True
- self.assertTrue(failed)
+ cache.get(1)
+ cache.get(3)
+
+ def test_eviction_iterable(self):
+ cache = DeferredCache(
+ "test",
+ max_entries=3,
+ apply_cache_factor_from_config=False,
+ iterable=True,
+ )
+
+ cache.prefill(1, ["one", "two"])
+ cache.prefill(2, ["three"])
+ # Now access 1 again, thus causing 2 to be least-recently used
+ cache.get(1)
+
+ # Now add an item to the cache, which evicts 2.
+ cache.prefill(3, ["four"])
+ with self.assertRaises(KeyError):
+ cache.get(2)
+
+ # Ensure 1 & 3 are in the cache.
cache.get(1)
cache.get(3)
+
+ # Now access 1 again, thus causing 3 to be least-recently used
+ cache.get(1)
+
+ # Now add an item with multiple elements to the cache
+ cache.prefill(4, ["five", "six"])
+
+ # Both 1 and 3 are evicted since there's too many elements.
+ with self.assertRaises(KeyError):
+ cache.get(1)
+ with self.assertRaises(KeyError):
+ cache.get(3)
+
+ # Now add another item to fill the cache again.
+ cache.prefill(5, ["seven"])
+
+ # Now access 4, thus causing 5 to be least-recently used
+ cache.get(4)
+
+ # Add an empty item.
+ cache.prefill(6, [])
+
+ # 5 gets evicted and replaced since an empty element counts as an item.
+ with self.assertRaises(KeyError):
+ cache.get(5)
+ cache.get(4)
+ cache.get(6)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index cf1e3203a4..afb11b9caf 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -143,8 +143,7 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called()
def test_cache_with_sync_exception(self):
- """If the wrapped function throws synchronously, things should continue to work
- """
+ """If the wrapped function throws synchronously, things should continue to work"""
class Cls:
@cached()
@@ -165,8 +164,7 @@ class DescriptorTestCase(unittest.TestCase):
self.failureResultOf(d, SynapseError)
def test_cache_with_async_exception(self):
- """The wrapped function returns a failure
- """
+ """The wrapped function returns a failure"""
class Cls:
result = None
@@ -282,7 +280,8 @@ class DescriptorTestCase(unittest.TestCase):
try:
d = obj.fn(1)
self.assertEqual(
- current_context(), SENTINEL_CONTEXT,
+ current_context(),
+ SENTINEL_CONTEXT,
)
yield d
self.fail("No exception thrown")
@@ -374,8 +373,7 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called()
def test_cache_iterable_with_sync_exception(self):
- """If the wrapped function throws synchronously, things should continue to work
- """
+ """If the wrapped function throws synchronously, things should continue to work"""
class Cls:
@descriptors.cached(iterable=True)
diff --git a/tests/util/caches/test_responsecache.py b/tests/util/caches/test_responsecache.py
new file mode 100644
index 0000000000..f9a187b8de
--- /dev/null
+++ b/tests/util/caches/test_responsecache.py
@@ -0,0 +1,131 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.caches.response_cache import ResponseCache
+
+from tests.server import get_clock
+from tests.unittest import TestCase
+
+
+class DeferredCacheTestCase(TestCase):
+ """
+ A TestCase class for ResponseCache.
+
+ The test-case function naming has some logic to it in it's parts, here's some notes about it:
+ wait: Denotes tests that have an element of "waiting" before its wrapped result becomes available
+ (Generally these just use .delayed_return instead of .instant_return in it's wrapped call.)
+ expire: Denotes tests that test expiry after assured existence.
+ (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
+ """
+
+ def setUp(self):
+ self.reactor, self.clock = get_clock()
+
+ def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
+ return ResponseCache(self.clock, name, timeout_ms=ms)
+
+ @staticmethod
+ async def instant_return(o: str) -> str:
+ return o
+
+ async def delayed_return(self, o: str) -> str:
+ await self.clock.sleep(1)
+ return o
+
+ def test_cache_hit(self):
+ cache = self.with_cache("keeping_cache", ms=9001)
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.instant_return, expected_result)
+
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(wrap_d),
+ "initial wrap result should be the same",
+ )
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(cache.get(0)),
+ "cache should have the result",
+ )
+
+ def test_cache_miss(self):
+ cache = self.with_cache("trashing_cache", ms=0)
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.instant_return, expected_result)
+
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(wrap_d),
+ "initial wrap result should be the same",
+ )
+ self.assertIsNone(cache.get(0), "cache should not have the result now")
+
+ def test_cache_expire(self):
+ cache = self.with_cache("short_cache", ms=1000)
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.instant_return, expected_result)
+
+ self.assertEqual(expected_result, self.successResultOf(wrap_d))
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(cache.get(0)),
+ "cache should still have the result",
+ )
+
+ # cache eviction timer is handled
+ self.reactor.pump((2,))
+
+ self.assertIsNone(cache.get(0), "cache should not have the result now")
+
+ def test_cache_wait_hit(self):
+ cache = self.with_cache("neutral_cache")
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.delayed_return, expected_result)
+ self.assertNoResult(wrap_d)
+
+ # function wakes up, returns result
+ self.reactor.pump((2,))
+
+ self.assertEqual(expected_result, self.successResultOf(wrap_d))
+
+ def test_cache_wait_expire(self):
+ cache = self.with_cache("medium_cache", ms=3000)
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.delayed_return, expected_result)
+ self.assertNoResult(wrap_d)
+
+ # stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2
+ self.reactor.pump((1, 1))
+
+ self.assertEqual(expected_result, self.successResultOf(wrap_d))
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(cache.get(0)),
+ "cache should still have the result",
+ )
+
+ # (1 + 1 + 2) > 3.0, cache eviction timer is handled
+ self.reactor.pump((2,))
+
+ self.assertIsNone(cache.get(0), "cache should not have the result now")
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index 34fdc9a43a..2f41333f4c 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -27,7 +27,9 @@ class DictCacheTestCase(unittest.TestCase):
key = "test_simple_cache_hit_full"
v = self.cache.get(key)
- self.assertEqual((False, set(), {}), v)
+ self.assertIs(v.full, False)
+ self.assertEqual(v.known_absent, set())
+ self.assertEqual({}, v.value)
seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"}
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 0ab0a91483..e931a7ec18 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.iterutils import chunk_seq
+from typing import Dict, List
+
+from synapse.util.iterutils import chunk_seq, sorted_topologically
from tests.unittest import TestCase
@@ -22,26 +24,87 @@ class ChunkSeqTests(TestCase):
parts = chunk_seq("123", 8)
self.assertEqual(
- list(parts), ["123"],
+ list(parts),
+ ["123"],
)
def test_long_seq(self):
parts = chunk_seq("abcdefghijklmnop", 8)
self.assertEqual(
- list(parts), ["abcdefgh", "ijklmnop"],
+ list(parts),
+ ["abcdefgh", "ijklmnop"],
)
def test_uneven_parts(self):
parts = chunk_seq("abcdefghijklmnop", 5)
self.assertEqual(
- list(parts), ["abcde", "fghij", "klmno", "p"],
+ list(parts),
+ ["abcde", "fghij", "klmno", "p"],
)
def test_empty_input(self):
parts = chunk_seq([], 5)
self.assertEqual(
- list(parts), [],
+ list(parts),
+ [],
)
+
+
+class SortTopologically(TestCase):
+ def test_empty(self):
+ "Test that an empty graph works correctly"
+
+ graph = {} # type: Dict[int, List[int]]
+ self.assertEqual(list(sorted_topologically([], graph)), [])
+
+ def test_handle_empty_graph(self):
+ "Test that a graph where a node doesn't have an entry is treated as empty"
+
+ graph = {} # type: Dict[int, List[int]]
+
+ # For disconnected nodes the output is simply sorted.
+ self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+ def test_disconnected(self):
+ "Test that a graph with no edges work"
+
+ graph = {1: [], 2: []} # type: Dict[int, List[int]]
+
+ # For disconnected nodes the output is simply sorted.
+ self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+ def test_linear(self):
+ "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
+
+ graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
+
+ self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+ def test_subset(self):
+ "Test that only sorting a subset of the graph works"
+ graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
+
+ self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
+
+ def test_fork(self):
+ "Test that a forked graph works"
+ graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]]
+
+ # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
+ # always get the same one.
+ self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+ def test_duplicates(self):
+ "Test that a graph with duplicate edges work"
+ graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]]
+
+ self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+ def test_multiple_paths(self):
+ "Test that a graph with multiple paths between two nodes work"
+ graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]]
+
+ self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 13b753e367..9ed01f7e0c 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -70,7 +70,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertTrue("user@foo.com" not in cache._entity_to_key)
self.assertEqual(
- cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"],
+ cache.get_all_entities_changed(2),
+ ["bar@baz.net", "user@elsewhere.org"],
)
self.assertIsNone(cache.get_all_entities_changed(1))
@@ -80,7 +81,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
{"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
)
self.assertEqual(
- cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"],
+ cache.get_all_entities_changed(2),
+ ["user@elsewhere.org", "bar@baz.net"],
)
self.assertIsNone(cache.get_all_entities_changed(1))
@@ -222,7 +224,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# Query a subset of the entries mid-way through the stream. We should
# only get back the subset.
self.assertEqual(
- cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"},
+ cache.get_entities_changed(["bar@baz.net"], stream_pos=2),
+ {"bar@baz.net"},
)
def test_max_pos(self):
diff --git a/tests/utils.py b/tests/utils.py
index 1584eacb12..5d299f766f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,13 +20,12 @@ import os
import time
import uuid
import warnings
-from inspect import getcallargs
from typing import Type
from urllib import parse as urlparse
from mock import Mock, patch
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
@@ -34,15 +33,12 @@ from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.federation.transport import server as federation_server
-from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database
-from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
#
@@ -118,7 +114,6 @@ def default_config(name, parse=False):
"server_name": name,
"send_federation": False,
"media_store_path": "media",
- "uploads_path": "uploads",
# the test signing key is just an arbitrary ed25519 key to keep the config
# parser happy
"signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",
@@ -161,6 +156,11 @@ def default_config(name, parse=False):
"local": {"per_second": 10000, "burst_count": 10000},
"remote": {"per_second": 10000, "burst_count": 10000},
},
+ "rc_invites": {
+ "per_room": {"per_second": 10000, "burst_count": 10000},
+ "per_user": {"per_second": 10000, "burst_count": 10000},
+ },
+ "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
"saml2_enabled": False,
"public_baseurl": None,
"default_identity_server": None,
@@ -268,7 +268,10 @@ def setup_test_homeserver(
db_conn.close()
hs = homeserver_to_use(
- name, config=config, version_string="Synapse/tests", reactor=reactor,
+ name,
+ config=config,
+ version_string="Synapse/tests",
+ reactor=reactor,
)
# Install @cache_in_self attributes
@@ -344,32 +347,9 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash
- fed = kwargs.get("resource_for_federation", None)
- if fed:
- register_federation_servlets(hs, fed)
-
return hs
-def register_federation_servlets(hs, resource):
- federation_server.register_servlets(
- hs,
- resource=resource,
- authenticator=federation_server.Authenticator(hs),
- ratelimiter=FederationRateLimiter(
- hs.get_clock(), config=hs.config.rc_federation
- ),
- )
-
-
-def get_mock_call_args(pattern_func, mock_func):
- """ Return the arguments the mock function was called with interpreted
- by the pattern functions argument list.
- """
- invoked_args, invoked_kargs = mock_func.call_args
- return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
-
-
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
@@ -380,7 +360,7 @@ def mock_getRawHeaders(headers=None):
# This is a mock /resource/ not an entire server
-class MockHttpResource(HttpServer):
+class MockHttpResource:
def __init__(self, prefix=""):
self.callbacks = [] # 3-tuple of method/pattern/function
self.prefix = prefix
@@ -393,7 +373,7 @@ class MockHttpResource(HttpServer):
def trigger(
self, http_method, path, content, mock_request, federation_auth_origin=None
):
- """ Fire an HTTP event.
+ """Fire an HTTP event.
Args:
http_method : The HTTP method
@@ -555,89 +535,8 @@ class MockClock:
return d
-def _format_call(args, kwargs):
- return ", ".join(
- ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
- )
-
-
-class DeferredMockCallable:
- """A callable instance that stores a set of pending call expectations and
- return values for them. It allows a unit test to assert that the given set
- of function calls are eventually made, by awaiting on them to be called.
- """
-
- def __init__(self):
- self.expectations = []
- self.calls = []
-
- def __call__(self, *args, **kwargs):
- self.calls.append((args, kwargs))
-
- if not self.expectations:
- raise ValueError(
- "%r has no pending calls to handle call(%s)"
- % (self, _format_call(args, kwargs))
- )
-
- for (call, result, d) in self.expectations:
- if args == call[1] and kwargs == call[2]:
- d.callback(None)
- return result
-
- failure = AssertionError(
- "Was not expecting call(%s)" % (_format_call(args, kwargs))
- )
-
- for _, _, d in self.expectations:
- try:
- d.errback(failure)
- except Exception:
- pass
-
- raise failure
-
- def expect_call_and_return(self, call, result):
- self.expectations.append((call, result, defer.Deferred()))
-
- @defer.inlineCallbacks
- def await_calls(self, timeout=1000):
- deferred = defer.DeferredList(
- [d for _, _, d in self.expectations], fireOnOneErrback=True
- )
-
- timer = reactor.callLater(
- timeout / 1000,
- deferred.errback,
- AssertionError(
- "%d pending calls left: %s"
- % (
- len([e for e in self.expectations if not e[2].called]),
- [e for e in self.expectations if not e[2].called],
- )
- ),
- )
-
- yield deferred
-
- timer.cancel()
-
- self.calls = []
-
- def assert_had_no_calls(self):
- if self.calls:
- calls = self.calls
- self.calls = []
-
- raise AssertionError(
- "Expected not to received any calls, got:\n"
- + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
- )
-
-
async def create_room(hs, room_id: str, creator_id: str):
- """Creates and persist a creation event for the given room
- """
+ """Creates and persist a creation event for the given room"""
persistence_store = hs.get_storage().persistence
store = hs.get_datastore()
diff --git a/tox.ini b/tox.ini
index adb0374f32..5365939e10 100644
--- a/tox.ini
+++ b/tox.ini
@@ -2,12 +2,13 @@
envlist = packaging, py35, py36, py37, py38, py39, check_codestyle, check_isort
[base]
-extras = test
deps =
python-subunit
junitxml
coverage
- coverage-enable-subprocess
+
+ # this is pinned since it's a bit of an obscure package.
+ coverage-enable-subprocess==1.0
# cyptography 2.2 requires setuptools >= 18.5
#
@@ -25,29 +26,57 @@ deps =
pip>=10 ; python_version >= '3.6'
pip>=10,<21.0 ; python_version < '3.6'
-setenv =
- PYTHONDONTWRITEBYTECODE = no_byte_code
- COVERAGE_PROCESS_START = {toxinidir}/.coveragerc
-
+# directories/files we run the linters on.
+# if you update this list, make sure to do the same in scripts-dev/lint.sh
+lint_targets =
+ setup.py
+ synapse
+ tests
+ scripts
+ scripts-dev
+ stubs
+ contrib
+ synctl
+ synmark
+ .buildkite
+ docker
+
+# default settings for all tox environments
[testenv]
deps =
{[base]deps}
-extras = all, test
-
-whitelist_externals =
- sh
+extras =
+ # install the optional dependendencies for tox environments without
+ # '-noextras' in their name
+ !noextras: all
+ test
setenv =
- {[base]setenv}
+ # use a postgres db for tox environments with "-postgres" in the name
+ # (see https://tox.readthedocs.io/en/3.20.1/config.html#factors-and-factor-conditional-settings)
postgres: SYNAPSE_POSTGRES = 1
+
+ # this is used by .coveragerc to refer to the top of our tree.
TOP={toxinidir}
passenv = *
commands =
- /usr/bin/find "{toxinidir}" -name '*.pyc' -delete
- # Add this so that coverage will run on subprocesses
- {envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
+ # the "env" invocation enables coverage checking for sub-processes. This is
+ # particularly important when running trial with `-j`, since that will make
+ # it run tests in a subprocess, whose coverage would otherwise not be
+ # tracked. (It also makes an explicit `coverage run` command redundant.)
+ #
+ # (See https://coverage.readthedocs.io/en/coverage-5.3/subprocess.html.
+ # Note that the `coverage.process_startup()` call is done by
+ # `coverage-enable-subprocess`.)
+ #
+ # we use "env" rather than putting a value in `setenv` so that it is not
+ # inherited by other tox environments.
+ #
+ # keep this in sync with the copy in `testenv:py3-old`.
+ #
+ /usr/bin/env COVERAGE_PROCESS_START={toxinidir}/.coveragerc "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
# As of twisted 16.4, trial tries to import the tests as a package (previously
# it loaded the files explicitly), which means they need to be on the
@@ -74,8 +103,9 @@ usedevelop=true
# A test suite for the oldest supported versions of Python libraries, to catch
# any uses of APIs not available in them.
-[testenv:py35-old]
-skip_install=True
+[testenv:py3-old]
+skip_install = true
+usedevelop = false
deps =
# Old automat version for Twisted
Automat == 0.3.0
@@ -83,16 +113,19 @@ deps =
{[base]deps}
commands =
- /usr/bin/find "{toxinidir}" -name '*.pyc' -delete
# Make all greater-thans equals so we test the oldest version of our direct
# dependencies, but make the pyopenssl 17.0, which can work against an
# OpenSSL 1.1 compiled cryptography (as older ones don't compile on Travis).
- /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "s/psycopg2==2.6//" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install'
+ /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "/psycopg2/d" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install'
# Install Synapse itself. This won't update any libraries.
pip install -e ".[test]"
- {envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
+ # we have to duplicate the command from `testenv` rather than refer to it
+ # as `{[testenv]commands}`, because we run on ubuntu xenial, which has
+ # tox 2.3.1, and https://github.com/tox-dev/tox/issues/208.
+ #
+ /usr/bin/env COVERAGE_PROCESS_START={toxinidir}/.coveragerc "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
[testenv:benchmark]
deps =
@@ -113,13 +146,13 @@ commands =
[testenv:check_codestyle]
extras = lint
commands =
- python -m black --check --diff .
- /bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}"
+ python -m black --check --diff {[base]lint_targets}
+ flake8 {[base]lint_targets} {env:PEP8SUFFIX:}
{toxinidir}/scripts-dev/config-lint.sh
[testenv:check_isort]
extras = lint
-commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev scripts"
+commands = isort -c --df --sp setup.cfg {[base]lint_targets}
[testenv:check-newsfragment]
skip_install = True
@@ -157,12 +190,5 @@ commands=
[testenv:mypy]
deps =
{[base]deps}
- # Type hints are broken with Twisted > 20.3.0, see https://github.com/matrix-org/synapse/issues/9513
- # TODO: Remove after merging in the fixes from mainline
- twisted==20.3.0
extras = all,mypy
commands = mypy
-
-# To find all folders that pass mypy you run:
-#
-# find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null" \; -print
|