diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
new file mode 100644
index 0000000000..ff05a7428b
--- /dev/null
+++ b/.buildkite/pipeline.yml
@@ -0,0 +1,310 @@
+env:
+ COVERALLS_REPO_TOKEN: wsJWOby6j0uCYFiCes3r0XauxO27mx8lD
+
+steps:
+ - command:
+ - "python -m pip install tox"
+ - "tox -e check_codestyle"
+ label: "\U0001F9F9 Check Style"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - command:
+ - "python -m pip install tox"
+ - "tox -e packaging"
+ label: "\U0001F9F9 packaging"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - command:
+ - "python -m pip install tox"
+ - "tox -e check_isort"
+ label: "\U0001F9F9 isort"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - command:
+ - "python -m pip install tox"
+ - "scripts-dev/check-newsfragment"
+ label: ":newspaper: Newsfile"
+ branches: "!master !develop !release-*"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ propagate-environment: true
+ mount-buildkite-agent: false
+
+ - command:
+ - "python -m pip install tox"
+ - "tox -e check-sampleconfig"
+ label: "\U0001F9F9 check-sample-config"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ mount-buildkite-agent: false
+
+ - command:
+ - "python -m pip install tox"
+ - "tox -e mypy"
+ label: ":mypy: mypy"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.5"
+ mount-buildkite-agent: false
+
+ - wait
+
+ - command:
+ - "apt-get update && apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev"
+ - "python3.5 -m pip install tox"
+ - "tox -e py35-old,combine"
+ label: ":python: 3.5 / SQLite / Old Deps"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ LANG: "C.UTF-8"
+ plugins:
+ - docker#v3.0.1:
+ image: "ubuntu:xenial" # We use xenial to get an old sqlite and python
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - matrix-org/coveralls#v1.0:
+ parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - command:
+ - "python -m pip install tox"
+ - "tox -e py35,combine"
+ label: ":python: 3.5 / SQLite"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.5"
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - matrix-org/coveralls#v1.0:
+ parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - command:
+ - "python -m pip install tox"
+ - "tox -e py36,combine"
+ label: ":python: 3.6 / SQLite"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.6"
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - matrix-org/coveralls#v1.0:
+ parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - command:
+ - "python -m pip install tox"
+ - "tox -e py37,combine"
+ label: ":python: 3.7 / SQLite"
+ env:
+ TRIAL_FLAGS: "-j 2"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.7"
+ workdir: "/src"
+ mount-buildkite-agent: false
+ propagate-environment: true
+ - matrix-org/coveralls#v1.0:
+ parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.5 / :postgres: 9.5"
+ agents:
+ queue: "medium"
+ env:
+ TRIAL_FLAGS: "-j 8"
+ command:
+ - "bash -c 'python -m pip install tox && python -m tox -e py35-postgres,combine'"
+ plugins:
+ - docker-compose#v2.1.0:
+ run: testenv
+ config:
+ - .buildkite/docker-compose.py35.pg95.yaml
+ - matrix-org/coveralls#v1.0:
+ parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.7 / :postgres: 9.5"
+ agents:
+ queue: "medium"
+ env:
+ TRIAL_FLAGS: "-j 8"
+ command:
+ - "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,combine'"
+ plugins:
+ - docker-compose#v2.1.0:
+ run: testenv
+ config:
+ - .buildkite/docker-compose.py37.pg95.yaml
+ - matrix-org/coveralls#v1.0:
+ parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: ":python: 3.7 / :postgres: 11"
+ agents:
+ queue: "medium"
+ env:
+ TRIAL_FLAGS: "-j 8"
+ command:
+ - "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,combine'"
+ plugins:
+ - docker-compose#v2.1.0:
+ run: testenv
+ config:
+ - .buildkite/docker-compose.py37.pg11.yaml
+ - matrix-org/coveralls#v1.0:
+ parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: "SyTest - :python: 3.5 / SQLite / Monolith"
+ agents:
+ queue: "medium"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash /synapse_sytest.sh"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:py35"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: ["/bin/sh", "-e", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/coverage.xml" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ class: "error"
+ - matrix-org/coveralls#v1.0:
+ parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: "SyTest - :python: 3.5 / :postgres: 9.6 / Monolith"
+ agents:
+ queue: "xlarge"
+ env:
+ POSTGRES: "1"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash /synapse_sytest.sh"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:dinsic-py3"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: ["/bin/sh", "-e", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/coverage.xml" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ class: "error"
+ - matrix-org/coveralls#v1.0:
+ parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - label: "SyTest - :python: 3 / :postgres: 9.6 / Workers"
+ agents:
+ queue: "medium"
+ env:
+ POSTGRES: "1"
+ WORKERS: "1"
+ BLACKLIST: "synapse-blacklist-with-workers"
+ command:
+ - "bash .buildkite/merge_base_branch.sh"
+ - "bash -c 'cat /src/sytest-blacklist /src/.buildkite/worker-blacklist > /src/synapse-blacklist-with-workers'"
+ - "bash /synapse_sytest.sh"
+ plugins:
+ - docker#v3.0.1:
+ image: "matrixdotorg/sytest-synapse:dinsic-py3"
+ propagate-environment: true
+ always-pull: true
+ workdir: "/src"
+ entrypoint: ["/bin/sh", "-e", "-c"]
+ mount-buildkite-agent: false
+ volumes: ["./logs:/logs"]
+ - artifacts#v1.2.0:
+ upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/coverage.xml" ]
+ - matrix-org/annotate:
+ path: "logs/annotate.md"
+ class: "error"
+ - matrix-org/coveralls#v1.0:
+ parallel: "true"
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 2
+ - exit_status: 2
+ limit: 2
+
+ - wait: ~
+ continue_on_failure: true
+
+ - label: Trigger webhook
+ command: "curl -k https://coveralls.io/webhook?repo_token=$COVERALLS_REPO_TOKEN -d \"payload[build_num]=$BUILDKITE_BUILD_NUMBER&payload[status]=done\""
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 4b01b6ac8c..253a0ca648 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -60,7 +60,7 @@ python 3.6 and to install each tool:
```
# Install the dependencies
-pip install -U black flake8 isort
+pip install -U black flake8 flake8-comprehensions isort
# Run the linter script
./scripts-dev/lint.sh
diff --git a/INSTALL.md b/INSTALL.md
index aa5eb882bb..ffb82bdcc3 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -418,7 +418,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
for having Synapse automatically provision and renew federation
certificates through ACME can be found at [ACME.md](docs/ACME.md).
Note that, as pointed out in that document, this feature will not
- work with installs set up after November 2020.
+ work with installs set up after November 2019.
If you are using your own certificate, be sure to use a `.pem` file that
includes the full certificate chain including any intermediate certificates
diff --git a/MANIFEST.in b/MANIFEST.in
index 156d6f04f7..5eb8e62d34 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,4 +1,5 @@
include synctl
+include sytest-blacklist
include LICENSE
include VERSION
include *.rst
@@ -50,3 +51,11 @@ prune docker
prune mypy.ini
prune snap
prune stubs
+
+exclude jenkins*
+recursive-exclude jenkins *.sh
+
+# FIXME: we shouldn't have these templates here
+recursive-include res/templates-dinsic *.css
+recursive-include res/templates-dinsic *.html
+recursive-include res/templates-dinsic *.txt
diff --git a/changelog.d/1.feature b/changelog.d/1.feature
new file mode 100644
index 0000000000..845642e445
--- /dev/null
+++ b/changelog.d/1.feature
@@ -0,0 +1 @@
+Forbid changing the name, avatar or topic of a direct room.
diff --git a/changelog.d/10.bugfix b/changelog.d/10.bugfix
new file mode 100644
index 0000000000..51f89f46dd
--- /dev/null
+++ b/changelog.d/10.bugfix
@@ -0,0 +1 @@
+Don't apply retention policy based filtering on state events.
diff --git a/changelog.d/11.feature b/changelog.d/11.feature
new file mode 100644
index 0000000000..362e4b1efd
--- /dev/null
+++ b/changelog.d/11.feature
@@ -0,0 +1 @@
+Allow server admins to configure a custom global rate-limiting for third party invites.
\ No newline at end of file
diff --git a/changelog.d/12.feature b/changelog.d/12.feature
new file mode 100644
index 0000000000..8e6e7a28af
--- /dev/null
+++ b/changelog.d/12.feature
@@ -0,0 +1 @@
+Add `/user/:user_id/info` CS servlet and to give user deactivated/expired information.
\ No newline at end of file
diff --git a/changelog.d/13.feature b/changelog.d/13.feature
new file mode 100644
index 0000000000..c2d2e93abf
--- /dev/null
+++ b/changelog.d/13.feature
@@ -0,0 +1 @@
+Hide expired users from the user directory, and optionally re-add them on renewal.
\ No newline at end of file
diff --git a/changelog.d/14.feature b/changelog.d/14.feature
new file mode 100644
index 0000000000..020d0bac1e
--- /dev/null
+++ b/changelog.d/14.feature
@@ -0,0 +1 @@
+User displaynames now have capitalised letters after - symbols.
\ No newline at end of file
diff --git a/changelog.d/15.misc b/changelog.d/15.misc
new file mode 100644
index 0000000000..4cc4a5175f
--- /dev/null
+++ b/changelog.d/15.misc
@@ -0,0 +1 @@
+Fix the ordering on `scripts/generate_signing_key.py`'s import statement.
diff --git a/changelog.d/17.misc b/changelog.d/17.misc
new file mode 100644
index 0000000000..58120ab5c7
--- /dev/null
+++ b/changelog.d/17.misc
@@ -0,0 +1 @@
+Blacklist some flaky sytests until they're fixed.
\ No newline at end of file
diff --git a/changelog.d/18.feature b/changelog.d/18.feature
new file mode 100644
index 0000000000..f5aa29a6e8
--- /dev/null
+++ b/changelog.d/18.feature
@@ -0,0 +1 @@
+Add option `limit_profile_requests_to_known_users` to prevent requirement of a user sharing a room with another user to query their profile information.
\ No newline at end of file
diff --git a/changelog.d/19.feature b/changelog.d/19.feature
new file mode 100644
index 0000000000..95a44a4a89
--- /dev/null
+++ b/changelog.d/19.feature
@@ -0,0 +1 @@
+Add `max_avatar_size` and `allowed_avatar_mimetypes` to restrict the size of user avatars and their file type respectively.
\ No newline at end of file
diff --git a/changelog.d/2.bugfix b/changelog.d/2.bugfix
new file mode 100644
index 0000000000..4fe5691468
--- /dev/null
+++ b/changelog.d/2.bugfix
@@ -0,0 +1 @@
+Don't treat 3PID revocation as a new 3PID invite.
diff --git a/changelog.d/20.bugfix b/changelog.d/20.bugfix
new file mode 100644
index 0000000000..8ba53c28f9
--- /dev/null
+++ b/changelog.d/20.bugfix
@@ -0,0 +1 @@
+Validate `client_secret` parameter against the regex provided by the C-S spec.
\ No newline at end of file
diff --git a/changelog.d/21.bugfix b/changelog.d/21.bugfix
new file mode 100644
index 0000000000..630d7812f7
--- /dev/null
+++ b/changelog.d/21.bugfix
@@ -0,0 +1 @@
+Fix resetting user passwords via a phone number.
diff --git a/changelog.d/3.bugfix b/changelog.d/3.bugfix
new file mode 100644
index 0000000000..cc4bcefa80
--- /dev/null
+++ b/changelog.d/3.bugfix
@@ -0,0 +1 @@
+Fix encoding on password reset HTML responses in Python 2.
diff --git a/changelog.d/4.bugfix b/changelog.d/4.bugfix
new file mode 100644
index 0000000000..fe717920a6
--- /dev/null
+++ b/changelog.d/4.bugfix
@@ -0,0 +1 @@
+Fix handling of filtered strings in Python 3.
diff --git a/changelog.d/5.bugfix b/changelog.d/5.bugfix
new file mode 100644
index 0000000000..53f57f46ca
--- /dev/null
+++ b/changelog.d/5.bugfix
@@ -0,0 +1 @@
+Fix room retention policy management in worker mode.
diff --git a/changelog.d/5083.feature b/changelog.d/5083.feature
new file mode 100644
index 0000000000..2ffdd37eef
--- /dev/null
+++ b/changelog.d/5083.feature
@@ -0,0 +1 @@
+Adds auth_profile_reqs option to require access_token to GET /profile endpoints on CS API.
diff --git a/changelog.d/5098.misc b/changelog.d/5098.misc
new file mode 100644
index 0000000000..9cd83bf226
--- /dev/null
+++ b/changelog.d/5098.misc
@@ -0,0 +1 @@
+Add workarounds for pep-517 install errors.
diff --git a/changelog.d/5214.feature b/changelog.d/5214.feature
new file mode 100644
index 0000000000..6c0f15c901
--- /dev/null
+++ b/changelog.d/5214.feature
@@ -0,0 +1 @@
+Allow server admins to define and enforce a password policy (MSC2000).
diff --git a/changelog.d/5416.misc b/changelog.d/5416.misc
new file mode 100644
index 0000000000..155e8c7cd3
--- /dev/null
+++ b/changelog.d/5416.misc
@@ -0,0 +1 @@
+Add unique index to the profile_replication_status table.
diff --git a/changelog.d/5420.feature b/changelog.d/5420.feature
new file mode 100644
index 0000000000..745864b903
--- /dev/null
+++ b/changelog.d/5420.feature
@@ -0,0 +1 @@
+Add configuration option to hide new users from the user directory.
diff --git a/changelog.d/5610.feature b/changelog.d/5610.feature
new file mode 100644
index 0000000000..b99514f97e
--- /dev/null
+++ b/changelog.d/5610.feature
@@ -0,0 +1 @@
+Implement new custom event rules for power levels.
diff --git a/changelog.d/5702.bugfix b/changelog.d/5702.bugfix
new file mode 100644
index 0000000000..43b6e39b13
--- /dev/null
+++ b/changelog.d/5702.bugfix
@@ -0,0 +1 @@
+Fix 3PID invite to invite association detection in the Tchap room access rules.
diff --git a/changelog.d/5760.feature b/changelog.d/5760.feature
new file mode 100644
index 0000000000..90302d793e
--- /dev/null
+++ b/changelog.d/5760.feature
@@ -0,0 +1 @@
+Force the access rule to be "restricted" if the join rule is "public".
diff --git a/changelog.d/6.bugfix b/changelog.d/6.bugfix
new file mode 100644
index 0000000000..43ab65cc95
--- /dev/null
+++ b/changelog.d/6.bugfix
@@ -0,0 +1 @@
+Don't forbid membership events which membership isn't 'join' or 'invite' in restricted rooms, so that users who got into these rooms before the access rules started to be enforced can leave them.
diff --git a/changelog.d/6315.feature b/changelog.d/6315.feature
new file mode 100644
index 0000000000..c5377dd1e9
--- /dev/null
+++ b/changelog.d/6315.feature
@@ -0,0 +1 @@
+Expose the `synctl`, `hash_password` and `generate_config` commands in the snapcraft package. Contributed by @devec0.
diff --git a/changelog.d/6572.bugfix b/changelog.d/6572.bugfix
new file mode 100644
index 0000000000..4f708f409f
--- /dev/null
+++ b/changelog.d/6572.bugfix
@@ -0,0 +1 @@
+When a user's profile is updated via the admin API, also generate a displayname/avatar update for that user in each room.
diff --git a/changelog.d/6615.misc b/changelog.d/6615.misc
new file mode 100644
index 0000000000..9f93152565
--- /dev/null
+++ b/changelog.d/6615.misc
@@ -0,0 +1 @@
+Add some clarifications to `README.md` in the database schema directory.
diff --git a/changelog.d/6941.removal b/changelog.d/6941.removal
new file mode 100644
index 0000000000..8573be84b3
--- /dev/null
+++ b/changelog.d/6941.removal
@@ -0,0 +1 @@
+Stop sending m.room.aliases events during room creation and upgrade.
diff --git a/changelog.d/6952.misc b/changelog.d/6952.misc
new file mode 100644
index 0000000000..e26dc5cab8
--- /dev/null
+++ b/changelog.d/6952.misc
@@ -0,0 +1 @@
+Improve perf of v2 state res for large rooms.
diff --git a/changelog.d/6953.misc b/changelog.d/6953.misc
new file mode 100644
index 0000000000..0ab52041cf
--- /dev/null
+++ b/changelog.d/6953.misc
@@ -0,0 +1 @@
+Reduce time spent doing GC by freezing objects on startup.
diff --git a/changelog.d/6954.misc b/changelog.d/6954.misc
new file mode 100644
index 0000000000..8b84ce2f19
--- /dev/null
+++ b/changelog.d/6954.misc
@@ -0,0 +1 @@
+Minor perf fixes to `get_auth_chain_ids`.
diff --git a/changelog.d/6956.misc b/changelog.d/6956.misc
new file mode 100644
index 0000000000..5cb0894182
--- /dev/null
+++ b/changelog.d/6956.misc
@@ -0,0 +1 @@
+Don't record remote cross-signing keys in the `devices` table.
diff --git a/changelog.d/6957.misc b/changelog.d/6957.misc
new file mode 100644
index 0000000000..4f98030110
--- /dev/null
+++ b/changelog.d/6957.misc
@@ -0,0 +1 @@
+Use flake8-comprehensions to enforce good hygiene of list/set/dict comprehensions.
diff --git a/changelog.d/6962.bugfix b/changelog.d/6962.bugfix
new file mode 100644
index 0000000000..9f5229d400
--- /dev/null
+++ b/changelog.d/6962.bugfix
@@ -0,0 +1 @@
+Fix a couple of bugs in email configuration handling.
diff --git a/changelog.d/6964.misc b/changelog.d/6964.misc
new file mode 100644
index 0000000000..ec5c004bbe
--- /dev/null
+++ b/changelog.d/6964.misc
@@ -0,0 +1 @@
+Merge worker apps together.
diff --git a/changelog.d/6965.feature b/changelog.d/6965.feature
new file mode 100644
index 0000000000..6ad9956e40
--- /dev/null
+++ b/changelog.d/6965.feature
@@ -0,0 +1 @@
+Publishing/removing a room from the room directory now requires the user to have a power level capable of modifying the canonical alias, instead of the room aliases.
diff --git a/changelog.d/6966.removal b/changelog.d/6966.removal
new file mode 100644
index 0000000000..69673d9139
--- /dev/null
+++ b/changelog.d/6966.removal
@@ -0,0 +1 @@
+Synapse no longer uses room alias events to calculate room names for email notifications.
diff --git a/changelog.d/6967.bugfix b/changelog.d/6967.bugfix
new file mode 100644
index 0000000000..b65f80cf1d
--- /dev/null
+++ b/changelog.d/6967.bugfix
@@ -0,0 +1 @@
+Fix an issue affecting worker-based deployments where replication would stop working, necessitating a full restart, after joining a large room.
diff --git a/changelog.d/6968.bugfix b/changelog.d/6968.bugfix
new file mode 100644
index 0000000000..9965bfc0c3
--- /dev/null
+++ b/changelog.d/6968.bugfix
@@ -0,0 +1 @@
+Fix `duplicate key` error which was logged when rejoining a room over federation.
diff --git a/changelog.d/6970.removal b/changelog.d/6970.removal
new file mode 100644
index 0000000000..89bd363b95
--- /dev/null
+++ b/changelog.d/6970.removal
@@ -0,0 +1 @@
+The room list endpoint no longer returns a list of aliases.
diff --git a/changelog.d/6979.misc b/changelog.d/6979.misc
new file mode 100644
index 0000000000..c57b398c2f
--- /dev/null
+++ b/changelog.d/6979.misc
@@ -0,0 +1 @@
+Remove redundant `store_room` call from `FederationHandler._process_received_pdu`.
diff --git a/changelog.d/6982.feature b/changelog.d/6982.feature
new file mode 100644
index 0000000000..934cc5141a
--- /dev/null
+++ b/changelog.d/6982.feature
@@ -0,0 +1 @@
+Check that server_name is correctly set before running database updates.
diff --git a/changelog.d/6983.misc b/changelog.d/6983.misc
new file mode 100644
index 0000000000..08aa80bcd9
--- /dev/null
+++ b/changelog.d/6983.misc
@@ -0,0 +1 @@
+Refactoring work in preparation for changing the event redaction algorithm.
diff --git a/changelog.d/6984.docker b/changelog.d/6984.docker
new file mode 100644
index 0000000000..84a55e1267
--- /dev/null
+++ b/changelog.d/6984.docker
@@ -0,0 +1 @@
+Fix `POSTGRES_INITDB_ARGS` in the `contrib/docker/docker-compose.yml` example docker-compose configuration.
diff --git a/changelog.d/6985.misc b/changelog.d/6985.misc
new file mode 100644
index 0000000000..ba367fa9af
--- /dev/null
+++ b/changelog.d/6985.misc
@@ -0,0 +1 @@
+Update warning for incorrect database collation/ctype to include link to documentation.
diff --git a/changelog.d/6987.misc b/changelog.d/6987.misc
new file mode 100644
index 0000000000..7ff74cda55
--- /dev/null
+++ b/changelog.d/6987.misc
@@ -0,0 +1 @@
+Add some type annotations to the database storage classes.
diff --git a/changelog.d/6990.bugfix b/changelog.d/6990.bugfix
new file mode 100644
index 0000000000..8c1c48f4d4
--- /dev/null
+++ b/changelog.d/6990.bugfix
@@ -0,0 +1 @@
+Prevent user from setting 'deactivated' to anything other than a bool on the v2 PUT /users Admin API.
\ No newline at end of file
diff --git a/changelog.d/6991.misc b/changelog.d/6991.misc
new file mode 100644
index 0000000000..5130f4e8af
--- /dev/null
+++ b/changelog.d/6991.misc
@@ -0,0 +1 @@
+Port `synapse.handlers.presence` to async/await.
diff --git a/changelog.d/6995.misc b/changelog.d/6995.misc
new file mode 100644
index 0000000000..884b4cf4ee
--- /dev/null
+++ b/changelog.d/6995.misc
@@ -0,0 +1 @@
+Add some type annotations to the federation base & client classes.
diff --git a/changelog.d/7002.misc b/changelog.d/7002.misc
new file mode 100644
index 0000000000..ec5c004bbe
--- /dev/null
+++ b/changelog.d/7002.misc
@@ -0,0 +1 @@
+Merge worker apps together.
diff --git a/changelog.d/7003.misc b/changelog.d/7003.misc
new file mode 100644
index 0000000000..08aa80bcd9
--- /dev/null
+++ b/changelog.d/7003.misc
@@ -0,0 +1 @@
+Refactoring work in preparation for changing the event redaction algorithm.
diff --git a/changelog.d/7015.misc b/changelog.d/7015.misc
new file mode 100644
index 0000000000..9709dc606e
--- /dev/null
+++ b/changelog.d/7015.misc
@@ -0,0 +1 @@
+Change date in INSTALL.md#tls-certificates for last date of getting TLS certificates to November 2019.
\ No newline at end of file
diff --git a/changelog.d/9.misc b/changelog.d/9.misc
new file mode 100644
index 0000000000..24fd12c978
--- /dev/null
+++ b/changelog.d/9.misc
@@ -0,0 +1 @@
+Add SyTest to the BuildKite CI.
diff --git a/contrib/docker/docker-compose.yml b/contrib/docker/docker-compose.yml
index 5df29379c8..17354b6610 100644
--- a/contrib/docker/docker-compose.yml
+++ b/contrib/docker/docker-compose.yml
@@ -15,10 +15,9 @@ services:
restart: unless-stopped
# See the readme for a full documentation of the environment settings
environment:
- - SYNAPSE_CONFIG_PATH=/etc/homeserver.yaml
+ - SYNAPSE_CONFIG_PATH=/data/homeserver.yaml
volumes:
# You may either store all the files in a local folder
- - ./matrix-config/homeserver.yaml:/etc/homeserver.yaml
- ./files:/data
# .. or you may split this between different storage points
# - ./files:/data
@@ -58,7 +57,7 @@ services:
- POSTGRES_PASSWORD=changeme
# ensure the database gets created correctly
# https://github.com/matrix-org/synapse/blob/master/docs/postgres.md#set-up-database
- - POSTGRES_INITDB_ARGS="--encoding=UTF-8 --lc-collate=C --lc-ctype=C"
+ - POSTGRES_INITDB_ARGS=--encoding=UTF-8 --lc-collate=C --lc-ctype=C
volumes:
# You may store the database tables in a local folder..
- ./schemas:/var/lib/postgresql/data
diff --git a/contrib/systemd/README.md b/contrib/systemd/README.md
deleted file mode 100644
index 5d42b3464f..0000000000
--- a/contrib/systemd/README.md
+++ /dev/null
@@ -1,17 +0,0 @@
-# Setup Synapse with Systemd
-This is a setup for managing synapse with a user contributed systemd unit
-file. It provides a `matrix-synapse` systemd unit file that should be tailored
-to accommodate your installation in accordance with the installation
-instructions provided in [installation instructions](../../INSTALL.md).
-
-## Setup
-1. Under the service section, ensure the `User` variable matches which user
-you installed synapse under and wish to run it as.
-2. Under the service section, ensure the `WorkingDirectory` variable matches
-where you have installed synapse.
-3. Under the service section, ensure the `ExecStart` variable matches the
-appropriate locations of your installation.
-4. Copy the `matrix-synapse.service` to `/etc/systemd/system/`
-5. Start Synapse: `sudo systemctl start matrix-synapse`
-6. Verify Synapse is running: `sudo systemctl status matrix-synapse`
-7. *optional* Enable Synapse to start at system boot: `sudo systemctl enable matrix-synapse`
diff --git a/docs/code_style.md b/docs/code_style.md
index 71aecd41f7..6ef6f80290 100644
--- a/docs/code_style.md
+++ b/docs/code_style.md
@@ -30,7 +30,7 @@ The necessary tools are detailed below.
Install `flake8` with:
- pip install --upgrade flake8
+ pip install --upgrade flake8 flake8-comprehensions
Check all application and test code with:
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 8a036071e1..d839ca6ffc 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -333,6 +333,74 @@ listeners:
#
#allow_per_room_profiles: false
+# Whether to show the users on this homeserver in the user directory. Defaults to
+# 'true'.
+#
+#show_users_in_user_directory: false
+
+# Message retention policy at the server level.
+#
+# Room admins and mods can define a retention period for their rooms using the
+# 'm.room.retention' state event, and server admins can cap this period by setting
+# the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+#
+# If this feature is enabled, Synapse will regularly look for and purge events
+# which are older than the room's maximum retention period. Synapse will also
+# filter events received over federation so that events that should have been
+# purged are ignored and not stored again.
+#
+retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
+
# How long to keep redacted events in unredacted form in the database. After
# this period redacted events get replaced with their redacted form in the DB.
#
@@ -619,6 +687,8 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+# - one that ratelimits third-party invites requests based on the account
+# that's making the requests.
# - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly.
@@ -644,6 +714,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# per_second: 0.17
# burst_count: 3
#
+#rc_third_party_invite:
+# per_second: 0.2
+# burst_count: 10
+#
#rc_admin_redaction:
# per_second: 1
# burst_count: 50
@@ -711,6 +785,30 @@ media_store_path: "DATADIR/media_store"
#
#max_upload_size: 10M
+# The largest allowed size for a user avatar. If not defined, no
+# restriction will be imposed.
+#
+# Note that this only applies when an avatar is changed globally.
+# Per-room avatar changes are not affected. See allow_per_room_profiles
+# for disabling that functionality.
+#
+# Note that user avatar changes will not work if this is set without
+# using Synapse's local media repo.
+#
+#max_avatar_size: 10M
+
+# Allow mimetypes for a user avatar. If not defined, no restriction will
+# be imposed.
+#
+# Note that this only applies when an avatar is changed globally.
+# Per-room avatar changes are not affected. See allow_per_room_profiles
+# for disabling that functionality.
+#
+# Note that user avatar changes will not work if this is set without
+# using Synapse's local media repo.
+#
+#allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
@@ -971,9 +1069,32 @@ account_validity:
#
#disable_msisdn_registration: true
+# Derive the user's matrix ID from a type of 3PID used when registering.
+# This overrides any matrix ID the user proposes when calling /register
+# The 3PID type should be present in registrations_require_3pid to avoid
+# users failing to register if they don't specify the right kind of 3pid.
+#
+#register_mxid_from_3pid: email
+
+# Uncomment to set the display name of new users to their email address,
+# rather than using the default heuristic.
+#
+#register_just_use_email_for_display_name: true
+
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
+# Use an Identity Server to establish which 3PIDs are allowed to register?
+# Overrides allowed_local_3pids below.
+#
+#check_is_for_allowed_local_3pids: matrix.org
+#
+# If you are using an IS you can also check whether that IS registers
+# pending invites for the given 3PID (and then allow it to sign up on
+# the platform):
+#
+#allow_invited_3pids: False
+#
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\.org'
@@ -982,6 +1103,11 @@ account_validity:
# - medium: msisdn
# pattern: '\+44'
+# If true, stop users from trying to change the 3PIDs associated with
+# their accounts.
+#
+#disable_3pid_changes: False
+
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -1031,6 +1157,30 @@ account_validity:
# - matrix.org
# - vector.im
+# If enabled, user IDs, display names and avatar URLs will be replicated
+# to this server whenever they change.
+# This is an experimental API currently implemented by sydent to support
+# cross-homeserver user directories.
+#
+#replicate_user_profiles_to: example.com
+
+# If specified, attempt to replay registrations, profile changes & 3pid
+# bindings on the given target homeserver via the AS API. The HS is authed
+# via a given AS token.
+#
+#shadow_server:
+# hs_url: https://shadow.example.com
+# hs: shadow.example.com
+# as_token: 12u394refgbdhivsia
+
+# If enabled, don't let users set their own display names/avatars
+# other than for the very first time (unless they are a server admin).
+# Useful when provisioning users based on the contents of a 3rd party
+# directory and to avoid ambiguities.
+#
+#disable_set_displayname: False
+#disable_set_avatar_url: False
+
# Handle threepid (email/phone etc) registration and password resets through a set of
# *trusted* identity servers. Note that this allows the configured identity server to
# reset passwords for accounts!
@@ -1384,6 +1534,36 @@ password_config:
#
#pepper: "EVEN_MORE_SECRET"
+ # Define and enforce a password policy. Each parameter is optional, boolean
+ # parameters default to 'false' and integer parameters default to 0.
+ # This is an early implementation of MSC2000.
+ #
+ #policy:
+ # Whether to enforce the password policy.
+ #
+ #enabled: true
+
+ # Minimum accepted length for a password.
+ #
+ #minimum_length: 15
+
+ # Whether a password must contain at least one digit.
+ #
+ #require_digit: true
+
+ # Whether a password must contain at least one symbol.
+ # A symbol is any character that's not a number or a letter.
+ #
+ #require_symbol: true
+
+ # Whether a password must contain at least one lowercase letter.
+ #
+ #require_lowercase: true
+
+ # Whether a password must contain at least one lowercase letter.
+ #
+ #require_uppercase: true
+
# Configuration for sending emails from Synapse.
#
@@ -1409,10 +1589,6 @@ email:
#
#require_transport_security: true
- # Enable sending emails for messages that the user has missed
- #
- #enable_notifs: false
-
# notif_from defines the "From" address to use when sending emails.
# It must be set if email sending is enabled.
#
@@ -1430,6 +1606,11 @@ email:
#
#app_name: my_branded_matrix_server
+ # Uncomment the following to enable sending emails for messages that the user
+ # has missed. Disabled by default.
+ #
+ #enable_notifs: true
+
# Uncomment the following to disable automatic subscription to email
# notifications for new users. Enabled by default.
#
@@ -1556,6 +1737,11 @@ email:
#user_directory:
# enabled: true
# search_all_users: false
+#
+# # If this is set, user search will be delegated to this ID server instead
+# # of synapse performing the search itself.
+# # This is an experimental API.
+# defer_to_id_server: https://id.example.com
# User Consent configuration
diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py
index ca4b879526..5c5a115ca9 100644
--- a/docs/sphinx/conf.py
+++ b/docs/sphinx/conf.py
@@ -12,8 +12,8 @@
# All configuration values have a default; values that are commented out
# serve to show the default.
-import sys
import os
+import sys
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
@@ -191,11 +191,11 @@ htmlhelp_basename = "Synapsedoc"
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
- #'papersize': 'letterpaper',
+ # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
- #'pointsize': '10pt',
+ # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
- #'preamble': '',
+ # 'preamble': '',
}
# Grouping the document tree into LaTeX files. List of tuples
diff --git a/res/templates-dinsic/mail-Vector.css b/res/templates-dinsic/mail-Vector.css
new file mode 100644
index 0000000000..6a3e36eda1
--- /dev/null
+++ b/res/templates-dinsic/mail-Vector.css
@@ -0,0 +1,7 @@
+.header {
+ border-bottom: 4px solid #e4f7ed ! important;
+}
+
+.notif_link a, .footer a {
+ color: #76CFA6 ! important;
+}
diff --git a/res/templates-dinsic/mail.css b/res/templates-dinsic/mail.css
new file mode 100644
index 0000000000..5ab3e1b06d
--- /dev/null
+++ b/res/templates-dinsic/mail.css
@@ -0,0 +1,156 @@
+body {
+ margin: 0px;
+}
+
+pre, code {
+ word-break: break-word;
+ white-space: pre-wrap;
+}
+
+#page {
+ font-family: 'Open Sans', Helvetica, Arial, Sans-Serif;
+ font-color: #454545;
+ font-size: 12pt;
+ width: 100%;
+ padding: 20px;
+}
+
+#inner {
+ width: 640px;
+}
+
+.header {
+ width: 100%;
+ height: 87px;
+ color: #454545;
+ border-bottom: 4px solid #e5e5e5;
+}
+
+.logo {
+ text-align: right;
+ margin-left: 20px;
+}
+
+.salutation {
+ padding-top: 10px;
+ font-weight: bold;
+}
+
+.summarytext {
+}
+
+.room {
+ width: 100%;
+ color: #454545;
+ border-bottom: 1px solid #e5e5e5;
+}
+
+.room_header td {
+ padding-top: 38px;
+ padding-bottom: 10px;
+ border-bottom: 1px solid #e5e5e5;
+}
+
+.room_name {
+ vertical-align: middle;
+ font-size: 18px;
+ font-weight: bold;
+}
+
+.room_header h2 {
+ margin-top: 0px;
+ margin-left: 75px;
+ font-size: 20px;
+}
+
+.room_avatar {
+ width: 56px;
+ line-height: 0px;
+ text-align: center;
+ vertical-align: middle;
+}
+
+.room_avatar img {
+ width: 48px;
+ height: 48px;
+ object-fit: cover;
+ border-radius: 24px;
+}
+
+.notif {
+ border-bottom: 1px solid #e5e5e5;
+ margin-top: 16px;
+ padding-bottom: 16px;
+}
+
+.historical_message .sender_avatar {
+ opacity: 0.3;
+}
+
+/* spell out opacity and historical_message class names for Outlook aka Word */
+.historical_message .sender_name {
+ color: #e3e3e3;
+}
+
+.historical_message .message_time {
+ color: #e3e3e3;
+}
+
+.historical_message .message_body {
+ color: #c7c7c7;
+}
+
+.historical_message td,
+.message td {
+ padding-top: 10px;
+}
+
+.sender_avatar {
+ width: 56px;
+ text-align: center;
+ vertical-align: top;
+}
+
+.sender_avatar img {
+ margin-top: -2px;
+ width: 32px;
+ height: 32px;
+ border-radius: 16px;
+}
+
+.sender_name {
+ display: inline;
+ font-size: 13px;
+ color: #a2a2a2;
+}
+
+.message_time {
+ text-align: right;
+ width: 100px;
+ font-size: 11px;
+ color: #a2a2a2;
+}
+
+.message_body {
+}
+
+.notif_link td {
+ padding-top: 10px;
+ padding-bottom: 10px;
+ font-weight: bold;
+}
+
+.notif_link a, .footer a {
+ color: #454545;
+ text-decoration: none;
+}
+
+.debug {
+ font-size: 10px;
+ color: #888;
+}
+
+.footer {
+ margin-top: 20px;
+ text-align: center;
+}
\ No newline at end of file
diff --git a/res/templates-dinsic/notif.html b/res/templates-dinsic/notif.html
new file mode 100644
index 0000000000..bcdfeea9da
--- /dev/null
+++ b/res/templates-dinsic/notif.html
@@ -0,0 +1,45 @@
+{% for message in notif.messages %}
+ <tr class="{{ "historical_message" if message.is_historical else "message" }}">
+ <td class="sender_avatar">
+ {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
+ {% if message.sender_avatar_url %}
+ <img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
+ {% else %}
+ {% if message.sender_hash % 3 == 0 %}
+ <img class="sender_avatar" src="https://vector.im/beta/img/76cfa6.png" />
+ {% elif message.sender_hash % 3 == 1 %}
+ <img class="sender_avatar" src="https://vector.im/beta/img/50e2c2.png" />
+ {% else %}
+ <img class="sender_avatar" src="https://vector.im/beta/img/f4c371.png" />
+ {% endif %}
+ {% endif %}
+ {% endif %}
+ </td>
+ <td class="message_contents">
+ {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
+ <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
+ {% endif %}
+ <div class="message_body">
+ {% if message.msgtype == "m.text" %}
+ {{ message.body_text_html }}
+ {% elif message.msgtype == "m.emote" %}
+ {{ message.body_text_html }}
+ {% elif message.msgtype == "m.notice" %}
+ {{ message.body_text_html }}
+ {% elif message.msgtype == "m.image" %}
+ <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
+ {% elif message.msgtype == "m.file" %}
+ <span class="filename">{{ message.body_text_plain }}</span>
+ {% endif %}
+ </div>
+ </td>
+ <td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
+ </tr>
+{% endfor %}
+<tr class="notif_link">
+ <td></td>
+ <td>
+ <a href="{{ notif.link }}">Voir {{ room.title }}</a>
+ </td>
+ <td></td>
+</tr>
diff --git a/res/templates-dinsic/notif.txt b/res/templates-dinsic/notif.txt
new file mode 100644
index 0000000000..3dff1bb570
--- /dev/null
+++ b/res/templates-dinsic/notif.txt
@@ -0,0 +1,16 @@
+{% for message in notif.messages %}
+{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
+{% if message.msgtype == "m.text" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.emote" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.notice" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.image" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.file" %}
+{{ message.body_text_plain }}
+{% endif %}
+{% endfor %}
+
+Voir {{ room.title }} à {{ notif.link }}
diff --git a/res/templates-dinsic/notif_mail.html b/res/templates-dinsic/notif_mail.html
new file mode 100644
index 0000000000..1e1efa74b2
--- /dev/null
+++ b/res/templates-dinsic/notif_mail.html
@@ -0,0 +1,55 @@
+<!doctype html>
+<html lang="en">
+ <head>
+ <style type="text/css">
+ {% include 'mail.css' without context %}
+ {% include "mail-%s.css" % app_name ignore missing without context %}
+ </style>
+ </head>
+ <body>
+ <table id="page">
+ <tr>
+ <td> </td>
+ <td id="inner">
+ <table class="header">
+ <tr>
+ <td>
+ <div class="salutation">Bonjour {{ user_display_name }},</div>
+ <div class="summarytext">{{ summary_text }}</div>
+ </td>
+ <td class="logo">
+ {% if app_name == "Riot" %}
+ <img src="http://matrix.org/img/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
+ {% elif app_name == "Vector" %}
+ <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
+ {% else %}
+ <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
+ {% endif %}
+ </td>
+ </tr>
+ </table>
+ {% for room in rooms %}
+ {% include 'room.html' with context %}
+ {% endfor %}
+ <div class="footer">
+ <a href="{{ unsubscribe_link }}">Se désinscrire</a>
+ <br/>
+ <br/>
+ <div class="debug">
+ Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
+ an event was received at {{ reason.received_at|format_ts("%c") }}
+ which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
+ {% if reason.last_sent_ts %}
+ and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
+ which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
+ {% else %}
+ and we don't have a last time we sent a mail for this room.
+ {% endif %}
+ </div>
+ </div>
+ </td>
+ <td> </td>
+ </tr>
+ </table>
+ </body>
+</html>
diff --git a/res/templates-dinsic/notif_mail.txt b/res/templates-dinsic/notif_mail.txt
new file mode 100644
index 0000000000..fae877426f
--- /dev/null
+++ b/res/templates-dinsic/notif_mail.txt
@@ -0,0 +1,10 @@
+Bonjour {{ user_display_name }},
+
+{{ summary_text }}
+
+{% for room in rooms %}
+{% include 'room.txt' with context %}
+{% endfor %}
+
+Vous pouvez désactiver ces notifications en cliquant ici {{ unsubscribe_link }}
+
diff --git a/res/templates-dinsic/room.html b/res/templates-dinsic/room.html
new file mode 100644
index 0000000000..0487b1b11c
--- /dev/null
+++ b/res/templates-dinsic/room.html
@@ -0,0 +1,33 @@
+<table class="room">
+ <tr class="room_header">
+ <td class="room_avatar">
+ {% if room.avatar_url %}
+ <img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
+ {% else %}
+ {% if room.hash % 3 == 0 %}
+ <img alt="" src="https://vector.im/beta/img/76cfa6.png" />
+ {% elif room.hash % 3 == 1 %}
+ <img alt="" src="https://vector.im/beta/img/50e2c2.png" />
+ {% else %}
+ <img alt="" src="https://vector.im/beta/img/f4c371.png" />
+ {% endif %}
+ {% endif %}
+ </td>
+ <td class="room_name" colspan="2">
+ {{ room.title }}
+ </td>
+ </tr>
+ {% if room.invite %}
+ <tr>
+ <td></td>
+ <td>
+ <a href="{{ room.link }}">Rejoindre la conversation.</a>
+ </td>
+ <td></td>
+ </tr>
+ {% else %}
+ {% for notif in room.notifs %}
+ {% include 'notif.html' with context %}
+ {% endfor %}
+ {% endif %}
+</table>
diff --git a/res/templates-dinsic/room.txt b/res/templates-dinsic/room.txt
new file mode 100644
index 0000000000..dd36d01d21
--- /dev/null
+++ b/res/templates-dinsic/room.txt
@@ -0,0 +1,9 @@
+{{ room.title }}
+
+{% if room.invite %}
+ Vous avez été invité, rejoignez la conversation en cliquant sur le lien suivant {{ room.link }}
+{% else %}
+ {% for notif in room.notifs %}
+ {% include 'notif.txt' with context %}
+ {% endfor %}
+{% endif %}
diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment
index 0ec5075e79..b8a85abe18 100755
--- a/scripts-dev/check-newsfragment
+++ b/scripts-dev/check-newsfragment
@@ -5,9 +5,9 @@
set -e
-# make sure that origin/develop is up to date
-git remote set-branches --add origin develop
-git fetch origin develop
+# make sure that origin/dinsic is up to date
+git remote set-branches --add origin dinsic
+git fetch origin dinsic
# if there are changes in the debian directory, check that the debian changelog
# has been updated
diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py
index 179be61c30..06b4c1e2ff 100644
--- a/scripts-dev/convert_server_keys.py
+++ b/scripts-dev/convert_server_keys.py
@@ -103,7 +103,7 @@ def main():
yaml.safe_dump(result, sys.stdout, default_flow_style=False)
- rows = list(row for server, json in result.items() for row in rows_v2(server, json))
+ rows = [row for server, json in result.items() for row in rows_v2(server, json)]
cursor = connection.cursor()
cursor.executemany(
diff --git a/snap/snapcraft.yaml b/snap/snapcraft.yaml
index 9e644e8567..6b62b79114 100644
--- a/snap/snapcraft.yaml
+++ b/snap/snapcraft.yaml
@@ -1,20 +1,31 @@
name: matrix-synapse
base: core18
-version: git
+version: git
summary: Reference Matrix homeserver
description: |
Synapse is the reference Matrix homeserver.
Matrix is a federated and decentralised instant messaging and VoIP system.
-grade: stable
-confinement: strict
+grade: stable
+confinement: strict
apps:
- matrix-synapse:
+ matrix-synapse:
command: synctl --no-daemonize start $SNAP_COMMON/homeserver.yaml
stop-command: synctl -c $SNAP_COMMON stop
plugs: [network-bind, network]
- daemon: simple
+ daemon: simple
+ hash-password:
+ command: hash_password
+ generate-config:
+ command: generate_config
+ generate-signing-key:
+ command: generate_signing_key.py
+ register-new-matrix-user:
+ command: register_new_matrix_user
+ plugs: [network]
+ synctl:
+ command: synctl
parts:
matrix-synapse:
source: .
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index f576d65388..dc3a863010 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -186,6 +186,7 @@ class Auth(object):
access_token = self.get_access_token_from_request(request)
user_id, app_service = yield self._get_appservice_user_id(request)
+
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
@@ -250,11 +251,11 @@ class Auth(object):
except KeyError:
raise MissingClientTokenError()
- @defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
+
if app_service is None:
return None, None
@@ -272,8 +273,12 @@ class Auth(object):
if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
- if not (yield self.store.get_user_by_id(user_id)):
- raise AuthError(403, "Application service has not registered this user")
+ # Let ASes manipulate nonexistent users (e.g. to shadow-register them)
+ # if not (yield self.store.get_user_by_id(user_id)):
+ # raise AuthError(
+ # 403,
+ # "Application service has not registered this user"
+ # )
return user_id, app_service
@defer.inlineCallbacks
@@ -538,13 +543,13 @@ class Auth(object):
return defer.succeed(auth_ids)
@defer.inlineCallbacks
- def check_can_change_room_list(self, room_id, user):
+ def check_can_change_room_list(self, room_id: str, user: UserID):
"""Check if the user is allowed to edit the room's entry in the
published room list.
Args:
- room_id (str)
- user (UserID)
+ room_id
+ user
"""
is_admin = yield self.is_server_admin(user)
@@ -556,7 +561,7 @@ class Auth(object):
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
- # m.room.aliases events
+ # m.room.canonical_alias events
power_level_event = yield self.state.get_current_state(
room_id, EventTypes.PowerLevels, ""
)
@@ -566,7 +571,7 @@ class Auth(object):
auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = event_auth.get_send_level(
- EventTypes.Aliases, "", power_level_event
+ EventTypes.CanonicalAlias, "", power_level_event
)
user_level = event_auth.get_user_power_level(user_id, auth_events)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index cc8577552b..42eff8793b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -85,6 +85,7 @@ class EventTypes(object):
RoomAvatar = "m.room.avatar"
RoomEncryption = "m.room.encryption"
GuestAccess = "m.room.guest_access"
+ Encryption = "m.room.encryption"
# These are used for validation
Message = "m.room.message"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 0c20601600..de81cb9663 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -66,6 +67,13 @@ class Codes(object):
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
+ PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
+ PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
+ PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
+ PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE"
+ PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL"
+ PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY"
+ WEAK_PASSWORD = "M_WEAK_PASSWORD"
class CodeMessageException(RuntimeError):
@@ -438,6 +446,18 @@ class IncompatibleRoomVersionError(SynapseError):
return cs_error(self.msg, self.errcode, room_version=self._room_version)
+class PasswordRefusedError(SynapseError):
+ """A password has been refused, either during password reset/change or registration.
+ """
+
+ def __init__(
+ self,
+ msg="This password doesn't comply with the server's policy",
+ errcode=Codes.WEAK_PASSWORD,
+ ):
+ super(PasswordRefusedError, self).__init__(code=400, msg=msg, errcode=errcode)
+
+
class RequestSendFailed(RuntimeError):
"""Sending a HTTP request over federation failed due to not being able to
talk to the remote server for some reason.
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 0e8b467a3e..9ffd23c6df 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -141,7 +141,7 @@ def start_reactor(
def quit_with_error(error_string):
message_lines = error_string.split("\n")
- line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
+ line_length = max(len(l) for l in message_lines if len(l) < 80) + 2
sys.stderr.write("*" * line_length + "\n")
for line in message_lines:
sys.stderr.write(" %s\n" % (line.rstrip(),))
@@ -279,6 +279,15 @@ def start(hs, listeners=None):
setup_sentry(hs)
setup_sdnotify(hs)
+
+ # We now freeze all allocated objects in the hopes that (almost)
+ # everything currently allocated are things that will be used for the
+ # rest of time. Doing so means less work each GC (hopefully).
+ #
+ # This only works on Python 3.7
+ if sys.version_info >= (3, 7):
+ gc.collect()
+ gc.freeze()
except Exception:
traceback.print_exc(file=sys.stderr)
reactor = hs.get_reactor()
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 2217d4a4fb..add43147b3 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -13,161 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-
-from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse import events
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.site import SynapseSite
-from synapse.logging.context import LoggingContext, run_in_background
-from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.directory import DirectoryStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.server import HomeServer
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.appservice")
-
-
-class AppserviceSlaveStore(
- DirectoryStore,
- SlavedEventStore,
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
-):
- pass
-
-
-class AppserviceServer(HomeServer):
- DATASTORE_CLASS = AppserviceSlaveStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
- )
-
- logger.info("Synapse appservice now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warning(
- (
- "Metrics listener configured, but "
- "enable_metrics is not True!"
- )
- )
- else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
- else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return ASReplicationHandler(self)
+import sys
-class ASReplicationHandler(ReplicationClientHandler):
- def __init__(self, hs):
- super(ASReplicationHandler, self).__init__(hs.get_datastore())
- self.appservice_handler = hs.get_application_service_handler()
-
- async def on_rdata(self, stream_name, token, rows):
- await super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
-
- if stream_name == "events":
- max_stream_id = self.store.get_room_max_stream_ordering()
- run_in_background(self._notify_app_services, max_stream_id)
-
- @defer.inlineCallbacks
- def _notify_app_services(self, room_stream_id):
- try:
- yield self.appservice_handler.notify_interested_services(room_stream_id)
- except Exception:
- logger.exception("Error notifying application services of event")
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config("Synapse appservice", config_options)
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.appservice"
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- if config.notify_appservices:
- sys.stderr.write(
- "\nThe appservices must be disabled in the main synapse process"
- "\nbefore they can be run in a separate worker."
- "\nPlease add ``notify_appservices: false`` to the main config"
- "\n"
- )
- sys.exit(1)
-
- # Force the pushers to start since they will be disabled in the main config
- config.notify_appservices = True
-
- ps = AppserviceServer(
- config.server_name,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- )
-
- setup_logging(ps, config, use_worker_options=True)
-
- ps.setup()
- reactor.addSystemEventTrigger(
- "before", "startup", _base.start, ps, config.worker_listeners
- )
-
- _base.start_worker_reactor("synapse-appservice", config)
-
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
if __name__ == "__main__":
with LoggingContext("main"):
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 7fa91a3b11..add43147b3 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -13,195 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-
-from twisted.internet import reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse import events
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.server import JsonResource
-from synapse.http.site import SynapseSite
-from synapse.logging.context import LoggingContext
-from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.directory import DirectoryStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.groups import SlavedGroupServerStore
-from synapse.replication.slave.storage.keys import SlavedKeyStore
-from synapse.replication.slave.storage.profile import SlavedProfileStore
-from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.transactions import SlavedTransactionStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.client.v1.login import LoginRestServlet
-from synapse.rest.client.v1.push_rule import PushRuleRestServlet
-from synapse.rest.client.v1.room import (
- JoinedRoomMemberListRestServlet,
- PublicRoomListRestServlet,
- RoomEventContextServlet,
- RoomMemberListRestServlet,
- RoomMessageListRestServlet,
- RoomStateRestServlet,
-)
-from synapse.rest.client.v1.voip import VoipRestServlet
-from synapse.rest.client.v2_alpha import groups
-from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
-from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
-from synapse.rest.client.v2_alpha.register import RegisterRestServlet
-from synapse.rest.client.versions import VersionsRestServlet
-from synapse.server import HomeServer
-from synapse.storage.data_stores.main.monthly_active_users import (
- MonthlyActiveUsersWorkerStore,
-)
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.client_reader")
-
-
-class ClientReaderSlavedStore(
- SlavedDeviceInboxStore,
- SlavedDeviceStore,
- SlavedReceiptsStore,
- SlavedPushRuleStore,
- SlavedGroupServerStore,
- SlavedAccountDataStore,
- SlavedEventStore,
- SlavedKeyStore,
- RoomStore,
- DirectoryStore,
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
- SlavedTransactionStore,
- SlavedProfileStore,
- SlavedClientIpStore,
- MonthlyActiveUsersWorkerStore,
- BaseSlavedStore,
-):
- pass
-
-
-class ClientReaderServer(HomeServer):
- DATASTORE_CLASS = ClientReaderSlavedStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "client":
- resource = JsonResource(self, canonical_json=False)
-
- PublicRoomListRestServlet(self).register(resource)
- RoomMemberListRestServlet(self).register(resource)
- JoinedRoomMemberListRestServlet(self).register(resource)
- RoomStateRestServlet(self).register(resource)
- RoomEventContextServlet(self).register(resource)
- RoomMessageListRestServlet(self).register(resource)
- RegisterRestServlet(self).register(resource)
- LoginRestServlet(self).register(resource)
- ThreepidRestServlet(self).register(resource)
- KeyQueryServlet(self).register(resource)
- KeyChangesServlet(self).register(resource)
- VoipRestServlet(self).register(resource)
- PushRuleRestServlet(self).register(resource)
- VersionsRestServlet(self).register(resource)
-
- groups.register_servlets(self, resource)
-
- resources.update({"/_matrix/client": resource})
-
- root_resource = create_resource_tree(resources, NoResource())
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
- )
-
- logger.info("Synapse client reader now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warning(
- (
- "Metrics listener configured, but "
- "enable_metrics is not True!"
- )
- )
- else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
- else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return ReplicationClientHandler(self.get_datastore())
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config("Synapse client reader", config_options)
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.client_reader"
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- ss = ClientReaderServer(
- config.server_name,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- )
-
- setup_logging(ss, config, use_worker_options=True)
-
- ss.setup()
- reactor.addSystemEventTrigger(
- "before", "startup", _base.start, ss, config.worker_listeners
- )
-
- _base.start_worker_reactor("synapse-client-reader", config)
+import sys
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
if __name__ == "__main__":
with LoggingContext("main"):
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index 58e5b354f6..e9c098c4e7 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -13,191 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-
-from twisted.internet import reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse import events
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.server import JsonResource
-from synapse.http.site import SynapseSite
-from synapse.logging.context import LoggingContext
-from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.directory import DirectoryStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.profile import SlavedProfileStore
-from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.pushers import SlavedPusherStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.transactions import SlavedTransactionStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.client.v1.profile import (
- ProfileAvatarURLRestServlet,
- ProfileDisplaynameRestServlet,
- ProfileRestServlet,
-)
-from synapse.rest.client.v1.room import (
- JoinRoomAliasServlet,
- RoomMembershipRestServlet,
- RoomSendEventRestServlet,
- RoomStateEventRestServlet,
-)
-from synapse.server import HomeServer
-from synapse.storage.data_stores.main.monthly_active_users import (
- MonthlyActiveUsersWorkerStore,
-)
-from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.event_creator")
-
-
-class EventCreatorSlavedStore(
- # FIXME(#3714): We need to add UserDirectoryStore as we write directly
- # rather than going via the correct worker.
- UserDirectoryStore,
- DirectoryStore,
- SlavedTransactionStore,
- SlavedProfileStore,
- SlavedAccountDataStore,
- SlavedPusherStore,
- SlavedReceiptsStore,
- SlavedPushRuleStore,
- SlavedDeviceStore,
- SlavedClientIpStore,
- SlavedApplicationServiceStore,
- SlavedEventStore,
- SlavedRegistrationStore,
- RoomStore,
- MonthlyActiveUsersWorkerStore,
- BaseSlavedStore,
-):
- pass
-
-
-class EventCreatorServer(HomeServer):
- DATASTORE_CLASS = EventCreatorSlavedStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "client":
- resource = JsonResource(self, canonical_json=False)
- RoomSendEventRestServlet(self).register(resource)
- RoomMembershipRestServlet(self).register(resource)
- RoomStateEventRestServlet(self).register(resource)
- JoinRoomAliasServlet(self).register(resource)
- ProfileAvatarURLRestServlet(self).register(resource)
- ProfileDisplaynameRestServlet(self).register(resource)
- ProfileRestServlet(self).register(resource)
- resources.update(
- {
- "/_matrix/client/r0": resource,
- "/_matrix/client/unstable": resource,
- "/_matrix/client/v2_alpha": resource,
- "/_matrix/client/api/v1": resource,
- }
- )
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
- )
-
- logger.info("Synapse event creator now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warning(
- (
- "Metrics listener configured, but "
- "enable_metrics is not True!"
- )
- )
- else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
- else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return ReplicationClientHandler(self.get_datastore())
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config("Synapse event creator", config_options)
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.event_creator"
-
- assert config.worker_replication_http_port is not None
-
- # This should only be done on the user directory worker or the master
- config.update_user_directory = False
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- ss = EventCreatorServer(
- config.server_name,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- )
-
- setup_logging(ss, config, use_worker_options=True)
-
- ss.setup()
- reactor.addSystemEventTrigger(
- "before", "startup", _base.start, ss, config.worker_listeners
- )
-
- _base.start_worker_reactor("synapse-event-creator", config)
+import sys
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
if __name__ == "__main__":
with LoggingContext("main"):
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index d055d11b23..add43147b3 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -13,177 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-
-from twisted.internet import reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse import events
-from synapse.api.urls import FEDERATION_PREFIX, SERVER_KEY_V2_PREFIX
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.federation.transport.server import TransportLayerServer
-from synapse.http.site import SynapseSite
-from synapse.logging.context import LoggingContext
-from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.directory import DirectoryStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.groups import SlavedGroupServerStore
-from synapse.replication.slave.storage.keys import SlavedKeyStore
-from synapse.replication.slave.storage.profile import SlavedProfileStore
-from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.pushers import SlavedPusherStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.transactions import SlavedTransactionStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.key.v2 import KeyApiV2Resource
-from synapse.server import HomeServer
-from synapse.storage.data_stores.main.monthly_active_users import (
- MonthlyActiveUsersWorkerStore,
-)
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.federation_reader")
-
-
-class FederationReaderSlavedStore(
- SlavedAccountDataStore,
- SlavedProfileStore,
- SlavedApplicationServiceStore,
- SlavedPusherStore,
- SlavedPushRuleStore,
- SlavedReceiptsStore,
- SlavedEventStore,
- SlavedKeyStore,
- SlavedRegistrationStore,
- SlavedGroupServerStore,
- SlavedDeviceStore,
- RoomStore,
- DirectoryStore,
- SlavedTransactionStore,
- MonthlyActiveUsersWorkerStore,
- BaseSlavedStore,
-):
- pass
-
-
-class FederationReaderServer(HomeServer):
- DATASTORE_CLASS = FederationReaderSlavedStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "federation":
- resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
- if name == "openid" and "federation" not in res["names"]:
- # Only load the openid resource separately if federation resource
- # is not specified since federation resource includes openid
- # resource.
- resources.update(
- {
- FEDERATION_PREFIX: TransportLayerServer(
- self, servlet_groups=["openid"]
- )
- }
- )
-
- if name in ["keys", "federation"]:
- resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
- reactor=self.get_reactor(),
- )
- logger.info("Synapse federation reader now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warning(
- (
- "Metrics listener configured, but "
- "enable_metrics is not True!"
- )
- )
- else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
- else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return ReplicationClientHandler(self.get_datastore())
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse federation reader", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.federation_reader"
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- ss = FederationReaderServer(
- config.server_name,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- )
-
- setup_logging(ss, config, use_worker_options=True)
-
- ss.setup()
- reactor.addSystemEventTrigger(
- "before", "startup", _base.start, ss, config.worker_listeners
- )
-
- _base.start_worker_reactor("synapse-federation-reader", config)
+import sys
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
if __name__ == "__main__":
with LoggingContext("main"):
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 63a91f1177..add43147b3 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -13,308 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-
-from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse import events
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.federation import send_queue
-from synapse.http.site import SynapseSite
-from synapse.logging.context import LoggingContext, run_in_background
-from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.presence import SlavedPresenceStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.transactions import SlavedTransactionStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.replication.tcp.streams._base import (
- DeviceListsStream,
- ReceiptsStream,
- ToDeviceStream,
-)
-from synapse.server import HomeServer
-from synapse.storage.database import Database
-from synapse.types import ReadReceipt
-from synapse.util.async_helpers import Linearizer
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.federation_sender")
-
-
-class FederationSenderSlaveStore(
- SlavedDeviceInboxStore,
- SlavedTransactionStore,
- SlavedReceiptsStore,
- SlavedEventStore,
- SlavedRegistrationStore,
- SlavedDeviceStore,
- SlavedPresenceStore,
-):
- def __init__(self, database: Database, db_conn, hs):
- super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs)
-
- # We pull out the current federation stream position now so that we
- # always have a known value for the federation position in memory so
- # that we don't have to bounce via a deferred once when we start the
- # replication streams.
- self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
-
- def _get_federation_out_pos(self, db_conn):
- sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?"
- sql = self.database_engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
- txn.execute(sql, ("federation",))
- rows = txn.fetchall()
- txn.close()
-
- return rows[0][0] if rows else -1
-
-
-class FederationSenderServer(HomeServer):
- DATASTORE_CLASS = FederationSenderSlaveStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
- )
-
- logger.info("Synapse federation_sender now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warning(
- (
- "Metrics listener configured, but "
- "enable_metrics is not True!"
- )
- )
- else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
- else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return FederationSenderReplicationHandler(self)
-
-
-class FederationSenderReplicationHandler(ReplicationClientHandler):
- def __init__(self, hs):
- super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
- self.send_handler = FederationSenderHandler(hs, self)
-
- async def on_rdata(self, stream_name, token, rows):
- await super(FederationSenderReplicationHandler, self).on_rdata(
- stream_name, token, rows
- )
- self.send_handler.process_replication_rows(stream_name, token, rows)
-
- def get_streams_to_replicate(self):
- args = super(
- FederationSenderReplicationHandler, self
- ).get_streams_to_replicate()
- args.update(self.send_handler.stream_positions())
- return args
-
- def on_remote_server_up(self, server: str):
- """Called when get a new REMOTE_SERVER_UP command."""
-
- # Let's wake up the transaction queue for the server in case we have
- # pending stuff to send to it.
- self.send_handler.wake_destination(server)
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse federation sender", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
- assert config.worker_app == "synapse.app.federation_sender"
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- if config.send_federation:
- sys.stderr.write(
- "\nThe send_federation must be disabled in the main synapse process"
- "\nbefore they can be run in a separate worker."
- "\nPlease add ``send_federation: false`` to the main config"
- "\n"
- )
- sys.exit(1)
-
- # Force the pushers to start since they will be disabled in the main config
- config.send_federation = True
-
- ss = FederationSenderServer(
- config.server_name,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- )
-
- setup_logging(ss, config, use_worker_options=True)
-
- ss.setup()
- reactor.addSystemEventTrigger(
- "before", "startup", _base.start, ss, config.worker_listeners
- )
-
- _base.start_worker_reactor("synapse-federation-sender", config)
-
-
-class FederationSenderHandler(object):
- """Processes the replication stream and forwards the appropriate entries
- to the federation sender.
- """
-
- def __init__(self, hs: FederationSenderServer, replication_client):
- self.store = hs.get_datastore()
- self._is_mine_id = hs.is_mine_id
- self.federation_sender = hs.get_federation_sender()
- self.replication_client = replication_client
-
- self.federation_position = self.store.federation_out_pos_startup
- self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
-
- self._last_ack = self.federation_position
-
- self._room_serials = {}
- self._room_typing = {}
-
- def on_start(self):
- # There may be some events that are persisted but haven't been sent,
- # so send them now.
- self.federation_sender.notify_new_events(
- self.store.get_room_max_stream_ordering()
- )
-
- def wake_destination(self, server: str):
- self.federation_sender.wake_destination(server)
-
- def stream_positions(self):
- return {"federation": self.federation_position}
-
- def process_replication_rows(self, stream_name, token, rows):
- # The federation stream contains things that we want to send out, e.g.
- # presence, typing, etc.
- if stream_name == "federation":
- send_queue.process_rows_for_federation(self.federation_sender, rows)
- run_in_background(self.update_token, token)
-
- # We also need to poke the federation sender when new events happen
- elif stream_name == "events":
- self.federation_sender.notify_new_events(token)
-
- # ... and when new receipts happen
- elif stream_name == ReceiptsStream.NAME:
- run_as_background_process(
- "process_receipts_for_federation", self._on_new_receipts, rows
- )
-
- # ... as well as device updates and messages
- elif stream_name == DeviceListsStream.NAME:
- hosts = set(row.destination for row in rows)
- for host in hosts:
- self.federation_sender.send_device_messages(host)
-
- elif stream_name == ToDeviceStream.NAME:
- # The to_device stream includes stuff to be pushed to both local
- # clients and remote servers, so we ignore entities that start with
- # '@' (since they'll be local users rather than destinations).
- hosts = set(row.entity for row in rows if not row.entity.startswith("@"))
- for host in hosts:
- self.federation_sender.send_device_messages(host)
-
- @defer.inlineCallbacks
- def _on_new_receipts(self, rows):
- """
- Args:
- rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]):
- new receipts to be processed
- """
- for receipt in rows:
- # we only want to send on receipts for our own users
- if not self._is_mine_id(receipt.user_id):
- continue
- receipt_info = ReadReceipt(
- receipt.room_id,
- receipt.receipt_type,
- receipt.user_id,
- [receipt.event_id],
- receipt.data,
- )
- yield self.federation_sender.send_read_receipt(receipt_info)
-
- @defer.inlineCallbacks
- def update_token(self, token):
- try:
- self.federation_position = token
-
- # We linearize here to ensure we don't have races updating the token
- with (yield self._fed_position_linearizer.queue(None)):
- if self._last_ack < self.federation_position:
- yield self.store.update_federation_out_pos(
- "federation", self.federation_position
- )
-
- # We ACK this token over replication so that the master can drop
- # its in memory queues
- self.replication_client.send_federation_ack(
- self.federation_position
- )
- self._last_ack = self.federation_position
- except Exception:
- logger.exception("Error updating federation stream position")
+import sys
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
if __name__ == "__main__":
with LoggingContext("main"):
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 30e435eead..add43147b3 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -13,241 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-
-from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse import events
-from synapse.api.errors import HttpResponseException, SynapseError
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.server import JsonResource
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.http.site import SynapseSite
-from synapse.logging.context import LoggingContext
-from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.client.v2_alpha._base import client_patterns
-from synapse.server import HomeServer
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.frontend_proxy")
-
-
-class PresenceStatusStubServlet(RestServlet):
- PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
-
- def __init__(self, hs):
- super(PresenceStatusStubServlet, self).__init__()
- self.http_client = hs.get_simple_http_client()
- self.auth = hs.get_auth()
- self.main_uri = hs.config.worker_main_http_uri
-
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- # Pass through the auth headers, if any, in case the access token
- # is there.
- auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
- headers = {"Authorization": auth_headers}
-
- try:
- result = yield self.http_client.get_json(
- self.main_uri + request.uri.decode("ascii"), headers=headers
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
-
- return 200, result
-
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- yield self.auth.get_user_by_req(request)
- return 200, {}
-
-
-class KeyUploadServlet(RestServlet):
- PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
-
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
- super(KeyUploadServlet, self).__init__()
- self.auth = hs.get_auth()
- self.store = hs.get_datastore()
- self.http_client = hs.get_simple_http_client()
- self.main_uri = hs.config.worker_main_http_uri
-
- @defer.inlineCallbacks
- def on_POST(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- user_id = requester.user.to_string()
- body = parse_json_object_from_request(request)
-
- if device_id is not None:
- # passing the device_id here is deprecated; however, we allow it
- # for now for compatibility with older clients.
- if requester.device_id is not None and device_id != requester.device_id:
- logger.warning(
- "Client uploading keys for a different device "
- "(logged in as %s, uploading for %s)",
- requester.device_id,
- device_id,
- )
- else:
- device_id = requester.device_id
-
- if device_id is None:
- raise SynapseError(
- 400, "To upload keys, you must pass device_id when authenticating"
- )
-
- if body:
- # They're actually trying to upload something, proxy to main synapse.
- # Pass through the auth headers, if any, in case the access token
- # is there.
- auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
- headers = {"Authorization": auth_headers}
- result = yield self.http_client.post_json_get_json(
- self.main_uri + request.uri.decode("ascii"), body, headers=headers
- )
-
- return 200, result
- else:
- # Just interested in counts.
- result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
- return 200, {"one_time_key_counts": result}
-
-
-class FrontendProxySlavedStore(
- SlavedDeviceStore,
- SlavedClientIpStore,
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
- BaseSlavedStore,
-):
- pass
+import sys
-class FrontendProxyServer(HomeServer):
- DATASTORE_CLASS = FrontendProxySlavedStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "client":
- resource = JsonResource(self, canonical_json=False)
- KeyUploadServlet(self).register(resource)
-
- # If presence is disabled, use the stub servlet that does
- # not allow sending presence
- if not self.config.use_presence:
- PresenceStatusStubServlet(self).register(resource)
-
- resources.update(
- {
- "/_matrix/client/r0": resource,
- "/_matrix/client/unstable": resource,
- "/_matrix/client/v2_alpha": resource,
- "/_matrix/client/api/v1": resource,
- }
- )
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
- reactor=self.get_reactor(),
- )
-
- logger.info("Synapse client reader now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warning(
- (
- "Metrics listener configured, but "
- "enable_metrics is not True!"
- )
- )
- else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
- else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return ReplicationClientHandler(self.get_datastore())
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config("Synapse frontend proxy", config_options)
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.frontend_proxy"
-
- assert config.worker_main_http_uri is not None
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- ss = FrontendProxyServer(
- config.server_name,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- )
-
- setup_logging(ss, config, use_worker_options=True)
-
- ss.setup()
- reactor.addSystemEventTrigger(
- "before", "startup", _base.start, ss, config.worker_listeners
- )
-
- _base.start_worker_reactor("synapse-frontend-proxy", config)
-
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
if __name__ == "__main__":
with LoggingContext("main"):
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
new file mode 100644
index 0000000000..b2c764bfe8
--- /dev/null
+++ b/synapse/app/generic_worker.py
@@ -0,0 +1,923 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import contextlib
+import logging
+import sys
+
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
+
+import synapse
+import synapse.events
+from synapse.api.constants import EventTypes
+from synapse.api.errors import HttpResponseException, SynapseError
+from synapse.api.urls import (
+ CLIENT_API_PREFIX,
+ FEDERATION_PREFIX,
+ LEGACY_MEDIA_PREFIX,
+ MEDIA_PREFIX,
+ SERVER_KEY_V2_PREFIX,
+)
+from synapse.app import _base
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.federation import send_queue
+from synapse.federation.transport.server import TransportLayerServer
+from synapse.handlers.presence import PresenceHandler, get_interested_parties
+from synapse.http.server import JsonResource
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseSite
+from synapse.logging.context import LoggingContext, run_in_background
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.slave.storage._base import BaseSlavedStore, __func__
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.directory import DirectoryStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.filtering import SlavedFilteringStore
+from synapse.replication.slave.storage.groups import SlavedGroupServerStore
+from synapse.replication.slave.storage.keys import SlavedKeyStore
+from synapse.replication.slave.storage.presence import SlavedPresenceStore
+from synapse.replication.slave.storage.profile import SlavedProfileStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.pushers import SlavedPusherStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.slave.storage.transactions import SlavedTransactionStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.replication.tcp.streams._base import (
+ DeviceListsStream,
+ ReceiptsStream,
+ ToDeviceStream,
+)
+from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow
+from synapse.rest.admin import register_servlets_for_media_repo
+from synapse.rest.client.v1 import events
+from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
+from synapse.rest.client.v1.login import LoginRestServlet
+from synapse.rest.client.v1.profile import (
+ ProfileAvatarURLRestServlet,
+ ProfileDisplaynameRestServlet,
+ ProfileRestServlet,
+)
+from synapse.rest.client.v1.push_rule import PushRuleRestServlet
+from synapse.rest.client.v1.room import (
+ JoinedRoomMemberListRestServlet,
+ JoinRoomAliasServlet,
+ PublicRoomListRestServlet,
+ RoomEventContextServlet,
+ RoomInitialSyncRestServlet,
+ RoomMemberListRestServlet,
+ RoomMembershipRestServlet,
+ RoomMessageListRestServlet,
+ RoomSendEventRestServlet,
+ RoomStateEventRestServlet,
+ RoomStateRestServlet,
+)
+from synapse.rest.client.v1.voip import VoipRestServlet
+from synapse.rest.client.v2_alpha import groups, sync, user_directory
+from synapse.rest.client.v2_alpha._base import client_patterns
+from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
+from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
+from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+from synapse.rest.client.versions import VersionsRestServlet
+from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.server import HomeServer
+from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
+from synapse.storage.data_stores.main.monthly_active_users import (
+ MonthlyActiveUsersWorkerStore,
+)
+from synapse.storage.data_stores.main.presence import UserPresenceState
+from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
+from synapse.types import ReadReceipt
+from synapse.util.async_helpers import Linearizer
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.manhole import manhole
+from synapse.util.stringutils import random_string
+from synapse.util.versionstring import get_version_string
+
+logger = logging.getLogger("synapse.app.generic_worker")
+
+
+class PresenceStatusStubServlet(RestServlet):
+ """If presence is disabled this servlet can be used to stub out setting
+ presence status, while proxying the getters to the master instance.
+ """
+
+ PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
+
+ def __init__(self, hs):
+ super(PresenceStatusStubServlet, self).__init__()
+ self.http_client = hs.get_simple_http_client()
+ self.auth = hs.get_auth()
+ self.main_uri = hs.config.worker_main_http_uri
+
+ async def on_GET(self, request, user_id):
+ # Pass through the auth headers, if any, in case the access token
+ # is there.
+ auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
+ headers = {"Authorization": auth_headers}
+
+ try:
+ result = await self.http_client.get_json(
+ self.main_uri + request.uri.decode("ascii"), headers=headers
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
+
+ return 200, result
+
+ async def on_PUT(self, request, user_id):
+ await self.auth.get_user_by_req(request)
+ return 200, {}
+
+
+class KeyUploadServlet(RestServlet):
+ """An implementation of the `KeyUploadServlet` that responds to read only
+ requests, but otherwise proxies through to the master instance.
+ """
+
+ PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(KeyUploadServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
+ self.main_uri = hs.config.worker_main_http_uri
+
+ async def on_POST(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user_id = requester.user.to_string()
+ body = parse_json_object_from_request(request)
+
+ if device_id is not None:
+ # passing the device_id here is deprecated; however, we allow it
+ # for now for compatibility with older clients.
+ if requester.device_id is not None and device_id != requester.device_id:
+ logger.warning(
+ "Client uploading keys for a different device "
+ "(logged in as %s, uploading for %s)",
+ requester.device_id,
+ device_id,
+ )
+ else:
+ device_id = requester.device_id
+
+ if device_id is None:
+ raise SynapseError(
+ 400, "To upload keys, you must pass device_id when authenticating"
+ )
+
+ if body:
+ # They're actually trying to upload something, proxy to main synapse.
+ # Pass through the auth headers, if any, in case the access token
+ # is there.
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
+ headers = {"Authorization": auth_headers}
+ result = await self.http_client.post_json_get_json(
+ self.main_uri + request.uri.decode("ascii"), body, headers=headers
+ )
+
+ return 200, result
+ else:
+ # Just interested in counts.
+ result = await self.store.count_e2e_one_time_keys(user_id, device_id)
+ return 200, {"one_time_key_counts": result}
+
+
+UPDATE_SYNCING_USERS_MS = 10 * 1000
+
+
+class GenericWorkerPresence(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.is_mine_id = hs.is_mine_id
+ self.http_client = hs.get_simple_http_client()
+ self.store = hs.get_datastore()
+ self.user_to_num_current_syncs = {}
+ self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
+
+ active_presence = self.store.take_presence_startup_info()
+ self.user_to_current_state = {state.user_id: state for state in active_presence}
+
+ # user_id -> last_sync_ms. Lists the users that have stopped syncing
+ # but we haven't notified the master of that yet
+ self.users_going_offline = {}
+
+ self._send_stop_syncing_loop = self.clock.looping_call(
+ self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
+ )
+
+ self.process_id = random_string(16)
+ logger.info("Presence process_id is %r", self.process_id)
+
+ def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+ if self.hs.config.use_presence:
+ self.hs.get_tcp_replication().send_user_sync(
+ user_id, is_syncing, last_sync_ms
+ )
+
+ def mark_as_coming_online(self, user_id):
+ """A user has started syncing. Send a UserSync to the master, unless they
+ had recently stopped syncing.
+
+ Args:
+ user_id (str)
+ """
+ going_offline = self.users_going_offline.pop(user_id, None)
+ if not going_offline:
+ # Safe to skip because we haven't yet told the master they were offline
+ self.send_user_sync(user_id, True, self.clock.time_msec())
+
+ def mark_as_going_offline(self, user_id):
+ """A user has stopped syncing. We wait before notifying the master as
+ its likely they'll come back soon. This allows us to avoid sending
+ a stopped syncing immediately followed by a started syncing notification
+ to the master
+
+ Args:
+ user_id (str)
+ """
+ self.users_going_offline[user_id] = self.clock.time_msec()
+
+ def send_stop_syncing(self):
+ """Check if there are any users who have stopped syncing a while ago
+ and haven't come back yet. If there are poke the master about them.
+ """
+ now = self.clock.time_msec()
+ for user_id, last_sync_ms in list(self.users_going_offline.items()):
+ if now - last_sync_ms > UPDATE_SYNCING_USERS_MS:
+ self.users_going_offline.pop(user_id, None)
+ self.send_user_sync(user_id, False, last_sync_ms)
+
+ def set_state(self, user, state, ignore_status_msg=False):
+ # TODO Hows this supposed to work?
+ return defer.succeed(None)
+
+ get_states = __func__(PresenceHandler.get_states)
+ get_state = __func__(PresenceHandler.get_state)
+ current_state_for_users = __func__(PresenceHandler.current_state_for_users)
+
+ def user_syncing(self, user_id, affect_presence):
+ if affect_presence:
+ curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
+ self.user_to_num_current_syncs[user_id] = curr_sync + 1
+
+ # If we went from no in flight sync to some, notify replication
+ if self.user_to_num_current_syncs[user_id] == 1:
+ self.mark_as_coming_online(user_id)
+
+ def _end():
+ # We check that the user_id is in user_to_num_current_syncs because
+ # user_to_num_current_syncs may have been cleared if we are
+ # shutting down.
+ if affect_presence and user_id in self.user_to_num_current_syncs:
+ self.user_to_num_current_syncs[user_id] -= 1
+
+ # If we went from one in flight sync to non, notify replication
+ if self.user_to_num_current_syncs[user_id] == 0:
+ self.mark_as_going_offline(user_id)
+
+ @contextlib.contextmanager
+ def _user_syncing():
+ try:
+ yield
+ finally:
+ _end()
+
+ return defer.succeed(_user_syncing())
+
+ @defer.inlineCallbacks
+ def notify_from_replication(self, states, stream_id):
+ parties = yield get_interested_parties(self.store, states)
+ room_ids_to_states, users_to_states = parties
+
+ self.notifier.on_new_event(
+ "presence_key",
+ stream_id,
+ rooms=room_ids_to_states.keys(),
+ users=users_to_states.keys(),
+ )
+
+ @defer.inlineCallbacks
+ def process_replication_rows(self, token, rows):
+ states = [
+ UserPresenceState(
+ row.user_id,
+ row.state,
+ row.last_active_ts,
+ row.last_federation_update_ts,
+ row.last_user_sync_ts,
+ row.status_msg,
+ row.currently_active,
+ )
+ for row in rows
+ ]
+
+ for state in states:
+ self.user_to_current_state[state.user_id] = state
+
+ stream_id = token
+ yield self.notify_from_replication(states, stream_id)
+
+ def get_currently_syncing_users(self):
+ if self.hs.config.use_presence:
+ return [
+ user_id
+ for user_id, count in self.user_to_num_current_syncs.items()
+ if count > 0
+ ]
+ else:
+ return set()
+
+
+class GenericWorkerTyping(object):
+ def __init__(self, hs):
+ self._latest_room_serial = 0
+ self._reset()
+
+ def _reset(self):
+ """
+ Reset the typing handler's data caches.
+ """
+ # map room IDs to serial numbers
+ self._room_serials = {}
+ # map room IDs to sets of users currently typing
+ self._room_typing = {}
+
+ def stream_positions(self):
+ # We must update this typing token from the response of the previous
+ # sync. In particular, the stream id may "reset" back to zero/a low
+ # value which we *must* use for the next replication request.
+ return {"typing": self._latest_room_serial}
+
+ def process_replication_rows(self, token, rows):
+ if self._latest_room_serial > token:
+ # The master has gone backwards. To prevent inconsistent data, just
+ # clear everything.
+ self._reset()
+
+ # Set the latest serial token to whatever the server gave us.
+ self._latest_room_serial = token
+
+ for row in rows:
+ self._room_serials[row.room_id] = token
+ self._room_typing[row.room_id] = row.user_ids
+
+
+class GenericWorkerSlavedStore(
+ # FIXME(#3714): We need to add UserDirectoryStore as we write directly
+ # rather than going via the correct worker.
+ UserDirectoryStore,
+ SlavedDeviceInboxStore,
+ SlavedDeviceStore,
+ SlavedReceiptsStore,
+ SlavedPushRuleStore,
+ SlavedGroupServerStore,
+ SlavedAccountDataStore,
+ SlavedPusherStore,
+ SlavedEventStore,
+ SlavedKeyStore,
+ RoomStore,
+ DirectoryStore,
+ SlavedApplicationServiceStore,
+ SlavedRegistrationStore,
+ SlavedTransactionStore,
+ SlavedProfileStore,
+ SlavedClientIpStore,
+ SlavedPresenceStore,
+ SlavedFilteringStore,
+ MonthlyActiveUsersWorkerStore,
+ MediaRepositoryStore,
+ BaseSlavedStore,
+):
+ def __init__(self, database, db_conn, hs):
+ super(GenericWorkerSlavedStore, self).__init__(database, db_conn, hs)
+
+ # We pull out the current federation stream position now so that we
+ # always have a known value for the federation position in memory so
+ # that we don't have to bounce via a deferred once when we start the
+ # replication streams.
+ self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
+
+ def _get_federation_out_pos(self, db_conn):
+ sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?"
+ sql = self.database_engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, ("federation",))
+ rows = txn.fetchall()
+ txn.close()
+
+ return rows[0][0] if rows else -1
+
+
+class GenericWorkerServer(HomeServer):
+ DATASTORE_CLASS = GenericWorkerSlavedStore
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_addresses = listener_config["bind_addresses"]
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
+ elif name == "client":
+ resource = JsonResource(self, canonical_json=False)
+
+ PublicRoomListRestServlet(self).register(resource)
+ RoomMemberListRestServlet(self).register(resource)
+ JoinedRoomMemberListRestServlet(self).register(resource)
+ RoomStateRestServlet(self).register(resource)
+ RoomEventContextServlet(self).register(resource)
+ RoomMessageListRestServlet(self).register(resource)
+ RegisterRestServlet(self).register(resource)
+ LoginRestServlet(self).register(resource)
+ ThreepidRestServlet(self).register(resource)
+ KeyQueryServlet(self).register(resource)
+ KeyChangesServlet(self).register(resource)
+ VoipRestServlet(self).register(resource)
+ PushRuleRestServlet(self).register(resource)
+ VersionsRestServlet(self).register(resource)
+ RoomSendEventRestServlet(self).register(resource)
+ RoomMembershipRestServlet(self).register(resource)
+ RoomStateEventRestServlet(self).register(resource)
+ JoinRoomAliasServlet(self).register(resource)
+ ProfileAvatarURLRestServlet(self).register(resource)
+ ProfileDisplaynameRestServlet(self).register(resource)
+ ProfileRestServlet(self).register(resource)
+ KeyUploadServlet(self).register(resource)
+
+ sync.register_servlets(self, resource)
+ events.register_servlets(self, resource)
+ InitialSyncRestServlet(self).register(resource)
+ RoomInitialSyncRestServlet(self).register(resource)
+
+ user_directory.register_servlets(self, resource)
+
+ # If presence is disabled, use the stub servlet that does
+ # not allow sending presence
+ if not self.config.use_presence:
+ PresenceStatusStubServlet(self).register(resource)
+
+ groups.register_servlets(self, resource)
+
+ resources.update({CLIENT_API_PREFIX: resource})
+ elif name == "federation":
+ resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
+ elif name == "media":
+ if self.config.can_load_media_repo:
+ media_repo = self.get_media_repository_resource()
+
+ # We need to serve the admin servlets for media on the
+ # worker.
+ admin_resource = JsonResource(self, canonical_json=False)
+ register_servlets_for_media_repo(self, admin_resource)
+
+ resources.update(
+ {
+ MEDIA_PREFIX: media_repo,
+ LEGACY_MEDIA_PREFIX: media_repo,
+ "/_synapse/admin": admin_resource,
+ }
+ )
+ else:
+ logger.warning(
+ "A 'media' listener is configured but the media"
+ " repository is disabled. Ignoring."
+ )
+
+ if name == "openid" and "federation" not in res["names"]:
+ # Only load the openid resource separately if federation resource
+ # is not specified since federation resource includes openid
+ # resource.
+ resources.update(
+ {
+ FEDERATION_PREFIX: TransportLayerServer(
+ self, servlet_groups=["openid"]
+ )
+ }
+ )
+
+ if name in ["keys", "federation"]:
+ resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
+
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
+ ),
+ reactor=self.get_reactor(),
+ )
+
+ logger.info("Synapse worker now listening on port %d", port)
+
+ def start_listening(self, listeners):
+ for listener in listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
+ )
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warning(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
+ else:
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
+ else:
+ logger.warning("Unrecognized listener type: %s", listener["type"])
+
+ self.get_tcp_replication().start_replication(self)
+
+ def remove_pusher(self, app_id, push_key, user_id):
+ self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
+
+ def build_tcp_replication(self):
+ return GenericWorkerReplicationHandler(self)
+
+ def build_presence_handler(self):
+ return GenericWorkerPresence(self)
+
+ def build_typing_handler(self):
+ return GenericWorkerTyping(self)
+
+
+class GenericWorkerReplicationHandler(ReplicationClientHandler):
+ def __init__(self, hs):
+ super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore())
+
+ self.store = hs.get_datastore()
+ self.typing_handler = hs.get_typing_handler()
+ # NB this is a SynchrotronPresence, not a normal PresenceHandler
+ self.presence_handler = hs.get_presence_handler()
+ self.notifier = hs.get_notifier()
+
+ self.notify_pushers = hs.config.start_pushers
+ self.pusher_pool = hs.get_pusherpool()
+
+ if hs.config.send_federation:
+ self.send_handler = FederationSenderHandler(hs, self)
+ else:
+ self.send_handler = None
+
+ async def on_rdata(self, stream_name, token, rows):
+ await super(GenericWorkerReplicationHandler, self).on_rdata(
+ stream_name, token, rows
+ )
+ run_in_background(self.process_and_notify, stream_name, token, rows)
+
+ def get_streams_to_replicate(self):
+ args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
+ args.update(self.typing_handler.stream_positions())
+ if self.send_handler:
+ args.update(self.send_handler.stream_positions())
+ return args
+
+ def get_currently_syncing_users(self):
+ return self.presence_handler.get_currently_syncing_users()
+
+ async def process_and_notify(self, stream_name, token, rows):
+ try:
+ if self.send_handler:
+ self.send_handler.process_replication_rows(stream_name, token, rows)
+
+ if stream_name == "events":
+ # We shouldn't get multiple rows per token for events stream, so
+ # we don't need to optimise this for multiple rows.
+ for row in rows:
+ if row.type != EventsStreamEventRow.TypeId:
+ continue
+ assert isinstance(row, EventsStreamRow)
+
+ event = await self.store.get_event(
+ row.data.event_id, allow_rejected=True
+ )
+ if event.rejected_reason:
+ continue
+
+ extra_users = ()
+ if event.type == EventTypes.Member:
+ extra_users = (event.state_key,)
+ max_token = self.store.get_room_max_stream_ordering()
+ self.notifier.on_new_room_event(
+ event, token, max_token, extra_users
+ )
+
+ await self.pusher_pool.on_new_notifications(token, token)
+ elif stream_name == "push_rules":
+ self.notifier.on_new_event(
+ "push_rules_key", token, users=[row.user_id for row in rows]
+ )
+ elif stream_name in ("account_data", "tag_account_data"):
+ self.notifier.on_new_event(
+ "account_data_key", token, users=[row.user_id for row in rows]
+ )
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "receipt_key", token, rooms=[row.room_id for row in rows]
+ )
+ await self.pusher_pool.on_new_receipts(
+ token, token, {row.room_id for row in rows}
+ )
+ elif stream_name == "typing":
+ self.typing_handler.process_replication_rows(token, rows)
+ self.notifier.on_new_event(
+ "typing_key", token, rooms=[row.room_id for row in rows]
+ )
+ elif stream_name == "to_device":
+ entities = [row.entity for row in rows if row.entity.startswith("@")]
+ if entities:
+ self.notifier.on_new_event("to_device_key", token, users=entities)
+ elif stream_name == "device_lists":
+ all_room_ids = set()
+ for row in rows:
+ room_ids = await self.store.get_rooms_for_user(row.user_id)
+ all_room_ids.update(room_ids)
+ self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
+ elif stream_name == "presence":
+ await self.presence_handler.process_replication_rows(token, rows)
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "groups_key", token, users=[row.user_id for row in rows]
+ )
+ elif stream_name == "pushers":
+ for row in rows:
+ if row.deleted:
+ self.stop_pusher(row.user_id, row.app_id, row.pushkey)
+ else:
+ await self.start_pusher(row.user_id, row.app_id, row.pushkey)
+ except Exception:
+ logger.exception("Error processing replication")
+
+ def stop_pusher(self, user_id, app_id, pushkey):
+ if not self.notify_pushers:
+ return
+
+ key = "%s:%s" % (app_id, pushkey)
+ pushers_for_user = self.pusher_pool.pushers.get(user_id, {})
+ pusher = pushers_for_user.pop(key, None)
+ if pusher is None:
+ return
+ logger.info("Stopping pusher %r / %r", user_id, key)
+ pusher.on_stop()
+
+ async def start_pusher(self, user_id, app_id, pushkey):
+ if not self.notify_pushers:
+ return
+
+ key = "%s:%s" % (app_id, pushkey)
+ logger.info("Starting pusher %r / %r", user_id, key)
+ return await self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
+
+ def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+
+ # Let's wake up the transaction queue for the server in case we have
+ # pending stuff to send to it.
+ if self.send_handler:
+ self.send_handler.wake_destination(server)
+
+
+class FederationSenderHandler(object):
+ """Processes the replication stream and forwards the appropriate entries
+ to the federation sender.
+ """
+
+ def __init__(self, hs: GenericWorkerServer, replication_client):
+ self.store = hs.get_datastore()
+ self._is_mine_id = hs.is_mine_id
+ self.federation_sender = hs.get_federation_sender()
+ self.replication_client = replication_client
+
+ self.federation_position = self.store.federation_out_pos_startup
+ self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
+
+ self._last_ack = self.federation_position
+
+ self._room_serials = {}
+ self._room_typing = {}
+
+ def on_start(self):
+ # There may be some events that are persisted but haven't been sent,
+ # so send them now.
+ self.federation_sender.notify_new_events(
+ self.store.get_room_max_stream_ordering()
+ )
+
+ def wake_destination(self, server: str):
+ self.federation_sender.wake_destination(server)
+
+ def stream_positions(self):
+ return {"federation": self.federation_position}
+
+ def process_replication_rows(self, stream_name, token, rows):
+ # The federation stream contains things that we want to send out, e.g.
+ # presence, typing, etc.
+ if stream_name == "federation":
+ send_queue.process_rows_for_federation(self.federation_sender, rows)
+ run_in_background(self.update_token, token)
+
+ # We also need to poke the federation sender when new events happen
+ elif stream_name == "events":
+ self.federation_sender.notify_new_events(token)
+
+ # ... and when new receipts happen
+ elif stream_name == ReceiptsStream.NAME:
+ run_as_background_process(
+ "process_receipts_for_federation", self._on_new_receipts, rows
+ )
+
+ # ... as well as device updates and messages
+ elif stream_name == DeviceListsStream.NAME:
+ hosts = {row.destination for row in rows}
+ for host in hosts:
+ self.federation_sender.send_device_messages(host)
+
+ elif stream_name == ToDeviceStream.NAME:
+ # The to_device stream includes stuff to be pushed to both local
+ # clients and remote servers, so we ignore entities that start with
+ # '@' (since they'll be local users rather than destinations).
+ hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+ for host in hosts:
+ self.federation_sender.send_device_messages(host)
+
+ async def _on_new_receipts(self, rows):
+ """
+ Args:
+ rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]):
+ new receipts to be processed
+ """
+ for receipt in rows:
+ # we only want to send on receipts for our own users
+ if not self._is_mine_id(receipt.user_id):
+ continue
+ receipt_info = ReadReceipt(
+ receipt.room_id,
+ receipt.receipt_type,
+ receipt.user_id,
+ [receipt.event_id],
+ receipt.data,
+ )
+ await self.federation_sender.send_read_receipt(receipt_info)
+
+ async def update_token(self, token):
+ try:
+ self.federation_position = token
+
+ # We linearize here to ensure we don't have races updating the token
+ with (await self._fed_position_linearizer.queue(None)):
+ if self._last_ack < self.federation_position:
+ await self.store.update_federation_out_pos(
+ "federation", self.federation_position
+ )
+
+ # We ACK this token over replication so that the master can drop
+ # its in memory queues
+ self.replication_client.send_federation_ack(
+ self.federation_position
+ )
+ self._last_ack = self.federation_position
+ except Exception:
+ logger.exception("Error updating federation stream position")
+
+
+def start(config_options):
+ try:
+ config = HomeServerConfig.load_config("Synapse worker", config_options)
+ except ConfigError as e:
+ sys.stderr.write("\n" + str(e) + "\n")
+ sys.exit(1)
+
+ # For backwards compatibility let any of the old app names.
+ assert config.worker_app in (
+ "synapse.app.appservice",
+ "synapse.app.client_reader",
+ "synapse.app.event_creator",
+ "synapse.app.federation_reader",
+ "synapse.app.federation_sender",
+ "synapse.app.frontend_proxy",
+ "synapse.app.generic_worker",
+ "synapse.app.media_repository",
+ "synapse.app.pusher",
+ "synapse.app.synchrotron",
+ "synapse.app.user_dir",
+ )
+
+ if config.worker_app == "synapse.app.appservice":
+ if config.notify_appservices:
+ sys.stderr.write(
+ "\nThe appservices must be disabled in the main synapse process"
+ "\nbefore they can be run in a separate worker."
+ "\nPlease add ``notify_appservices: false`` to the main config"
+ "\n"
+ )
+ sys.exit(1)
+
+ # Force the appservice to start since they will be disabled in the main config
+ config.notify_appservices = True
+
+ if config.worker_app == "synapse.app.pusher":
+ if config.start_pushers:
+ sys.stderr.write(
+ "\nThe pushers must be disabled in the main synapse process"
+ "\nbefore they can be run in a separate worker."
+ "\nPlease add ``start_pushers: false`` to the main config"
+ "\n"
+ )
+ sys.exit(1)
+
+ # Force the pushers to start since they will be disabled in the main config
+ config.start_pushers = True
+
+ if config.worker_app == "synapse.app.user_dir":
+ if config.update_user_directory:
+ sys.stderr.write(
+ "\nThe update_user_directory must be disabled in the main synapse process"
+ "\nbefore they can be run in a separate worker."
+ "\nPlease add ``update_user_directory: false`` to the main config"
+ "\n"
+ )
+ sys.exit(1)
+
+ # Force the pushers to start since they will be disabled in the main config
+ config.update_user_directory = True
+
+ if config.worker_app == "synapse.app.federation_sender":
+ if config.send_federation:
+ sys.stderr.write(
+ "\nThe send_federation must be disabled in the main synapse process"
+ "\nbefore they can be run in a separate worker."
+ "\nPlease add ``send_federation: false`` to the main config"
+ "\n"
+ )
+ sys.exit(1)
+
+ # Force the pushers to start since they will be disabled in the main config
+ config.send_federation = True
+
+ synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+ ss = GenericWorkerServer(
+ config.server_name,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ )
+
+ setup_logging(ss, config, use_worker_options=True)
+
+ ss.setup()
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
+
+ _base.start_worker_reactor("synapse-generic-worker", config)
+
+
+if __name__ == "__main__":
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 5b5832214a..add43147b3 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -13,162 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-
-from twisted.internet import reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse import events
-from synapse.api.urls import LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.server import JsonResource
-from synapse.http.site import SynapseSite
-from synapse.logging.context import LoggingContext
-from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.transactions import SlavedTransactionStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.admin import register_servlets_for_media_repo
-from synapse.server import HomeServer
-from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.media_repository")
-
-
-class MediaRepositorySlavedStore(
- RoomStore,
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
- SlavedClientIpStore,
- SlavedTransactionStore,
- BaseSlavedStore,
- MediaRepositoryStore,
-):
- pass
-
-
-class MediaRepositoryServer(HomeServer):
- DATASTORE_CLASS = MediaRepositorySlavedStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "media":
- media_repo = self.get_media_repository_resource()
-
- # We need to serve the admin servlets for media on the
- # worker.
- admin_resource = JsonResource(self, canonical_json=False)
- register_servlets_for_media_repo(self, admin_resource)
-
- resources.update(
- {
- MEDIA_PREFIX: media_repo,
- LEGACY_MEDIA_PREFIX: media_repo,
- "/_synapse/admin": admin_resource,
- }
- )
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
- )
- logger.info("Synapse media repository now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warning(
- (
- "Metrics listener configured, but "
- "enable_metrics is not True!"
- )
- )
- else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
- else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return ReplicationClientHandler(self.get_datastore())
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse media repository", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.media_repository"
-
- if config.enable_media_repo:
- _base.quit_with_error(
- "enable_media_repo must be disabled in the main synapse process\n"
- "before the media repo can be run in a separate worker.\n"
- "Please add ``enable_media_repo: false`` to the main config\n"
- )
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- ss = MediaRepositoryServer(
- config.server_name,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- )
-
- setup_logging(ss, config, use_worker_options=True)
-
- ss.setup()
- reactor.addSystemEventTrigger(
- "before", "startup", _base.start, ss, config.worker_listeners
- )
-
- _base.start_worker_reactor("synapse-media-repository", config)
+import sys
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
if __name__ == "__main__":
with LoggingContext("main"):
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index e46b6ac598..add43147b3 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -13,213 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-
-from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse import events
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.site import SynapseSite
-from synapse.logging.context import LoggingContext, run_in_background
-from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.replication.slave.storage._base import __func__
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.pushers import SlavedPusherStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.server import HomeServer
-from synapse.storage import DataStore
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.pusher")
-
-
-class PusherSlaveStore(
- SlavedEventStore,
- SlavedPusherStore,
- SlavedReceiptsStore,
- SlavedAccountDataStore,
- RoomStore,
-):
- update_pusher_last_stream_ordering_and_success = __func__(
- DataStore.update_pusher_last_stream_ordering_and_success
- )
-
- update_pusher_failing_since = __func__(DataStore.update_pusher_failing_since)
-
- update_pusher_last_stream_ordering = __func__(
- DataStore.update_pusher_last_stream_ordering
- )
-
- get_throttle_params_by_room = __func__(DataStore.get_throttle_params_by_room)
-
- set_throttle_params = __func__(DataStore.set_throttle_params)
-
- get_time_of_last_push_action_before = __func__(
- DataStore.get_time_of_last_push_action_before
- )
-
- get_profile_displayname = __func__(DataStore.get_profile_displayname)
-
-
-class PusherServer(HomeServer):
- DATASTORE_CLASS = PusherSlaveStore
-
- def remove_pusher(self, app_id, push_key, user_id):
- self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
- )
-
- logger.info("Synapse pusher now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warning(
- (
- "Metrics listener configured, but "
- "enable_metrics is not True!"
- )
- )
- else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
- else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
- def build_tcp_replication(self):
- return PusherReplicationHandler(self)
-
-
-class PusherReplicationHandler(ReplicationClientHandler):
- def __init__(self, hs):
- super(PusherReplicationHandler, self).__init__(hs.get_datastore())
-
- self.pusher_pool = hs.get_pusherpool()
-
- async def on_rdata(self, stream_name, token, rows):
- await super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
- run_in_background(self.poke_pushers, stream_name, token, rows)
-
- @defer.inlineCallbacks
- def poke_pushers(self, stream_name, token, rows):
- try:
- if stream_name == "pushers":
- for row in rows:
- if row.deleted:
- yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
- else:
- yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
- elif stream_name == "events":
- yield self.pusher_pool.on_new_notifications(token, token)
- elif stream_name == "receipts":
- yield self.pusher_pool.on_new_receipts(
- token, token, set(row.room_id for row in rows)
- )
- except Exception:
- logger.exception("Error poking pushers")
-
- def stop_pusher(self, user_id, app_id, pushkey):
- key = "%s:%s" % (app_id, pushkey)
- pushers_for_user = self.pusher_pool.pushers.get(user_id, {})
- pusher = pushers_for_user.pop(key, None)
- if pusher is None:
- return
- logger.info("Stopping pusher %r / %r", user_id, key)
- pusher.on_stop()
-
- def start_pusher(self, user_id, app_id, pushkey):
- key = "%s:%s" % (app_id, pushkey)
- logger.info("Starting pusher %r / %r", user_id, key)
- return self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config("Synapse pusher", config_options)
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.pusher"
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- if config.start_pushers:
- sys.stderr.write(
- "\nThe pushers must be disabled in the main synapse process"
- "\nbefore they can be run in a separate worker."
- "\nPlease add ``start_pushers: false`` to the main config"
- "\n"
- )
- sys.exit(1)
-
- # Force the pushers to start since they will be disabled in the main config
- config.start_pushers = True
-
- ps = PusherServer(
- config.server_name,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- )
-
- setup_logging(ps, config, use_worker_options=True)
-
- ps.setup()
-
- def start():
- _base.start(ps, config.worker_listeners)
- ps.get_pusherpool().start()
-
- reactor.addSystemEventTrigger("before", "startup", start)
-
- _base.start_worker_reactor("synapse-pusher", config)
+import sys
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
if __name__ == "__main__":
with LoggingContext("main"):
- ps = start(sys.argv[1:])
+ start(sys.argv[1:])
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 8982c0676e..add43147b3 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -13,454 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import contextlib
-import logging
-import sys
-
-from six import iteritems
-
-from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse.api.constants import EventTypes
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.handlers.presence import PresenceHandler, get_interested_parties
-from synapse.http.server import JsonResource
-from synapse.http.site import SynapseSite
-from synapse.logging.context import LoggingContext, run_in_background
-from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.replication.slave.storage._base import BaseSlavedStore, __func__
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.filtering import SlavedFilteringStore
-from synapse.replication.slave.storage.groups import SlavedGroupServerStore
-from synapse.replication.slave.storage.presence import SlavedPresenceStore
-from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow
-from synapse.rest.client.v1 import events
-from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
-from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
-from synapse.rest.client.v2_alpha import sync
-from synapse.server import HomeServer
-from synapse.storage.data_stores.main.monthly_active_users import (
- MonthlyActiveUsersWorkerStore,
-)
-from synapse.storage.data_stores.main.presence import UserPresenceState
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
-from synapse.util.stringutils import random_string
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.synchrotron")
-
-
-class SynchrotronSlavedStore(
- SlavedReceiptsStore,
- SlavedAccountDataStore,
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
- SlavedFilteringStore,
- SlavedPresenceStore,
- SlavedGroupServerStore,
- SlavedDeviceInboxStore,
- SlavedDeviceStore,
- SlavedPushRuleStore,
- SlavedEventStore,
- SlavedClientIpStore,
- RoomStore,
- MonthlyActiveUsersWorkerStore,
- BaseSlavedStore,
-):
- pass
-
-
-UPDATE_SYNCING_USERS_MS = 10 * 1000
-
-
-class SynchrotronPresence(object):
- def __init__(self, hs):
- self.hs = hs
- self.is_mine_id = hs.is_mine_id
- self.http_client = hs.get_simple_http_client()
- self.store = hs.get_datastore()
- self.user_to_num_current_syncs = {}
- self.clock = hs.get_clock()
- self.notifier = hs.get_notifier()
-
- active_presence = self.store.take_presence_startup_info()
- self.user_to_current_state = {state.user_id: state for state in active_presence}
-
- # user_id -> last_sync_ms. Lists the users that have stopped syncing
- # but we haven't notified the master of that yet
- self.users_going_offline = {}
-
- self._send_stop_syncing_loop = self.clock.looping_call(
- self.send_stop_syncing, 10 * 1000
- )
-
- self.process_id = random_string(16)
- logger.info("Presence process_id is %r", self.process_id)
-
- def send_user_sync(self, user_id, is_syncing, last_sync_ms):
- if self.hs.config.use_presence:
- self.hs.get_tcp_replication().send_user_sync(
- user_id, is_syncing, last_sync_ms
- )
-
- def mark_as_coming_online(self, user_id):
- """A user has started syncing. Send a UserSync to the master, unless they
- had recently stopped syncing.
-
- Args:
- user_id (str)
- """
- going_offline = self.users_going_offline.pop(user_id, None)
- if not going_offline:
- # Safe to skip because we haven't yet told the master they were offline
- self.send_user_sync(user_id, True, self.clock.time_msec())
-
- def mark_as_going_offline(self, user_id):
- """A user has stopped syncing. We wait before notifying the master as
- its likely they'll come back soon. This allows us to avoid sending
- a stopped syncing immediately followed by a started syncing notification
- to the master
-
- Args:
- user_id (str)
- """
- self.users_going_offline[user_id] = self.clock.time_msec()
-
- def send_stop_syncing(self):
- """Check if there are any users who have stopped syncing a while ago
- and haven't come back yet. If there are poke the master about them.
- """
- now = self.clock.time_msec()
- for user_id, last_sync_ms in list(self.users_going_offline.items()):
- if now - last_sync_ms > 10 * 1000:
- self.users_going_offline.pop(user_id, None)
- self.send_user_sync(user_id, False, last_sync_ms)
-
- def set_state(self, user, state, ignore_status_msg=False):
- # TODO Hows this supposed to work?
- return defer.succeed(None)
-
- get_states = __func__(PresenceHandler.get_states)
- get_state = __func__(PresenceHandler.get_state)
- current_state_for_users = __func__(PresenceHandler.current_state_for_users)
-
- def user_syncing(self, user_id, affect_presence):
- if affect_presence:
- curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
- self.user_to_num_current_syncs[user_id] = curr_sync + 1
-
- # If we went from no in flight sync to some, notify replication
- if self.user_to_num_current_syncs[user_id] == 1:
- self.mark_as_coming_online(user_id)
-
- def _end():
- # We check that the user_id is in user_to_num_current_syncs because
- # user_to_num_current_syncs may have been cleared if we are
- # shutting down.
- if affect_presence and user_id in self.user_to_num_current_syncs:
- self.user_to_num_current_syncs[user_id] -= 1
-
- # If we went from one in flight sync to non, notify replication
- if self.user_to_num_current_syncs[user_id] == 0:
- self.mark_as_going_offline(user_id)
-
- @contextlib.contextmanager
- def _user_syncing():
- try:
- yield
- finally:
- _end()
-
- return defer.succeed(_user_syncing())
-
- @defer.inlineCallbacks
- def notify_from_replication(self, states, stream_id):
- parties = yield get_interested_parties(self.store, states)
- room_ids_to_states, users_to_states = parties
-
- self.notifier.on_new_event(
- "presence_key",
- stream_id,
- rooms=room_ids_to_states.keys(),
- users=users_to_states.keys(),
- )
-
- @defer.inlineCallbacks
- def process_replication_rows(self, token, rows):
- states = [
- UserPresenceState(
- row.user_id,
- row.state,
- row.last_active_ts,
- row.last_federation_update_ts,
- row.last_user_sync_ts,
- row.status_msg,
- row.currently_active,
- )
- for row in rows
- ]
-
- for state in states:
- self.user_to_current_state[state.user_id] = state
-
- stream_id = token
- yield self.notify_from_replication(states, stream_id)
-
- def get_currently_syncing_users(self):
- if self.hs.config.use_presence:
- return [
- user_id
- for user_id, count in iteritems(self.user_to_num_current_syncs)
- if count > 0
- ]
- else:
- return set()
-
-class SynchrotronTyping(object):
- def __init__(self, hs):
- self._latest_room_serial = 0
- self._reset()
-
- def _reset(self):
- """
- Reset the typing handler's data caches.
- """
- # map room IDs to serial numbers
- self._room_serials = {}
- # map room IDs to sets of users currently typing
- self._room_typing = {}
-
- def stream_positions(self):
- # We must update this typing token from the response of the previous
- # sync. In particular, the stream id may "reset" back to zero/a low
- # value which we *must* use for the next replication request.
- return {"typing": self._latest_room_serial}
-
- def process_replication_rows(self, token, rows):
- if self._latest_room_serial > token:
- # The master has gone backwards. To prevent inconsistent data, just
- # clear everything.
- self._reset()
-
- # Set the latest serial token to whatever the server gave us.
- self._latest_room_serial = token
-
- for row in rows:
- self._room_serials[row.room_id] = token
- self._room_typing[row.room_id] = row.user_ids
-
-
-class SynchrotronApplicationService(object):
- def notify_interested_services(self, event):
- pass
-
-
-class SynchrotronServer(HomeServer):
- DATASTORE_CLASS = SynchrotronSlavedStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "client":
- resource = JsonResource(self, canonical_json=False)
- sync.register_servlets(self, resource)
- events.register_servlets(self, resource)
- InitialSyncRestServlet(self).register(resource)
- RoomInitialSyncRestServlet(self).register(resource)
- resources.update(
- {
- "/_matrix/client/r0": resource,
- "/_matrix/client/unstable": resource,
- "/_matrix/client/v2_alpha": resource,
- "/_matrix/client/api/v1": resource,
- }
- )
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
- )
-
- logger.info("Synapse synchrotron now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warning(
- (
- "Metrics listener configured, but "
- "enable_metrics is not True!"
- )
- )
- else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
- else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return SyncReplicationHandler(self)
-
- def build_presence_handler(self):
- return SynchrotronPresence(self)
-
- def build_typing_handler(self):
- return SynchrotronTyping(self)
-
-
-class SyncReplicationHandler(ReplicationClientHandler):
- def __init__(self, hs):
- super(SyncReplicationHandler, self).__init__(hs.get_datastore())
-
- self.store = hs.get_datastore()
- self.typing_handler = hs.get_typing_handler()
- # NB this is a SynchrotronPresence, not a normal PresenceHandler
- self.presence_handler = hs.get_presence_handler()
- self.notifier = hs.get_notifier()
-
- async def on_rdata(self, stream_name, token, rows):
- await super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
- run_in_background(self.process_and_notify, stream_name, token, rows)
-
- def get_streams_to_replicate(self):
- args = super(SyncReplicationHandler, self).get_streams_to_replicate()
- args.update(self.typing_handler.stream_positions())
- return args
-
- def get_currently_syncing_users(self):
- return self.presence_handler.get_currently_syncing_users()
-
- async def process_and_notify(self, stream_name, token, rows):
- try:
- if stream_name == "events":
- # We shouldn't get multiple rows per token for events stream, so
- # we don't need to optimise this for multiple rows.
- for row in rows:
- if row.type != EventsStreamEventRow.TypeId:
- continue
- assert isinstance(row, EventsStreamRow)
-
- event = await self.store.get_event(
- row.data.event_id, allow_rejected=True
- )
- if event.rejected_reason:
- continue
-
- extra_users = ()
- if event.type == EventTypes.Member:
- extra_users = (event.state_key,)
- max_token = self.store.get_room_max_stream_ordering()
- self.notifier.on_new_room_event(
- event, token, max_token, extra_users
- )
- elif stream_name == "push_rules":
- self.notifier.on_new_event(
- "push_rules_key", token, users=[row.user_id for row in rows]
- )
- elif stream_name in ("account_data", "tag_account_data"):
- self.notifier.on_new_event(
- "account_data_key", token, users=[row.user_id for row in rows]
- )
- elif stream_name == "receipts":
- self.notifier.on_new_event(
- "receipt_key", token, rooms=[row.room_id for row in rows]
- )
- elif stream_name == "typing":
- self.typing_handler.process_replication_rows(token, rows)
- self.notifier.on_new_event(
- "typing_key", token, rooms=[row.room_id for row in rows]
- )
- elif stream_name == "to_device":
- entities = [row.entity for row in rows if row.entity.startswith("@")]
- if entities:
- self.notifier.on_new_event("to_device_key", token, users=entities)
- elif stream_name == "device_lists":
- all_room_ids = set()
- for row in rows:
- room_ids = await self.store.get_rooms_for_user(row.user_id)
- all_room_ids.update(room_ids)
- self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
- elif stream_name == "presence":
- await self.presence_handler.process_replication_rows(token, rows)
- elif stream_name == "receipts":
- self.notifier.on_new_event(
- "groups_key", token, users=[row.user_id for row in rows]
- )
- except Exception:
- logger.exception("Error processing replication")
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config("Synapse synchrotron", config_options)
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.synchrotron"
-
- synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- ss = SynchrotronServer(
- config.server_name,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- application_service_handler=SynchrotronApplicationService(),
- )
-
- setup_logging(ss, config, use_worker_options=True)
-
- ss.setup()
- reactor.addSystemEventTrigger(
- "before", "startup", _base.start, ss, config.worker_listeners
- )
-
- _base.start_worker_reactor("synapse-synchrotron", config)
+import sys
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
if __name__ == "__main__":
with LoggingContext("main"):
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index ba536d6f04..503d44f687 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -14,217 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
import sys
-from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse import events
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.server import JsonResource
-from synapse.http.site import SynapseSite
-from synapse.logging.context import LoggingContext, run_in_background
-from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.replication.tcp.streams.events import (
- EventsStream,
- EventsStreamCurrentStateRow,
-)
-from synapse.rest.client.v2_alpha import user_directory
-from synapse.server import HomeServer
-from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
-from synapse.storage.database import Database
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.user_dir")
-
-
-class UserDirectorySlaveStore(
- SlavedEventStore,
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
- SlavedClientIpStore,
- UserDirectoryStore,
- BaseSlavedStore,
-):
- def __init__(self, database: Database, db_conn, hs):
- super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs)
-
- events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
- db_conn,
- "current_state_delta_stream",
- entity_column="room_id",
- stream_column="stream_id",
- max_value=events_max, # As we share the stream id with events token
- limit=1000,
- )
- self._curr_state_delta_stream_cache = StreamChangeCache(
- "_curr_state_delta_stream_cache",
- min_curr_state_delta_id,
- prefilled_cache=curr_state_delta_prefill,
- )
-
- def stream_positions(self):
- result = super(UserDirectorySlaveStore, self).stream_positions()
- return result
-
- def process_replication_rows(self, stream_name, token, rows):
- if stream_name == EventsStream.NAME:
- self._stream_id_gen.advance(token)
- for row in rows:
- if row.type != EventsStreamCurrentStateRow.TypeId:
- continue
- self._curr_state_delta_stream_cache.entity_has_changed(
- row.data.room_id, token
- )
- return super(UserDirectorySlaveStore, self).process_replication_rows(
- stream_name, token, rows
- )
-
-
-class UserDirectoryServer(HomeServer):
- DATASTORE_CLASS = UserDirectorySlaveStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "client":
- resource = JsonResource(self, canonical_json=False)
- user_directory.register_servlets(self, resource)
- resources.update(
- {
- "/_matrix/client/r0": resource,
- "/_matrix/client/unstable": resource,
- "/_matrix/client/v2_alpha": resource,
- "/_matrix/client/api/v1": resource,
- }
- )
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
- )
-
- logger.info("Synapse user_dir now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warning(
- (
- "Metrics listener configured, but "
- "enable_metrics is not True!"
- )
- )
- else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
- else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return UserDirectoryReplicationHandler(self)
-
-
-class UserDirectoryReplicationHandler(ReplicationClientHandler):
- def __init__(self, hs):
- super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
- self.user_directory = hs.get_user_directory_handler()
-
- async def on_rdata(self, stream_name, token, rows):
- await super(UserDirectoryReplicationHandler, self).on_rdata(
- stream_name, token, rows
- )
- if stream_name == EventsStream.NAME:
- run_in_background(self._notify_directory)
-
- @defer.inlineCallbacks
- def _notify_directory(self):
- try:
- yield self.user_directory.notify_new_event()
- except Exception:
- logger.exception("Error notifiying user directory of state update")
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config("Synapse user directory", config_options)
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.user_dir"
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- if config.update_user_directory:
- sys.stderr.write(
- "\nThe update_user_directory must be disabled in the main synapse process"
- "\nbefore they can be run in a separate worker."
- "\nPlease add ``update_user_directory: false`` to the main config"
- "\n"
- )
- sys.exit(1)
-
- # Force the pushers to start since they will be disabled in the main config
- config.update_user_directory = True
-
- ss = UserDirectoryServer(
- config.server_name,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- )
-
- setup_logging(ss, config, use_worker_options=True)
-
- ss.setup()
- reactor.addSystemEventTrigger(
- "before", "startup", _base.start, ss, config.worker_listeners
- )
-
- _base.start_worker_reactor("synapse-user-dir", config)
-
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
if __name__ == "__main__":
with LoggingContext("main"):
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index aea3985a5f..1b13e84425 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -270,7 +270,7 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
- def get_exlusive_user_regexes(self):
+ def get_exclusive_user_regexes(self):
"""Get the list of regexes used to determine if a user is exclusively
registered by the AS
"""
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index ba846042c4..eda2b65ef7 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,6 +18,7 @@
import argparse
import errno
import os
+from io import open as io_open
from collections import OrderedDict
from textwrap import dedent
from typing import Any, MutableMapping, Optional
@@ -181,7 +182,7 @@ class Config(object):
@classmethod
def read_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
- with open(file_path) as file_stream:
+ with io_open(file_path, encoding="utf-8") as file_stream:
return file_stream.read()
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 74853f9faa..f31fc85ec8 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -27,6 +27,12 @@ import pkg_resources
from ._base import Config, ConfigError
+MISSING_PASSWORD_RESET_CONFIG_ERROR = """\
+Password reset emails are enabled on this homeserver due to a partial
+'email' block. However, the following required keys are missing:
+ %s
+"""
+
class EmailConfig(Config):
section = "email"
@@ -142,24 +148,18 @@ class EmailConfig(Config):
bleach
if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- required = ["smtp_host", "smtp_port", "notif_from"]
-
missing = []
- for k in required:
- if k not in email_config:
- missing.append("email." + k)
+ if not self.email_notif_from:
+ missing.append("email.notif_from")
# public_baseurl is required to build password reset and validation links that
# will be emailed to users
if config.get("public_baseurl") is None:
missing.append("public_baseurl")
- if len(missing) > 0:
- raise RuntimeError(
- "Password resets emails are configured to be sent from "
- "this homeserver due to a partial 'email' block. "
- "However, the following required keys are missing: %s"
- % (", ".join(missing),)
+ if missing:
+ raise ConfigError(
+ MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),)
)
# These email templates have placeholders in them, and thus must be
@@ -245,32 +245,25 @@ class EmailConfig(Config):
)
if self.email_enable_notifs:
- required = [
- "smtp_host",
- "smtp_port",
- "notif_from",
- "notif_template_html",
- "notif_template_text",
- ]
-
missing = []
- for k in required:
- if k not in email_config:
- missing.append(k)
-
- if len(missing) > 0:
- raise RuntimeError(
- "email.enable_notifs is True but required keys are missing: %s"
- % (", ".join(["email." + k for k in missing]),)
- )
+ if not self.email_notif_from:
+ missing.append("email.notif_from")
if config.get("public_baseurl") is None:
- raise RuntimeError(
- "email.enable_notifs is True but no public_baseurl is set"
+ missing.append("public_baseurl")
+
+ if missing:
+ raise ConfigError(
+ "email.enable_notifs is True but required keys are missing: %s"
+ % (", ".join(missing),)
)
- self.email_notif_template_html = email_config["notif_template_html"]
- self.email_notif_template_text = email_config["notif_template_text"]
+ self.email_notif_template_html = email_config.get(
+ "notif_template_html", "notif_mail.html"
+ )
+ self.email_notif_template_text = email_config.get(
+ "notif_template_text", "notif_mail.txt"
+ )
for f in self.email_notif_template_text, self.email_notif_template_html:
p = os.path.join(self.email_template_dir, f)
@@ -323,10 +316,6 @@ class EmailConfig(Config):
#
#require_transport_security: true
- # Enable sending emails for messages that the user has missed
- #
- #enable_notifs: false
-
# notif_from defines the "From" address to use when sending emails.
# It must be set if email sending is enabled.
#
@@ -344,6 +333,11 @@ class EmailConfig(Config):
#
#app_name: my_branded_matrix_server
+ # Uncomment the following to enable sending emails for messages that the user
+ # has missed. Disabled by default.
+ #
+ #enable_notifs: true
+
# Uncomment the following to disable automatic subscription to email
# notifications for new users. Enabled by default.
#
diff --git a/synapse/config/password.py b/synapse/config/password.py
index 2a634ac751..2c13810ab8 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,6 +33,10 @@ class PasswordConfig(Config):
self.password_localdb_enabled = password_config.get("localdb_enabled", True)
self.password_pepper = password_config.get("pepper", "")
+ # Password policy
+ self.password_policy = password_config.get("policy", {})
+ self.password_policy_enabled = self.password_policy.pop("enabled", False)
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
password_config:
@@ -48,4 +54,34 @@ class PasswordConfig(Config):
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
#
#pepper: "EVEN_MORE_SECRET"
+
+ # Define and enforce a password policy. Each parameter is optional, boolean
+ # parameters default to 'false' and integer parameters default to 0.
+ # This is an early implementation of MSC2000.
+ #
+ #policy:
+ # Whether to enforce the password policy.
+ #
+ #enabled: true
+
+ # Minimum accepted length for a password.
+ #
+ #minimum_length: 15
+
+ # Whether a password must contain at least one digit.
+ #
+ #require_digit: true
+
+ # Whether a password must contain at least one symbol.
+ # A symbol is any character that's not a number or a letter.
+ #
+ #require_symbol: true
+
+ # Whether a password must contain at least one lowercase letter.
+ #
+ #require_lowercase: true
+
+ # Whether a password must contain at least one lowercase letter.
+ #
+ #require_uppercase: true
"""
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 4a3bfc4354..dbc3dd7a2c 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -70,6 +70,9 @@ class RatelimitConfig(Config):
)
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+ self.rc_third_party_invite = RateLimitConfig(
+ config.get("rc_third_party_invite", {})
+ )
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
@@ -109,6 +112,8 @@ class RatelimitConfig(Config):
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+ # - one that ratelimits third-party invites requests based on the account
+ # that's making the requests.
# - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly.
@@ -134,6 +139,10 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
#
+ #rc_third_party_invite:
+ # per_second: 0.2
+ # burst_count: 10
+ #
#rc_admin_redaction:
# per_second: 1
# burst_count: 50
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 9bb3beedbc..7dba213d74 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -99,8 +99,19 @@ class RegistrationConfig(Config):
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
+ self.check_is_for_allowed_local_3pids = config.get(
+ "check_is_for_allowed_local_3pids", None
+ )
+ self.allow_invited_3pids = config.get("allow_invited_3pids", False)
+
+ self.disable_3pid_changes = config.get("disable_3pid_changes", False)
+
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
self.registration_shared_secret = config.get("registration_shared_secret")
+ self.register_mxid_from_3pid = config.get("register_mxid_from_3pid")
+ self.register_just_use_email_for_display_name = config.get(
+ "register_just_use_email_for_display_name", False
+ )
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config.get(
@@ -129,6 +140,18 @@ class RegistrationConfig(Config):
raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
+ self.disable_set_displayname = config.get("disable_set_displayname", False)
+ self.disable_set_avatar_url = config.get("disable_set_avatar_url", False)
+
+ self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", [])
+ if not isinstance(self.replicate_user_profiles_to, list):
+ self.replicate_user_profiles_to = [self.replicate_user_profiles_to]
+
+ self.shadow_server = config.get("shadow_server", None)
+ self.rewrite_identity_server_urls = config.get(
+ "rewrite_identity_server_urls", {}
+ )
+
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
)
@@ -244,9 +267,32 @@ class RegistrationConfig(Config):
#
#disable_msisdn_registration: true
+ # Derive the user's matrix ID from a type of 3PID used when registering.
+ # This overrides any matrix ID the user proposes when calling /register
+ # The 3PID type should be present in registrations_require_3pid to avoid
+ # users failing to register if they don't specify the right kind of 3pid.
+ #
+ #register_mxid_from_3pid: email
+
+ # Uncomment to set the display name of new users to their email address,
+ # rather than using the default heuristic.
+ #
+ #register_just_use_email_for_display_name: true
+
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
+ # Use an Identity Server to establish which 3PIDs are allowed to register?
+ # Overrides allowed_local_3pids below.
+ #
+ #check_is_for_allowed_local_3pids: matrix.org
+ #
+ # If you are using an IS you can also check whether that IS registers
+ # pending invites for the given 3PID (and then allow it to sign up on
+ # the platform):
+ #
+ #allow_invited_3pids: False
+ #
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\\.org'
@@ -255,6 +301,11 @@ class RegistrationConfig(Config):
# - medium: msisdn
# pattern: '\\+44'
+ # If true, stop users from trying to change the 3PIDs associated with
+ # their accounts.
+ #
+ #disable_3pid_changes: False
+
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -304,6 +355,30 @@ class RegistrationConfig(Config):
# - matrix.org
# - vector.im
+ # If enabled, user IDs, display names and avatar URLs will be replicated
+ # to this server whenever they change.
+ # This is an experimental API currently implemented by sydent to support
+ # cross-homeserver user directories.
+ #
+ #replicate_user_profiles_to: example.com
+
+ # If specified, attempt to replay registrations, profile changes & 3pid
+ # bindings on the given target homeserver via the AS API. The HS is authed
+ # via a given AS token.
+ #
+ #shadow_server:
+ # hs_url: https://shadow.example.com
+ # hs: shadow.example.com
+ # as_token: 12u394refgbdhivsia
+
+ # If enabled, don't let users set their own display names/avatars
+ # other than for the very first time (unless they are a server admin).
+ # Useful when provisioning users based on the contents of a 3rd party
+ # directory and to avoid ambiguities.
+ #
+ #disable_set_displayname: False
+ #disable_set_avatar_url: False
+
# Handle threepid (email/phone etc) registration and password resets through a set of
# *trusted* identity servers. Note that this allows the configured identity server to
# reset passwords for accounts!
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 7d2dd27fd0..5ebc2ea1f1 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -97,6 +97,12 @@ class ContentRepositoryConfig(Config):
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
+ self.max_avatar_size = config.get("max_avatar_size")
+ if self.max_avatar_size:
+ self.max_avatar_size = self.parse_size(self.max_avatar_size)
+
+ self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes", [])
+
self.media_store_path = self.ensure_directory(
config.get("media_store_path", "media_store")
)
@@ -234,6 +240,30 @@ class ContentRepositoryConfig(Config):
#
#max_upload_size: 10M
+ # The largest allowed size for a user avatar. If not defined, no
+ # restriction will be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #max_avatar_size: 10M
+
+ # Allow mimetypes for a user avatar. If not defined, no restriction will
+ # be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 0ec1b0fadd..f5942c45c2 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -253,6 +253,12 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # True.
+ self.show_users_in_user_directory = config.get(
+ "show_users_in_user_directory", True
+ )
+
retention_config = config.get("retention")
if retention_config is None:
retention_config = {}
@@ -892,6 +898,74 @@ class ServerConfig(Config):
#
#allow_per_room_profiles: false
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # 'true'.
+ #
+ #show_users_in_user_directory: false
+
+ # Message retention policy at the server level.
+ #
+ # Room admins and mods can define a retention period for their rooms using the
+ # 'm.room.retention' state event, and server admins can cap this period by setting
+ # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+ #
+ # If this feature is enabled, Synapse will regularly look for and purge events
+ # which are older than the room's maximum retention period. Synapse will also
+ # filter events received over federation so that events that should have been
+ # purged are ignored and not stored again.
+ #
+ retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
+
# How long to keep redacted events in unredacted form in the database. After
# this period redacted events get replaced with their redacted form in the DB.
#
@@ -1066,12 +1140,12 @@ KNOWN_RESOURCES = (
def _check_resource_config(listeners):
- resource_names = set(
+ resource_names = {
res_name
for listener in listeners
for res in listener.get("resources", [])
for res_name in res.get("names", [])
- )
+ }
for resource in resource_names:
if resource not in KNOWN_RESOURCES:
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 97a12d51f6..a65538562b 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -260,7 +260,7 @@ class TlsConfig(Config):
crypto.FILETYPE_ASN1, self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
- sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
+ sha256_fingerprints = {f["sha256"] for f in self.tls_fingerprints}
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({"sha256": sha256_fingerprint})
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c8d19c5d6b..43b6c40456 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -26,6 +26,7 @@ class UserDirectoryConfig(Config):
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
+ self.user_directory_defer_to_id_server = None
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_enabled = user_directory_config.get(
@@ -34,6 +35,9 @@ class UserDirectoryConfig(Config):
self.user_directory_search_all_users = user_directory_config.get(
"search_all_users", False
)
+ self.user_directory_defer_to_id_server = user_directory_config.get(
+ "defer_to_id_server", None
+ )
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
@@ -52,4 +56,9 @@ class UserDirectoryConfig(Config):
#user_directory:
# enabled: true
# search_all_users: false
+ #
+ # # If this is set, user search will be delegated to this ID server instead
+ # # of synapse performing the search itself.
+ # # This is an experimental API.
+ # defer_to_id_server: https://id.example.com
"""
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 6fe5a6a26a..983f0ead8c 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -326,9 +326,7 @@ class Keyring(object):
verify_requests (list[VerifyJsonRequest]): list of verify requests
"""
- remaining_requests = set(
- (rq for rq in verify_requests if not rq.key_ready.called)
- )
+ remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@defer.inlineCallbacks
def do_iterations():
@@ -396,7 +394,7 @@ class Keyring(object):
results = yield fetcher.get_keys(missing_keys)
- completed = list()
+ completed = []
for verify_request in remaining_requests:
server_name = verify_request.server_name
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index a23b6b7b61..6e6d75bdcb 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,7 +15,7 @@
# limitations under the License.
import inspect
-from typing import Dict
+from typing import Dict, Optional, List
from synapse.spam_checker_api import SpamCheckerApi
@@ -64,16 +64,32 @@ class SpamChecker(object):
return self.spam_checker.check_event_for_spam(event)
def user_may_invite(
- self, inviter_userid: str, invitee_userid: str, room_id: str
+ self,
+ inviter_userid: str,
+ invitee_userid: str,
+ third_party_invite: Optional[Dict],
+ room_id: str,
+ new_room: bool,
+ published_room: bool,
) -> bool:
"""Checks if a given user may send an invite
If this method returns false, the invite will be rejected.
Args:
- inviter_userid: The user ID of the sender of the invitation
- invitee_userid: The user ID targeted in the invitation
- room_id: The room ID
+ inviter_userid:
+ invitee_userid: The user ID of the invitee. Is None
+ if this is a third party invite and the 3PID is not bound to a
+ user ID.
+ third_party_invite: If a third party invite then is a
+ dict containing the medium and address of the invitee.
+ room_id:
+ new_room: Whether the user is being invited to the room as
+ part of a room creation, if so the invitee would have been
+ included in the call to `user_may_create_room`.
+ published_room: Whether the room the user is being invited
+ to has been published in the local homeserver's public room
+ directory.
Returns:
True if the user may send an invite, otherwise False
@@ -82,16 +98,33 @@ class SpamChecker(object):
return True
return self.spam_checker.user_may_invite(
- inviter_userid, invitee_userid, room_id
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room,
)
- def user_may_create_room(self, userid: str) -> bool:
+ def user_may_create_room(
+ self,
+ userid: str,
+ invite_list: List[str],
+ third_party_invite_list: List[Dict],
+ cloning: bool,
+ ) -> bool:
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
Args:
userid: The ID of the user attempting to create a room
+ invite_list: List of user IDs that would be invited to
+ the new room.
+ third_party_invite_list: List of third party invites
+ for the new room.
+ cloning: Whether the user is cloning an existing room, e.g.
+ upgrading a room.
Returns:
True if the user may create a room, otherwise False
@@ -99,7 +132,9 @@ class SpamChecker(object):
if self.spam_checker is None:
return True
- return self.spam_checker.user_may_create_room(userid)
+ return self.spam_checker.user_may_create_room(
+ userid, invite_list, third_party_invite_list, cloning
+ )
def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
"""Checks if a given user may create a room alias
@@ -135,6 +170,24 @@ class SpamChecker(object):
return self.spam_checker.user_may_publish_room(userid, room_id)
+ def user_may_join_room(self, userid, room_id, is_invited):
+ """Checks if a given users is allowed to join a room.
+
+ Is not called when the user creates a room.
+
+ Args:
+ userid (str)
+ room_id (str)
+ is_invited (bool): Whether the user is invited into the room
+
+ Returns:
+ bool: Whether the user may join the room
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_join_room(userid, room_id, is_invited)
+
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 9fff65716a..190ea1fba1 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -15,11 +15,13 @@
# limitations under the License.
import logging
from collections import namedtuple
+from typing import Iterable, List
import six
from twisted.internet import defer
-from twisted.internet.defer import DeferredList
+from twisted.internet.defer import Deferred, DeferredList
+from twisted.python.failure import Failure
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
@@ -29,6 +31,7 @@ from synapse.api.room_versions import (
RoomVersion,
)
from synapse.crypto.event_signing import check_event_content_hash
+from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
@@ -56,7 +59,12 @@ class FederationBase(object):
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(
- self, origin, pdus, room_version, outlier=False, include_none=False
+ self,
+ origin: str,
+ pdus: List[EventBase],
+ room_version: str,
+ outlier: bool = False,
+ include_none: bool = False,
):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
@@ -69,11 +77,11 @@ class FederationBase(object):
a new list.
Args:
- origin (str)
- pdu (list)
- room_version (str)
- outlier (bool): Whether the events are outliers or not
- include_none (str): Whether to include None in the returned list
+ origin
+ pdu
+ room_version
+ outlier: Whether the events are outliers or not
+ include_none: Whether to include None in the returned list
for events that have failed their checks
Returns:
@@ -82,7 +90,7 @@ class FederationBase(object):
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@defer.inlineCallbacks
- def handle_check_result(pdu, deferred):
+ def handle_check_result(pdu: EventBase, deferred: Deferred):
try:
res = yield make_deferred_yieldable(deferred)
except SynapseError:
@@ -96,8 +104,10 @@ class FederationBase(object):
if not res and pdu.origin != origin:
try:
+ # This should not exist in the base implementation, until
+ # this is fixed, ignore it for typing. See issue #6997.
res = yield defer.ensureDeferred(
- self.get_pdu(
+ self.get_pdu( # type: ignore
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version,
@@ -127,21 +137,23 @@ class FederationBase(object):
else:
return [p for p in valid_pdus if p]
- def _check_sigs_and_hash(self, room_version, pdu):
+ def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0]
)
- def _check_sigs_and_hashes(self, room_version, pdus):
+ def _check_sigs_and_hashes(
+ self, room_version: str, pdus: List[EventBase]
+ ) -> List[Deferred]:
"""Checks that each of the received events is correctly signed by the
sending server.
Args:
- room_version (str): The room version of the PDUs
- pdus (list[FrozenEvent]): the events to be checked
+ room_version: The room version of the PDUs
+ pdus: the events to be checked
Returns:
- list[Deferred]: for each input event, a deferred which:
+ For each input event, a deferred which:
* returns the original event if the checks pass
* returns a redacted version of the event (if the signature
matched but the hash did not)
@@ -152,7 +164,7 @@ class FederationBase(object):
ctx = LoggingContext.current_context()
- def callback(_, pdu):
+ def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
# let's try to distinguish between failures because the event was
@@ -189,7 +201,7 @@ class FederationBase(object):
return pdu
- def errback(failure, pdu):
+ def errback(failure: Failure, pdu: EventBase):
failure.trap(SynapseError)
with PreserveLoggingContext(ctx):
logger.warning(
@@ -215,16 +227,18 @@ class PduToCheckSig(
pass
-def _check_sigs_on_pdus(keyring, room_version, pdus):
+def _check_sigs_on_pdus(
+ keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
+) -> List[Deferred]:
"""Check that the given events are correctly signed
Args:
- keyring (synapse.crypto.Keyring): keyring object to do the checks
- room_version (str): the room version of the PDUs
- pdus (Collection[EventBase]): the events to be checked
+ keyring: keyring object to do the checks
+ room_version: the room version of the PDUs
+ pdus: the events to be checked
Returns:
- List[Deferred]: a Deferred for each event in pdus, which will either succeed if
+ A Deferred for each event in pdus, which will either succeed if
the signatures are valid, or fail (with a SynapseError) if not.
"""
@@ -329,7 +343,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
-def _flatten_deferred_list(deferreds):
+def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
"""Given a list of deferreds, either return the single deferred,
combine into a DeferredList, or return an already resolved deferred.
"""
@@ -341,7 +355,7 @@ def _flatten_deferred_list(deferreds):
return defer.succeed(None)
-def _is_invite_via_3pid(event):
+def _is_invite_via_3pid(event: EventBase) -> bool:
return (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 4870e39652..b5538bc07a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -187,7 +187,7 @@ class FederationClient(FederationBase):
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
- ) -> List[EventBase]:
+ ) -> Optional[List[EventBase]]:
"""Requests some more historic PDUs for the given room from the
given destination server.
@@ -199,9 +199,9 @@ class FederationClient(FederationBase):
"""
logger.debug("backfill extrem=%s", extremities)
- # If there are no extremeties then we've (probably) reached the start.
+ # If there are no extremities then we've (probably) reached the start.
if not extremities:
- return
+ return None
transaction_data = await self.transport_layer.backfill(
dest, room_id, extremities, limit
@@ -284,7 +284,7 @@ class FederationClient(FederationBase):
pdu_list = [
event_from_pdu_json(p, room_version, outlier=outlier)
for p in transaction_data["pdus"]
- ]
+ ] # type: List[EventBase]
if pdu_list and pdu_list[0]:
pdu = pdu_list[0]
@@ -615,7 +615,7 @@ class FederationClient(FederationBase):
]
if auth_chain_create_events != [create_event.event_id]:
raise InvalidResponseError(
- "Unexpected create event(s) in auth chain"
+ "Unexpected create event(s) in auth chain: %s"
% (auth_chain_create_events,)
)
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 001bb304ae..876fb0e245 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -129,9 +129,9 @@ class FederationRemoteSendQueue(object):
for key in keys[:i]:
del self.presence_changed[key]
- user_ids = set(
+ user_ids = {
user_id for uids in self.presence_changed.values() for user_id in uids
- )
+ }
keys = self.presence_destinations.keys()
i = self.presence_destinations.bisect_left(position_to_delete)
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index c106abae21..4f0dc0a209 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -608,7 +608,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
user_results = yield self.store.get_users_in_group(
group_id, include_private=True
)
- if user_id in [user_result["user_id"] for user_result in user_results]:
+ if user_id in (user_result["user_id"] for user_result in user_results):
raise SynapseError(400, "User already in group")
content = {
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 829f52eca1..6c46c995d2 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -20,6 +20,8 @@ from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import List
+from twisted.internet import defer
+
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -43,6 +45,8 @@ class AccountValidityHandler(object):
self.clock = self.hs.get_clock()
self._account_validity = self.hs.config.account_validity
+ self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory
+ self.profile_handler = self.hs.get_profile_handler()
if self._account_validity.renew_by_email_enabled and load_jinja2_templates:
# Don't do email-specific configuration if renewal by email is disabled.
@@ -82,6 +86,9 @@ class AccountValidityHandler(object):
self.clock.looping_call(send_emails, 30 * 60 * 1000)
+ # Check every hour to remove expired users from the user directory
+ self.clock.looping_call(self._mark_expired_users_as_inactive, 60 * 60 * 1000)
+
async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
@@ -262,4 +269,27 @@ class AccountValidityHandler(object):
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
+ # Check if renewed users should be reintroduced to the user directory
+ if self._show_users_in_user_directory:
+ # Show the user in the directory again by setting them to active
+ await self.profile_handler.set_active(
+ UserID.from_string(user_id), True, True
+ )
+
return expiration_ts
+
+ @defer.inlineCallbacks
+ def _mark_expired_users_as_inactive(self):
+ """Iterate over expired users. Mark them as inactive in order to hide them from the
+ user directory.
+
+ Returns:
+ Deferred
+ """
+ # Get expired users
+ expired_user_ids = yield self.store.get_expired_users()
+ expired_users = [UserID.from_string(user_id) for user_id in expired_user_ids]
+
+ # Mark each one as non-active
+ for user in expired_users:
+ yield self.profile_handler.set_active(user, False, True)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 2afb390a92..f624c2a3f9 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -33,6 +33,7 @@ class DeactivateAccountHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_handlers().identity_handler
+ self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
# Flag that indicates whether the process to part users from rooms is running
@@ -104,6 +105,9 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.user_set_password_hash(user_id, None)
+ user = UserID.from_string(user_id)
+ await self._profile_handler.set_active(user, False, False)
+
# Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of)
await self.store.add_user_pending_deactivation(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 50cea3f378..a514c30714 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -742,6 +742,6 @@ class DeviceListUpdater(object):
# We clobber the seen updates since we've re-synced from a given
# point.
- self._seen_updates[user_id] = set([stream_id])
+ self._seen_updates[user_id] = {stream_id}
defer.returnValue(result)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index db2104c5f6..0b23ca919a 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -14,6 +14,7 @@
# limitations under the License.
+import collections
import logging
import string
from typing import List
@@ -71,7 +72,7 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Check if there is a current association.
if not servers:
users = yield self.state.get_current_users_in_room(room_id)
- servers = set(get_domain_from_id(u) for u in users)
+ servers = {get_domain_from_id(u) for u in users}
if not servers:
raise SynapseError(400, "Failed to get server list")
@@ -254,7 +255,7 @@ class DirectoryHandler(BaseHandler):
)
users = yield self.state.get_current_users_in_room(room_id)
- extra_servers = set(get_domain_from_id(u) for u in users)
+ extra_servers = {get_domain_from_id(u) for u in users}
servers = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first.
@@ -283,22 +284,6 @@ class DirectoryHandler(BaseHandler):
)
@defer.inlineCallbacks
- def send_room_alias_update_event(self, requester, room_id):
- aliases = yield self.store.get_aliases_for_room(room_id)
-
- yield self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- {
- "type": EventTypes.Aliases,
- "state_key": self.hs.hostname,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- "content": {"aliases": aliases},
- },
- ratelimit=False,
- )
-
- @defer.inlineCallbacks
def _update_canonical_alias(self, requester, user_id, room_id, room_alias):
"""
Send an updated canonical alias event if the removed alias was set as
@@ -326,7 +311,7 @@ class DirectoryHandler(BaseHandler):
alt_aliases = content.pop("alt_aliases", None)
# If the aliases are not a list (or not found) do not attempt to modify
# the list.
- if isinstance(alt_aliases, list):
+ if isinstance(alt_aliases, collections.Sequence):
send_update = True
alt_aliases = [alias for alias in alt_aliases if alias != alias_str]
if alt_aliases:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index eb20ef4aec..5d686ab2c9 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -41,7 +41,6 @@ from synapse.api.errors import (
FederationDeniedError,
FederationError,
RequestSendFailed,
- StoreError,
SynapseError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
@@ -61,6 +60,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationFederationSendEventsRestServlet,
+ ReplicationStoreRoomOnInviteRestServlet,
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
@@ -161,8 +161,12 @@ class FederationHandler(BaseHandler):
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
hs
)
+ self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client(
+ hs
+ )
else:
self._device_list_updater = hs.get_device_handler().device_list_updater
+ self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
# When joining a room we need to queue any events for that room up
self.room_queues = {}
@@ -187,7 +191,7 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("handling received PDU: %s", pdu)
+ logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
# We reprocess pdus when we have seen them only as outliers
existing = await self.store.get_event(
@@ -302,6 +306,14 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
+ elif missing_prevs:
+ logger.info(
+ "[%s %s] Not recursively fetching %d missing prev_events: %s",
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -346,12 +358,6 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
- logger.info(
- "Event %s is missing prev_events: calculating state for a "
- "backwards extremity",
- event_id,
- )
-
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: pdu}
@@ -369,7 +375,10 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "Requesting state at missing prev_event %s", event_id,
+ "[%s %s] Requesting state at missing prev_event %s",
+ room_id,
+ event_id,
+ p,
)
with nested_logging_context(p):
@@ -405,7 +414,6 @@ class FederationHandler(BaseHandler):
evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.AS_IS,
)
event_map.update(evs)
@@ -659,11 +667,11 @@ class FederationHandler(BaseHandler):
# this can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
- bad_events = list(
+ bad_events = [
(event_id, event.room_id)
for event_id, event in fetched_events.items()
if event.room_id != room_id
- )
+ ]
for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time
@@ -707,28 +715,6 @@ class FederationHandler(BaseHandler):
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
- room = await self.store.get_room(room_id)
-
- if not room:
- try:
- prev_state_ids = await context.get_prev_state_ids()
- create_event = await self.store.get_event(
- prev_state_ids[(EventTypes.Create, "")]
- )
-
- room_version_id = create_event.content.get(
- "room_version", RoomVersions.V1.identifier
- )
-
- await self.store.store_room(
- room_id=room_id,
- room_creator_user_id="",
- is_public=False,
- room_version=KNOWN_ROOM_VERSIONS[room_version_id],
- )
- except StoreError:
- logger.exception("Failed to store room.")
-
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally
@@ -856,7 +842,7 @@ class FederationHandler(BaseHandler):
# Don't bother processing events we already have.
seen_events = await self.store.have_events_in_timeline(
- set(e.event_id for e in events)
+ {e.event_id for e in events}
)
events = [e for e in events if e.event_id not in seen_events]
@@ -866,7 +852,7 @@ class FederationHandler(BaseHandler):
event_map = {e.event_id: e for e in events}
- event_ids = set(e.event_id for e in events)
+ event_ids = {e.event_id for e in events}
# build a list of events whose prev_events weren't in the batch.
# (XXX: this will include events whose prev_events we already have; that doesn't
@@ -892,13 +878,13 @@ class FederationHandler(BaseHandler):
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
- required_auth = set(
+ required_auth = {
a_id
for event in events
+ list(state_events.values())
+ list(auth_events.values())
for a_id in event.auth_event_ids()
- )
+ }
auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
)
@@ -1247,7 +1233,7 @@ class FederationHandler(BaseHandler):
async def on_event_auth(self, event_id: str) -> List[EventBase]:
event = await self.store.get_event(event_id)
auth = await self.store.get_auth_chain(
- [auth_id for auth_id in event.auth_event_ids()], include_given=True
+ list(event.auth_event_ids()), include_given=True
)
return list(auth)
@@ -1323,16 +1309,18 @@ class FederationHandler(BaseHandler):
logger.debug("do_invite_join event: %s", event)
- try:
- await self.store.store_room(
- room_id=room_id,
- room_creator_user_id="",
- is_public=False,
- room_version=room_version_obj,
- )
- except Exception:
- # FIXME
- pass
+ # if this is the first time we've joined this room, it's time to add
+ # a row to `rooms` with the correct room version. If there's already a
+ # row there, we should override it, since it may have been populated
+ # based on an invite request which lied about the room version.
+ #
+ # federation_client.send_join has already checked that the room
+ # version in the received create event is the same as room_version_obj,
+ # so we can rely on it now.
+ #
+ await self.store.upsert_room_on_join(
+ room_id=room_id, room_version=room_version_obj,
+ )
await self._persist_auth_tree(
origin, auth_chain, state, event, room_version_obj
@@ -1534,8 +1522,15 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
+ is_published = await self.store.is_room_published(event.room_id)
+
if not self.spam_checker.user_may_invite(
- event.sender, event.state_key, event.room_id
+ event.sender,
+ event.state_key,
+ None,
+ room_id=event.room_id,
+ new_room=False,
+ published_room=is_published,
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
@@ -1558,6 +1553,13 @@ class FederationHandler(BaseHandler):
if event.state_key == self._server_notices_mxid:
raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
+ # keep a record of the room version, if we don't yet know it.
+ # (this may get overwritten if we later get a different room version in a
+ # join dance).
+ await self._maybe_store_room_on_invite(
+ room_id=event.room_id, room_version=room_version
+ )
+
event.internal_metadata.outlier = True
event.internal_metadata.out_of_band_membership = True
@@ -2152,7 +2154,7 @@ class FederationHandler(BaseHandler):
# Now get the current auth_chain for the event.
local_auth_chain = await self.store.get_auth_chain(
- [auth_id for auth_id in event.auth_event_ids()], include_given=True
+ list(event.auth_event_ids()), include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
@@ -2654,7 +2656,7 @@ class FederationHandler(BaseHandler):
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
else:
- destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
+ destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
yield self.federation_client.forward_third_party_invite(
destinations, room_id, event_dict
)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 23f07832e7..94b5279aa6 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,6 +33,7 @@ from synapse.api.errors import (
CodeMessageException,
Codes,
HttpResponseException,
+ ProxiedRequestError,
SynapseError,
)
from synapse.config.emailconfig import ThreepidBehaviour
@@ -51,14 +52,21 @@ class IdentityHandler(BaseHandler):
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
- self.http_client = SimpleHttpClient(hs)
+ self.hs = hs
+ self.http_client = hs.get_simple_http_client()
# We create a blacklisting instance of SimpleHttpClient for contacting identity
# servers specified by clients
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
self.federation_http_client = hs.get_http_client()
- self.hs = hs
+
+ self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
+ self.trust_any_id_server_just_for_testing_do_not_use = (
+ hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
+ )
+ self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls
+ self._enable_lookup = hs.config.enable_3pid_lookup
@defer.inlineCallbacks
def threepid_from_creds(self, id_server, creds):
@@ -94,7 +102,15 @@ class IdentityHandler(BaseHandler):
query_params = {"sid": session_id, "client_secret": client_secret}
- url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ if id_server in self.rewrite_identity_server_urls:
+ id_server = self.rewrite_identity_server_urls[id_server]
+
+ url = "https://%s%s" % (
+ id_server,
+ "/_matrix/identity/api/v1/3pid/getValidated3pid",
+ )
try:
data = yield self.http_client.get_json(url, query_params)
@@ -149,14 +165,24 @@ class IdentityHandler(BaseHandler):
if id_access_token is None:
use_v2 = False
+ # if we have a rewrite rule set for the identity server,
+ # apply it now, but only for sending the request (not
+ # storing in the database).
+ if id_server in self.rewrite_identity_server_urls:
+ id_server_host = self.rewrite_identity_server_urls[id_server]
+ else:
+ id_server_host = id_server
+
# Decide which API endpoint URLs to use
headers = {}
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
if use_v2:
- bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
- headers["Authorization"] = create_id_access_token_header(id_access_token)
+ bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server_host,)
+ headers["Authorization"] = create_id_access_token_header(
+ id_access_token
+ )
else:
- bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+ bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server_host,)
try:
# Use the blacklisting http client as this call is only to identity servers
@@ -263,6 +289,16 @@ class IdentityHandler(BaseHandler):
)
headers = {b"Authorization": auth_headers}
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ #
+ # Note that destination_is has to be the real id_server, not
+ # the server we connect to.
+ if id_server in self.rewrite_identity_server_urls:
+ id_server = self.rewrite_identity_server_urls[id_server]
+
+ url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
+
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
@@ -400,6 +436,12 @@ class IdentityHandler(BaseHandler):
"client_secret": client_secret,
"send_attempt": send_attempt,
}
+
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ if id_server in self.rewrite_identity_server_urls:
+ id_server = self.rewrite_identity_server_urls[id_server]
+
if next_link:
params["next_link"] = next_link
@@ -466,6 +508,10 @@ class IdentityHandler(BaseHandler):
"details and update your config file."
)
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ if id_server in self.rewrite_identity_server_urls:
+ id_server = self.rewrite_identity_server_urls[id_server]
try:
data = yield self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
@@ -566,6 +612,89 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
+ # TODO: The following two methods are used for proxying IS requests using
+ # the CS API. They should be consolidated with those in RoomMemberHandler
+ # https://github.com/matrix-org/synapse-dinsic/issues/25
+
+ @defer.inlineCallbacks
+ def proxy_lookup_3pid(self, id_server, medium, address):
+ """Looks up a 3pid in the passed identity server.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ medium (str): The type of the third party identifier (e.g. "email").
+ address (str): The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ Deferred[dict]: The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ target = self.rewrite_identity_server_urls.get(id_server, id_server)
+
+ try:
+ data = yield self.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/lookup" % (target,),
+ {"medium": medium, "address": address},
+ )
+
+ if "mxid" in data:
+ if "signatures" not in data:
+ raise AuthError(401, "No signatures on 3pid binding")
+ yield self._verify_any_signature(data, id_server)
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %r: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ defer.returnValue(data)
+
+ @defer.inlineCallbacks
+ def proxy_bulk_lookup_3pid(self, id_server, threepids):
+ """Looks up given 3pids in the passed identity server.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ threepids ([[str, str]]): The third party identifiers to lookup, as
+ a list of 2-string sized lists ([medium, address]).
+
+ Returns:
+ Deferred[dict]: The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ target = self.rewrite_identity_server_urls.get(id_server, id_server)
+
+ try:
+ data = yield self.http_client.post_json_get_json(
+ "https://%s/_matrix/identity/api/v1/bulk_lookup" % (target,),
+ {"threepids": threepids},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %r: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ defer.returnValue(data)
+
@defer.inlineCallbacks
def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server.
@@ -581,6 +710,9 @@ class IdentityHandler(BaseHandler):
Returns:
str|None: the matrix ID of the 3pid, or None if it is not recognized.
"""
+ # Rewrite id_server URL if necessary
+ id_server = self._get_id_server_target(id_server)
+
if id_access_token is not None:
try:
results = yield self._lookup_3pid_v2(
@@ -618,7 +750,7 @@ class IdentityHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = yield self.blacklisting_http_client.get_json(
+ data = yield self.http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
{"medium": medium, "address": address},
)
@@ -651,7 +783,7 @@ class IdentityHandler(BaseHandler):
"""
# Check what hashing details are supported by this identity server
try:
- hash_details = yield self.blacklisting_http_client.get_json(
+ hash_details = yield self.http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token},
)
@@ -669,7 +801,7 @@ class IdentityHandler(BaseHandler):
400,
"Non-dict object from %s%s during v2 hash_details request: %s"
% (id_server_scheme, id_server, hash_details),
- )
+ )
# Extract information from hash_details
supported_lookup_algorithms = hash_details.get("algorithms")
@@ -684,7 +816,7 @@ class IdentityHandler(BaseHandler):
400,
"Invalid hash details received from identity server %s%s: %s"
% (id_server_scheme, id_server, hash_details),
- )
+ )
# Check if any of the supported lookup algorithms are present
if LookupAlgorithm.SHA256 in supported_lookup_algorithms:
@@ -718,7 +850,7 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
- lookup_results = yield self.blacklisting_http_client.post_json_get_json(
+ lookup_results = yield self.http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{
"addresses": [lookup_value],
@@ -726,7 +858,7 @@ class IdentityHandler(BaseHandler):
"pepper": lookup_pepper,
},
headers=headers,
- )
+ )
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except Exception as e:
@@ -750,14 +882,15 @@ class IdentityHandler(BaseHandler):
def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
+
for key_name, signature in data["signatures"][server_hostname].items():
- try:
- key_data = yield self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/pubkey/%s"
- % (id_server_scheme, server_hostname, key_name)
- )
- except TimeoutError:
- raise SynapseError(500, "Timed out contacting identity server")
+ target = self.rewrite_identity_server_urls.get(
+ server_hostname, server_hostname
+ )
+
+ key_data = yield self.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/pubkey/%s" % (target, key_name)
+ )
if "public_key" not in key_data:
raise AuthError(
401, "No public key named %s from %s" % (key_name, server_hostname)
@@ -771,6 +904,23 @@ class IdentityHandler(BaseHandler):
)
return
+ raise AuthError(401, "No signature from server %s" % (server_hostname,))
+
+ def _get_id_server_target(self, id_server):
+ """Looks up an id_server's actual http endpoint
+
+ Args:
+ id_server (str): the server name to lookup.
+
+ Returns:
+ the http endpoint to connect to.
+ """
+ if id_server in self.rewrite_identity_server_urls:
+ return self.rewrite_identity_server_urls[id_server]
+
+ return id_server
+
+
@defer.inlineCallbacks
def ask_id_server_for_third_party_invite(
self,
@@ -831,6 +981,9 @@ class IdentityHandler(BaseHandler):
"sender_avatar_url": inviter_avatar_url,
}
+ # Rewrite the identity server URL if necessary
+ id_server = self._get_id_server_target(id_server)
+
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
data = None
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d6be280952..a0103addd3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1016,11 +1016,10 @@ class EventCreationHandler(object):
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
- @defer.inlineCallbacks
- def _bump_active_time(self, user):
+ async def _bump_active_time(self, user):
try:
presence = self.hs.get_presence_handler()
- yield presence.bump_presence_active_time(user)
+ await presence.bump_presence_active_time(user)
except Exception:
logger.exception("Error bumping presence active time")
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
new file mode 100644
index 0000000000..d06b110269
--- /dev/null
+++ b/synapse/handlers/password_policy.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from synapse.api.errors import Codes, PasswordRefusedError
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordPolicyHandler(object):
+ def __init__(self, hs):
+ self.policy = hs.config.password_policy
+ self.enabled = hs.config.password_policy_enabled
+
+ # Regexps for the spec'd policy parameters.
+ self.regexp_digit = re.compile("[0-9]")
+ self.regexp_symbol = re.compile("[^a-zA-Z0-9]")
+ self.regexp_uppercase = re.compile("[A-Z]")
+ self.regexp_lowercase = re.compile("[a-z]")
+
+ def validate_password(self, password):
+ """Checks whether a given password complies with the server's policy.
+
+ Args:
+ password (str): The password to check against the server's policy.
+
+ Raises:
+ PasswordRefusedError: The password doesn't comply with the server's policy.
+ """
+
+ if not self.enabled:
+ return
+
+ minimum_accepted_length = self.policy.get("minimum_length", 0)
+ if len(password) < minimum_accepted_length:
+ raise PasswordRefusedError(
+ msg=(
+ "The password must be at least %d characters long"
+ % minimum_accepted_length
+ ),
+ errcode=Codes.PASSWORD_TOO_SHORT,
+ )
+
+ if (
+ self.policy.get("require_digit", False)
+ and self.regexp_digit.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one digit",
+ errcode=Codes.PASSWORD_NO_DIGIT,
+ )
+
+ if (
+ self.policy.get("require_symbol", False)
+ and self.regexp_symbol.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one symbol",
+ errcode=Codes.PASSWORD_NO_SYMBOL,
+ )
+
+ if (
+ self.policy.get("require_uppercase", False)
+ and self.regexp_uppercase.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one uppercase letter",
+ errcode=Codes.PASSWORD_NO_UPPERCASE,
+ )
+
+ if (
+ self.policy.get("require_lowercase", False)
+ and self.regexp_lowercase.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one lowercase letter",
+ errcode=Codes.PASSWORD_NO_LOWERCASE,
+ )
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 202aa9294f..5526015ddb 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -24,11 +24,12 @@ The methods that define policy are:
import logging
from contextlib import contextmanager
-from typing import Dict, Set
+from typing import Dict, List, Set
from six import iteritems, itervalues
from prometheus_client import Counter
+from typing_extensions import ContextManager
from twisted.internet import defer
@@ -42,10 +43,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
+MYPY = False
+if MYPY:
+ import synapse.server
+
logger = logging.getLogger(__name__)
@@ -97,7 +102,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(object):
def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs
- self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
self.clock = hs.get_clock()
@@ -150,7 +154,7 @@ class PresenceHandler(object):
# Set of users who have presence in the `user_to_current_state` that
# have not yet been persisted
- self.unpersisted_users_changes = set()
+ self.unpersisted_users_changes = set() # type: Set[str]
hs.get_reactor().addSystemEventTrigger(
"before",
@@ -160,12 +164,11 @@ class PresenceHandler(object):
self._on_shutdown,
)
- self.serial_to_user = {}
self._next_serial = 1
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
- self.user_to_num_current_syncs = {}
+ self.user_to_num_current_syncs = {} # type: Dict[str, int]
# Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never
@@ -213,8 +216,7 @@ class PresenceHandler(object):
self._event_pos = self.store.get_current_events_token()
self._event_processing = False
- @defer.inlineCallbacks
- def _on_shutdown(self):
+ async def _on_shutdown(self):
"""Gets called when shutting down. This lets us persist any updates that
we haven't yet persisted, e.g. updates that only changes some internal
timers. This allows changes to persist across startup without having to
@@ -235,7 +237,7 @@ class PresenceHandler(object):
if self.unpersisted_users_changes:
- yield self.store.update_presence(
+ await self.store.update_presence(
[
self.user_to_current_state[user_id]
for user_id in self.unpersisted_users_changes
@@ -243,8 +245,7 @@ class PresenceHandler(object):
)
logger.info("Finished _on_shutdown")
- @defer.inlineCallbacks
- def _persist_unpersisted_changes(self):
+ async def _persist_unpersisted_changes(self):
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
"""
@@ -253,12 +254,11 @@ class PresenceHandler(object):
if unpersisted:
logger.info("Persisting %d unpersisted presence updates", len(unpersisted))
- yield self.store.update_presence(
+ await self.store.update_presence(
[self.user_to_current_state[user_id] for user_id in unpersisted]
)
- @defer.inlineCallbacks
- def _update_states(self, new_states):
+ async def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
should be sent to clients/servers.
@@ -267,7 +267,7 @@ class PresenceHandler(object):
with Measure(self.clock, "presence_update_states"):
- # NOTE: We purposefully don't yield between now and when we've
+ # NOTE: We purposefully don't await between now and when we've
# calculated what we want to do with the new states, to avoid races.
to_notify = {} # Changes we want to notify everyone about
@@ -311,9 +311,9 @@ class PresenceHandler(object):
if to_notify:
notified_presence_counter.inc(len(to_notify))
- yield self._persist_and_notify(list(to_notify.values()))
+ await self._persist_and_notify(list(to_notify.values()))
- self.unpersisted_users_changes |= set(s.user_id for s in new_states)
+ self.unpersisted_users_changes |= {s.user_id for s in new_states}
self.unpersisted_users_changes -= set(to_notify.keys())
to_federation_ping = {
@@ -326,7 +326,7 @@ class PresenceHandler(object):
self._push_to_remotes(to_federation_ping.values())
- def _handle_timeouts(self):
+ async def _handle_timeouts(self):
"""Checks the presence of users that have timed out and updates as
appropriate.
"""
@@ -368,10 +368,9 @@ class PresenceHandler(object):
now=now,
)
- return self._update_states(changes)
+ return await self._update_states(changes)
- @defer.inlineCallbacks
- def bump_presence_active_time(self, user):
+ async def bump_presence_active_time(self, user):
"""We've seen the user do something that indicates they're interacting
with the app.
"""
@@ -383,16 +382,17 @@ class PresenceHandler(object):
bump_active_time_counter.inc()
- prev_state = yield self.current_state_for_user(user_id)
+ prev_state = await self.current_state_for_user(user_id)
new_fields = {"last_active_ts": self.clock.time_msec()}
if prev_state.state == PresenceState.UNAVAILABLE:
new_fields["state"] = PresenceState.ONLINE
- yield self._update_states([prev_state.copy_and_replace(**new_fields)])
+ await self._update_states([prev_state.copy_and_replace(**new_fields)])
- @defer.inlineCallbacks
- def user_syncing(self, user_id, affect_presence=True):
+ async def user_syncing(
+ self, user_id: str, affect_presence: bool = True
+ ) -> ContextManager[None]:
"""Returns a context manager that should surround any stream requests
from the user.
@@ -415,11 +415,11 @@ class PresenceHandler(object):
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
- prev_state = yield self.current_state_for_user(user_id)
+ prev_state = await self.current_state_for_user(user_id)
if prev_state.state == PresenceState.OFFLINE:
# If they're currently offline then bring them online, otherwise
# just update the last sync times.
- yield self._update_states(
+ await self._update_states(
[
prev_state.copy_and_replace(
state=PresenceState.ONLINE,
@@ -429,7 +429,7 @@ class PresenceHandler(object):
]
)
else:
- yield self._update_states(
+ await self._update_states(
[
prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec()
@@ -437,13 +437,12 @@ class PresenceHandler(object):
]
)
- @defer.inlineCallbacks
- def _end():
+ async def _end():
try:
self.user_to_num_current_syncs[user_id] -= 1
- prev_state = yield self.current_state_for_user(user_id)
- yield self._update_states(
+ prev_state = await self.current_state_for_user(user_id)
+ await self._update_states(
[
prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec()
@@ -480,8 +479,7 @@ class PresenceHandler(object):
else:
return set()
- @defer.inlineCallbacks
- def update_external_syncs_row(
+ async def update_external_syncs_row(
self, process_id, user_id, is_syncing, sync_time_msec
):
"""Update the syncing users for an external process as a delta.
@@ -494,8 +492,8 @@ class PresenceHandler(object):
is_syncing (bool): Whether or not the user is now syncing
sync_time_msec(int): Time in ms when the user was last syncing
"""
- with (yield self.external_sync_linearizer.queue(process_id)):
- prev_state = yield self.current_state_for_user(user_id)
+ with (await self.external_sync_linearizer.queue(process_id)):
+ prev_state = await self.current_state_for_user(user_id)
process_presence = self.external_process_to_current_syncs.setdefault(
process_id, set()
@@ -525,25 +523,24 @@ class PresenceHandler(object):
process_presence.discard(user_id)
if updates:
- yield self._update_states(updates)
+ await self._update_states(updates)
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
- @defer.inlineCallbacks
- def update_external_syncs_clear(self, process_id):
+ async def update_external_syncs_clear(self, process_id):
"""Marks all users that had been marked as syncing by a given process
as offline.
Used when the process has stopped/disappeared.
"""
- with (yield self.external_sync_linearizer.queue(process_id)):
+ with (await self.external_sync_linearizer.queue(process_id)):
process_presence = self.external_process_to_current_syncs.pop(
process_id, set()
)
- prev_states = yield self.current_state_for_users(process_presence)
+ prev_states = await self.current_state_for_users(process_presence)
time_now_ms = self.clock.time_msec()
- yield self._update_states(
+ await self._update_states(
[
prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
for prev_state in itervalues(prev_states)
@@ -551,15 +548,13 @@ class PresenceHandler(object):
)
self.external_process_last_updated_ms.pop(process_id, None)
- @defer.inlineCallbacks
- def current_state_for_user(self, user_id):
+ async def current_state_for_user(self, user_id):
"""Get the current presence state for a user.
"""
- res = yield self.current_state_for_users([user_id])
+ res = await self.current_state_for_users([user_id])
return res[user_id]
- @defer.inlineCallbacks
- def current_state_for_users(self, user_ids):
+ async def current_state_for_users(self, user_ids):
"""Get the current presence state for multiple users.
Returns:
@@ -574,7 +569,7 @@ class PresenceHandler(object):
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
- res = yield self.store.get_presence_for_users(missing)
+ res = await self.store.get_presence_for_users(missing)
states.update(res)
missing = [user_id for user_id, state in iteritems(states) if not state]
@@ -587,14 +582,13 @@ class PresenceHandler(object):
return states
- @defer.inlineCallbacks
- def _persist_and_notify(self, states):
+ async def _persist_and_notify(self, states):
"""Persist states in the database, poke the notifier and send to
interested remote servers
"""
- stream_id, max_token = yield self.store.update_presence(states)
+ stream_id, max_token = await self.store.update_presence(states)
- parties = yield get_interested_parties(self.store, states)
+ parties = await get_interested_parties(self.store, states)
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
@@ -606,9 +600,8 @@ class PresenceHandler(object):
self._push_to_remotes(states)
- @defer.inlineCallbacks
- def notify_for_states(self, state, stream_id):
- parties = yield get_interested_parties(self.store, [state])
+ async def notify_for_states(self, state, stream_id):
+ parties = await get_interested_parties(self.store, [state])
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
@@ -626,8 +619,7 @@ class PresenceHandler(object):
"""
self.federation.send_presence(states)
- @defer.inlineCallbacks
- def incoming_presence(self, origin, content):
+ async def incoming_presence(self, origin, content):
"""Called when we receive a `m.presence` EDU from a remote server.
"""
now = self.clock.time_msec()
@@ -670,21 +662,19 @@ class PresenceHandler(object):
new_fields["status_msg"] = push.get("status_msg", None)
new_fields["currently_active"] = push.get("currently_active", False)
- prev_state = yield self.current_state_for_user(user_id)
+ prev_state = await self.current_state_for_user(user_id)
updates.append(prev_state.copy_and_replace(**new_fields))
if updates:
federation_presence_counter.inc(len(updates))
- yield self._update_states(updates)
+ await self._update_states(updates)
- @defer.inlineCallbacks
- def get_state(self, target_user, as_event=False):
- results = yield self.get_states([target_user.to_string()], as_event=as_event)
+ async def get_state(self, target_user, as_event=False):
+ results = await self.get_states([target_user.to_string()], as_event=as_event)
return results[0]
- @defer.inlineCallbacks
- def get_states(self, target_user_ids, as_event=False):
+ async def get_states(self, target_user_ids, as_event=False):
"""Get the presence state for users.
Args:
@@ -695,10 +685,10 @@ class PresenceHandler(object):
list
"""
- updates = yield self.current_state_for_users(target_user_ids)
+ updates = await self.current_state_for_users(target_user_ids)
updates = list(updates.values())
- for user_id in set(target_user_ids) - set(u.user_id for u in updates):
+ for user_id in set(target_user_ids) - {u.user_id for u in updates}:
updates.append(UserPresenceState.default(user_id))
now = self.clock.time_msec()
@@ -713,8 +703,7 @@ class PresenceHandler(object):
else:
return updates
- @defer.inlineCallbacks
- def set_state(self, target_user, state, ignore_status_msg=False):
+ async def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
status_msg = state.get("status_msg", None)
@@ -730,7 +719,7 @@ class PresenceHandler(object):
user_id = target_user.to_string()
- prev_state = yield self.current_state_for_user(user_id)
+ prev_state = await self.current_state_for_user(user_id)
new_fields = {"state": presence}
@@ -741,16 +730,15 @@ class PresenceHandler(object):
if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec()
- yield self._update_states([prev_state.copy_and_replace(**new_fields)])
+ await self._update_states([prev_state.copy_and_replace(**new_fields)])
- @defer.inlineCallbacks
- def is_visible(self, observed_user, observer_user):
+ async def is_visible(self, observed_user, observer_user):
"""Returns whether a user can see another user's presence.
"""
- observer_room_ids = yield self.store.get_rooms_for_user(
+ observer_room_ids = await self.store.get_rooms_for_user(
observer_user.to_string()
)
- observed_room_ids = yield self.store.get_rooms_for_user(
+ observed_room_ids = await self.store.get_rooms_for_user(
observed_user.to_string()
)
@@ -759,8 +747,7 @@ class PresenceHandler(object):
return False
- @defer.inlineCallbacks
- def get_all_presence_updates(self, last_id, current_id):
+ async def get_all_presence_updates(self, last_id, current_id):
"""
Gets a list of presence update rows from between the given stream ids.
Each row has:
@@ -775,7 +762,7 @@ class PresenceHandler(object):
"""
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
- rows = yield self.store.get_all_presence_updates(last_id, current_id)
+ rows = await self.store.get_all_presence_updates(last_id, current_id)
return rows
def notify_new_event(self):
@@ -786,20 +773,18 @@ class PresenceHandler(object):
if self._event_processing:
return
- @defer.inlineCallbacks
- def _process_presence():
+ async def _process_presence():
assert not self._event_processing
self._event_processing = True
try:
- yield self._unsafe_process()
+ await self._unsafe_process()
finally:
self._event_processing = False
run_as_background_process("presence.notify_new_event", _process_presence)
- @defer.inlineCallbacks
- def _unsafe_process(self):
+ async def _unsafe_process(self):
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "presence_delta"):
@@ -812,10 +797,10 @@ class PresenceHandler(object):
self._event_pos,
room_max_stream_ordering,
)
- max_pos, deltas = yield self.store.get_current_state_deltas(
+ max_pos, deltas = await self.store.get_current_state_deltas(
self._event_pos, room_max_stream_ordering
)
- yield self._handle_state_delta(deltas)
+ await self._handle_state_delta(deltas)
self._event_pos = max_pos
@@ -824,8 +809,7 @@ class PresenceHandler(object):
max_pos
)
- @defer.inlineCallbacks
- def _handle_state_delta(self, deltas):
+ async def _handle_state_delta(self, deltas):
"""Process current state deltas to find new joins that need to be
handled.
"""
@@ -846,13 +830,13 @@ class PresenceHandler(object):
# joins.
continue
- event = yield self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
if not event or event.content.get("membership") != Membership.JOIN:
# We only care about joins
continue
if prev_event_id:
- prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if (
prev_event
and prev_event.content.get("membership") == Membership.JOIN
@@ -860,10 +844,9 @@ class PresenceHandler(object):
# Ignore changes to join events.
continue
- yield self._on_user_joined_room(room_id, state_key)
+ await self._on_user_joined_room(room_id, state_key)
- @defer.inlineCallbacks
- def _on_user_joined_room(self, room_id, user_id):
+ async def _on_user_joined_room(self, room_id, user_id):
"""Called when we detect a user joining the room via the current state
delta stream.
@@ -882,11 +865,11 @@ class PresenceHandler(object):
# TODO: We should be able to filter the hosts down to those that
# haven't previously seen the user
- state = yield self.current_state_for_user(user_id)
- hosts = yield self.state.get_current_hosts_in_room(room_id)
+ state = await self.current_state_for_user(user_id)
+ hosts = await self.state.get_current_hosts_in_room(room_id)
# Filter out ourselves.
- hosts = set(host for host in hosts if host != self.server_name)
+ hosts = {host for host in hosts if host != self.server_name}
self.federation.send_presence_to_destinations(
states=[state], destinations=hosts
@@ -903,10 +886,10 @@ class PresenceHandler(object):
# TODO: Check that this is actually a new server joining the
# room.
- user_ids = yield self.state.get_current_users_in_room(room_id)
+ user_ids = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids))
- states = yield self.current_state_for_users(user_ids)
+ states = await self.current_state_for_users(user_ids)
# Filter out old presence, i.e. offline presence states where
# the user hasn't been active for a week. We can change this
@@ -996,9 +979,8 @@ class PresenceEventSource(object):
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
- @defer.inlineCallbacks
@log_function
- def get_new_events(
+ async def get_new_events(
self,
user,
from_key,
@@ -1045,7 +1027,7 @@ class PresenceEventSource(object):
presence = self.get_presence_handler()
stream_change_cache = self.store.presence_stream_cache
- users_interested_in = yield self._get_interested_in(user, explicit_room_id)
+ users_interested_in = await self._get_interested_in(user, explicit_room_id)
user_ids_changed = set()
changed = None
@@ -1071,7 +1053,7 @@ class PresenceEventSource(object):
else:
user_ids_changed = users_interested_in
- updates = yield presence.current_state_for_users(user_ids_changed)
+ updates = await presence.current_state_for_users(user_ids_changed)
if include_offline:
return (list(updates.values()), max_token)
@@ -1084,11 +1066,11 @@ class PresenceEventSource(object):
def get_current_key(self):
return self.store.get_current_presence_token()
- def get_pagination_rows(self, user, pagination_config, key):
- return self.get_new_events(user, from_key=None, include_offline=False)
+ async def get_pagination_rows(self, user, pagination_config, key):
+ return await self.get_new_events(user, from_key=None, include_offline=False)
- @cachedInlineCallbacks(num_args=2, cache_context=True)
- def _get_interested_in(self, user, explicit_room_id, cache_context):
+ @cached(num_args=2, cache_context=True)
+ async def _get_interested_in(self, user, explicit_room_id, cache_context):
"""Returns the set of users that the given user should see presence
updates for
"""
@@ -1096,13 +1078,13 @@ class PresenceEventSource(object):
users_interested_in = set()
users_interested_in.add(user_id) # So that we receive our own presence
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id, on_invalidate=cache_context.invalidate
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
- user_ids = yield self.store.get_users_in_room(
+ user_ids = await self.store.get_users_in_room(
explicit_room_id, on_invalidate=cache_context.invalidate
)
users_interested_in.update(user_ids)
@@ -1277,8 +1259,8 @@ def get_interested_parties(store, states):
2-tuple: `(room_ids_to_states, users_to_states)`,
with each item being a dict of `entity_name` -> `[UserPresenceState]`
"""
- room_ids_to_states = {}
- users_to_states = {}
+ room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]]
+ users_to_states = {} # type: Dict[str, List[UserPresenceState]]
for state in states:
room_ids = yield store.get_rooms_for_user(state.user_id)
for room_id in room_ids:
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index f9579d69ee..824fadf028 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,8 +17,11 @@
import logging
from six import raise_from
+from six.moves import range
-from twisted.internet import defer
+from signedjson.sign import sign_json
+
+from twisted.internet import defer, reactor
from synapse.api.errors import (
AuthError,
@@ -27,8 +31,9 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import UserID, create_requester, get_domain_from_id
from ._base import BaseHandler
@@ -46,6 +51,8 @@ class BaseProfileHandler(BaseHandler):
subclass MasterProfileHandler
"""
+ PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000
+
def __init__(self, hs):
super(BaseProfileHandler, self).__init__(hs)
@@ -56,6 +63,87 @@ class BaseProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
+
+ self.max_avatar_size = hs.config.max_avatar_size
+ self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes
+
+ if hs.config.worker_app is None:
+ self.clock.looping_call(
+ self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
+ )
+
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ reactor.callWhenRunning(self._assign_profile_replication_batches)
+ reactor.callWhenRunning(self._replicate_profiles)
+ # Add a looping call to replicate_profiles: this handles retries
+ # if the replication is unsuccessful when the user updated their
+ # profile.
+ self.clock.looping_call(
+ self._replicate_profiles, self.PROFILE_REPLICATE_INTERVAL
+ )
+
+ @defer.inlineCallbacks
+ def _assign_profile_replication_batches(self):
+ """If no profile replication has been done yet, allocate replication batch
+ numbers to each profile to start the replication process.
+ """
+ logger.info("Assigning profile batch numbers...")
+ total = 0
+ while True:
+ assigned = yield self.store.assign_profile_batch()
+ total += assigned
+ if assigned == 0:
+ break
+ logger.info("Assigned %d profile batch numbers", total)
+
+ @defer.inlineCallbacks
+ def _replicate_profiles(self):
+ """If any profile data has been updated and not pushed to the replication targets,
+ replicate it.
+ """
+ host_batches = yield self.store.get_replication_hosts()
+ latest_batch = yield self.store.get_latest_profile_replication_batch_number()
+ if latest_batch is None:
+ latest_batch = -1
+ for repl_host in self.hs.config.replicate_user_profiles_to:
+ if repl_host not in host_batches:
+ host_batches[repl_host] = -1
+ try:
+ for i in range(host_batches[repl_host] + 1, latest_batch + 1):
+ yield self._replicate_host_profile_batch(repl_host, i)
+ except Exception:
+ logger.exception(
+ "Exception while replicating to %s: aborting for now", repl_host
+ )
+
+ @defer.inlineCallbacks
+ def _replicate_host_profile_batch(self, host, batchnum):
+ logger.info("Replicating profile batch %d to %s", batchnum, host)
+ batch_rows = yield self.store.get_profile_batch(batchnum)
+ batch = {
+ UserID(r["user_id"], self.hs.hostname).to_string(): (
+ {"display_name": r["displayname"], "avatar_url": r["avatar_url"]}
+ if r["active"]
+ else None
+ )
+ for r in batch_rows
+ }
+
+ url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,)
+ body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
+ signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
+ try:
+ yield self.http_client.post_json_get_json(url, signed_body)
+ yield self.store.update_replication_batch_for_host(host, batchnum)
+ logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host)
+ except Exception:
+ # This will get retried when the looping call next comes around
+ logger.exception(
+ "Failed to replicate profile batch %d to %s", batchnum, host
+ )
+ raise
+
@defer.inlineCallbacks
def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
@@ -154,9 +242,16 @@ class BaseProfileHandler(BaseHandler):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
- if not by_admin and target_user != requester.user:
+ if not by_admin and requester and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
+ if not by_admin and self.hs.config.disable_set_displayname:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ if profile.display_name:
+ raise SynapseError(
+ 400, "Changing displayname is disabled on this server"
+ )
+
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -165,7 +260,23 @@ class BaseProfileHandler(BaseHandler):
if new_displayname == "":
new_displayname = None
- yield self.store.set_profile_displayname(target_user.localpart, new_displayname)
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
+ # If the admin changes the display name of a user, the requesting user cannot send
+ # the join event to update the displayname in the rooms.
+ # This must be done by the target user himself.
+ if by_admin:
+ requester = create_requester(target_user)
+
+ yield self.store.set_profile_displayname(
+ target_user.localpart, new_displayname, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
@@ -173,7 +284,39 @@ class BaseProfileHandler(BaseHandler):
target_user.to_string(), profile
)
- yield self._update_join_states(requester, target_user)
+ if requester:
+ yield self._update_join_states(requester, target_user)
+
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ @defer.inlineCallbacks
+ def set_active(self, target_user, active, hide):
+ """
+ Sets the 'active' flag on a user profile. If set to false, the user
+ account is considered deactivated or hidden.
+
+ If 'hide' is true, then we interpret active=False as a request to try to
+ hide the user rather than deactivating it. This means withholding the
+ profile from replication (and mark it as inactive) rather than clearing
+ the profile from the HS DB. Note that unlike set_displayname and
+ set_avatar_url, this does *not* perform authorization checks! This is
+ because the only place it's used currently is in account deactivation
+ where we've already done these checks anyway.
+ """
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+ yield self.store.set_profile_active(
+ target_user.localpart, active, hide, new_batchnum
+ )
+
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@@ -212,12 +355,63 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
+ if not by_admin and self.hs.config.disable_set_avatar_url:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ if profile.avatar_url:
+ raise SynapseError(
+ 400, "Changing avatar url is disabled on this server"
+ )
+
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
- yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+ # Enforce a max avatar size if one is defined
+ if self.max_avatar_size or self.allowed_avatar_mimetypes:
+ media_id = self._validate_and_parse_media_id_from_avatar_url(new_avatar_url)
+
+ # Check that this media exists locally
+ media_info = yield self.store.get_local_media(media_id)
+ if not media_info:
+ raise SynapseError(
+ 400, "Unknown media id supplied", errcode=Codes.NOT_FOUND
+ )
+
+ # Ensure avatar does not exceed max allowed avatar size
+ media_size = media_info["media_length"]
+ if self.max_avatar_size and media_size > self.max_avatar_size:
+ raise SynapseError(
+ 400,
+ "Avatars must be less than %s bytes in size"
+ % (self.max_avatar_size,),
+ errcode=Codes.TOO_LARGE,
+ )
+
+ # Ensure the avatar's file type is allowed
+ if (
+ self.allowed_avatar_mimetypes
+ and media_info["media_type"] not in self.allowed_avatar_mimetypes
+ ):
+ raise SynapseError(
+ 400, "Avatar file type '%s' not allowed" % media_info["media_type"]
+ )
+
+ # Same like set_displayname
+ if by_admin:
+ requester = create_requester(target_user)
+
+ yield self.store.set_profile_avatar_url(
+ target_user.localpart, new_avatar_url, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
@@ -227,6 +421,23 @@ class BaseProfileHandler(BaseHandler):
yield self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ def _validate_and_parse_media_id_from_avatar_url(self, mxc):
+ """Validate and parse a provided avatar url and return the local media id
+
+ Args:
+ mxc (str): A mxc URL
+
+ Returns:
+ str: The ID of the media
+ """
+ avatar_pieces = mxc.split("/")
+ if len(avatar_pieces) != 4 or avatar_pieces[0] != "mxc:":
+ raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc)
+ return avatar_pieces[-1]
+
@defer.inlineCallbacks
def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
@@ -282,7 +493,7 @@ class BaseProfileHandler(BaseHandler):
@defer.inlineCallbacks
def check_profile_query_allowed(self, target_user, requester=None):
"""Checks whether a profile query is allowed. If the
- 'require_auth_for_profile_requests' config flag is set to True and a
+ 'limit_profile_requests_to_known_users' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
share a room.
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 9283c039e3..8bc100db42 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -94,7 +94,7 @@ class ReceiptsHandler(BaseHandler):
# no new receipts
return False
- affected_room_ids = list(set([r.room_id for r in receipts]))
+ affected_room_ids = list({r.room_id for r in receipts})
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 7ffc194f0c..696d90996a 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -49,6 +49,7 @@ class RegistrationHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()
@@ -61,6 +62,8 @@ class RegistrationHandler(BaseHandler):
)
self._server_notices_mxid = hs.config.server_notices_mxid
+ self._show_in_user_directory = self.hs.config.show_users_in_user_directory
+
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -203,6 +206,11 @@ class RegistrationHandler(BaseHandler):
address=address,
)
+ if default_display_name:
+ yield self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
+ )
+
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(localpart)
yield self.user_directory_handler.handle_local_profile_change(
@@ -233,6 +241,10 @@ class RegistrationHandler(BaseHandler):
address=address,
)
+ yield self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
+ )
+
# Successfully registered
break
except SynapseError:
@@ -262,6 +274,14 @@ class RegistrationHandler(BaseHandler):
# Bind email to new account
yield self._register_email_threepid(user_id, threepid_dict, None)
+ # Prevent the new user from showing up in the user directory if the server
+ # mandates it.
+ if not self._show_in_user_directory:
+ yield self.store.add_account_data_for_user(
+ user_id, "im.vector.hide_profile", {"hide_profile": True}
+ )
+ yield self.profile_handler.set_active(user, False, True)
+
return user_id
@defer.inlineCallbacks
@@ -328,7 +348,9 @@ class RegistrationHandler(BaseHandler):
yield self._auto_join_rooms(user_id)
@defer.inlineCallbacks
- def appservice_register(self, user_localpart, as_token):
+ def appservice_register(self, user_localpart, as_token, password, display_name):
+ # FIXME: this should be factored out and merged with normal register()
+
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -347,12 +369,29 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service
)
+ password_hash = ""
+ if password:
+ password_hash = yield self.auth_handler().hash(password)
+
+ display_name = display_name or user.localpart
+
yield self.register_with_store(
user_id=user_id,
- password_hash="",
+ password_hash=password_hash,
appservice_id=service_id,
- create_profile_with_displayname=user.localpart,
+ create_profile_with_displayname=display_name,
)
+
+ yield self.profile_handler.set_displayname(
+ user, None, display_name, by_admin=True
+ )
+
+ if self.hs.config.user_directory_search_all_users:
+ profile = yield self.store.get_profileinfo(user_localpart)
+ yield self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
return user_id
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
@@ -380,6 +419,39 @@ class RegistrationHandler(BaseHandler):
)
@defer.inlineCallbacks
+ def shadow_register(self, localpart, display_name, auth_result, params):
+ """Invokes the current registration on another server, using
+ shared secret registration, passing in any auth_results from
+ other registration UI auth flows (e.g. validated 3pids)
+ Useful for setting up shadow/backup accounts on a parallel deployment.
+ """
+
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token),
+ {
+ # XXX: auth_result is an unspecified extension for shadow registration
+ "auth_result": auth_result,
+ # XXX: another unspecified extension for shadow registration to ensure
+ # that the displayname is correctly set by the masters erver
+ "display_name": display_name,
+ "username": localpart,
+ "password": params.get("password"),
+ "bind_email": params.get("bind_email"),
+ "bind_msisdn": params.get("bind_msisdn"),
+ "device_id": params.get("device_id"),
+ "initial_device_display_name": params.get(
+ "initial_device_display_name"
+ ),
+ "inhibit_login": False,
+ "access_token": as_token,
+ },
+ )
+
+ @defer.inlineCallbacks
def _generate_user_id(self):
if self._next_generated_user_id is None:
with (yield self._generate_user_id_linearizer.queue(())):
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 49ec2f48bc..f0dfcb9158 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -64,6 +64,7 @@ class RoomCreationHandler(BaseHandler):
"history_visibility": "shared",
"original_invitees_have_ops": False,
"guest_can_join": True,
+ "encryption_alg": "m.megolm.v1.aes-sha2",
"power_level_content_override": {"invite": 0},
},
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
@@ -71,6 +72,7 @@ class RoomCreationHandler(BaseHandler):
"history_visibility": "shared",
"original_invitees_have_ops": True,
"guest_can_join": True,
+ "encryption_alg": "m.megolm.v1.aes-sha2",
"power_level_content_override": {"invite": 0},
},
RoomCreationPreset.PUBLIC_CHAT: {
@@ -149,7 +151,9 @@ class RoomCreationHandler(BaseHandler):
return ret
@defer.inlineCallbacks
- def _upgrade_room(self, requester, old_room_id, new_version):
+ def _upgrade_room(
+ self, requester: Requester, old_room_id: str, new_version: RoomVersion
+ ):
user_id = requester.user.to_string()
# start by allocating a new room id
@@ -335,7 +339,19 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_create_room(user_id):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to create rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ user_id, invite_list=[], third_party_invite_list=[], cloning=True
+ ):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -353,7 +369,7 @@ class RoomCreationHandler(BaseHandler):
# If so, mark the new room as non-federatable as well
creation_content["m.federate"] = False
- initial_state = dict()
+ initial_state = {}
# Replicate relevant room events
types_to_copy = (
@@ -448,19 +464,21 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def _move_aliases_to_new_room(
- self, requester, old_room_id, new_room_id, old_room_state
+ self,
+ requester: Requester,
+ old_room_id: str,
+ new_room_id: str,
+ old_room_state: StateMap[str],
):
directory_handler = self.hs.get_handlers().directory_handler
aliases = yield self.store.get_aliases_for_room(old_room_id)
# check to see if we have a canonical alias.
- canonical_alias = None
+ canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event_id:
canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
- if canonical_alias_event:
- canonical_alias = canonical_alias_event.content.get("alias", "")
# first we try to remove the aliases from the old room (we suppress sending
# the room_aliases event until the end).
@@ -488,19 +506,6 @@ class RoomCreationHandler(BaseHandler):
if not removed_aliases:
return
- try:
- # this can fail if, for some reason, our user doesn't have perms to send
- # m.room.aliases events in the old room (note that we've already checked that
- # they have perms to send a tombstone event, so that's not terribly likely).
- #
- # If that happens, it's regrettable, but we should carry on: it's the same
- # as when you remove an alias from the directory normally - it just means that
- # the aliases event gets out of sync with the directory
- # (cf https://github.com/vector-im/riot-web/issues/2369)
- yield directory_handler.send_room_alias_update_event(requester, old_room_id)
- except AuthError as e:
- logger.warning("Failed to send updated alias event on old room: %s", e)
-
# we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases:
try:
@@ -517,8 +522,10 @@ class RoomCreationHandler(BaseHandler):
# checking module decides it shouldn't, or similar.
logger.error("Error adding alias %s to new room: %s", alias, e)
+ # If a canonical alias event existed for the old room, fire a canonical
+ # alias event for the new room with a copy of the information.
try:
- if canonical_alias and (canonical_alias in removed_aliases):
+ if canonical_alias_event:
yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
@@ -526,12 +533,10 @@ class RoomCreationHandler(BaseHandler):
"state_key": "",
"room_id": new_room_id,
"sender": requester.user.to_string(),
- "content": {"alias": canonical_alias},
+ "content": canonical_alias_event.content,
},
ratelimit=False,
)
-
- yield directory_handler.send_room_alias_update_event(requester, new_room_id)
except SynapseError as e:
# again I'm not really expecting this to fail, but if it does, I'd rather
# we returned the new room to the client at this point.
@@ -587,8 +592,14 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
+ invite_list = config.get("invite", [])
+ invite_3pid_list = config.get("invite_3pid", [])
+
if not is_requester_admin and not self.spam_checker.user_may_create_room(
- user_id
+ user_id,
+ invite_list=invite_list,
+ third_party_invite_list=invite_3pid_list,
+ cloning=False,
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -623,7 +634,6 @@ class RoomCreationHandler(BaseHandler):
else:
room_alias = None
- invite_list = config.get("invite", [])
for i in invite_list:
try:
uid = UserID.from_string(i)
@@ -645,8 +655,6 @@ class RoomCreationHandler(BaseHandler):
% (user_id,),
)
- invite_3pid_list = config.get("invite_3pid", [])
-
visibility = config.get("visibility", None)
is_public = visibility == "public"
@@ -735,6 +743,7 @@ class RoomCreationHandler(BaseHandler):
"invite",
ratelimit=False,
content=content,
+ new_room=True,
)
for invite_3pid in invite_3pid_list:
@@ -750,6 +759,7 @@ class RoomCreationHandler(BaseHandler):
id_server,
requester,
txn_id=None,
+ new_room=True,
id_access_token=id_access_token,
)
@@ -757,7 +767,6 @@ class RoomCreationHandler(BaseHandler):
if room_alias:
result["room_alias"] = room_alias.to_string()
- yield directory_handler.send_room_alias_update_event(requester, room_id)
return result
@@ -807,6 +816,7 @@ class RoomCreationHandler(BaseHandler):
"join",
ratelimit=False,
content=creator_join_profile,
+ new_room=True,
)
# We treat the power levels override specially as this needs to be one
@@ -877,6 +887,13 @@ class RoomCreationHandler(BaseHandler):
for (etype, state_key), content in initial_state.items():
yield send(etype=etype, state_key=state_key, content=content)
+ if "encryption_alg" in config:
+ yield send(
+ etype=EventTypes.Encryption,
+ state_key="",
+ content={"algorithm": config["encryption_alg"]},
+ )
+
@defer.inlineCallbacks
def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion,
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index c615206df1..0b7d3da680 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -216,15 +216,6 @@ class RoomListHandler(BaseHandler):
direction_is_forward=False,
).to_token()
- for room in results:
- # populate search result entries with additional fields, namely
- # 'aliases'
- room_id = room["room_id"]
-
- aliases = yield self.store.get_aliases_for_room(room_id)
- if aliases:
- room["aliases"] = aliases
-
response["chunk"] = results
response["total_room_count_estimate"] = yield self.store.count_public_rooms(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4260426369..decef944ff 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -24,13 +24,20 @@ from twisted.internet import defer
from synapse import types
from synapse.api.constants import EventTypes, Membership
+from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ HttpResponseException,
+ SynapseError,
+)
+from synapse.handlers.identity import LookupAlgorithm, create_id_access_token_header
+from synapse.http.client import SimpleHttpClient
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.types import Collection, RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
-from ._base import BaseHandler
-
logger = logging.getLogger(__name__)
@@ -60,6 +67,7 @@ class RoomMemberHandler(object):
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.identity_handler = hs.get_handlers().identity_handler
self.member_linearizer = Linearizer(name="member")
@@ -67,13 +75,10 @@ class RoomMemberHandler(object):
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules = hs.get_third_party_event_rules()
self._server_notices_mxid = self.config.server_notices_mxid
+ self.rewrite_identity_server_urls = self.config.rewrite_identity_server_urls
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
-
- # This is only used to get at ratelimit function, and
- # maybe_kick_guest_users. It's fine there are multiple of these as
- # it doesn't store state.
- self.base_handler = BaseHandler(hs)
+ self.ratelimiter = Ratelimiter()
@abc.abstractmethod
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
@@ -265,8 +270,31 @@ class RoomMemberHandler(object):
third_party_signed=None,
ratelimit=True,
content=None,
+ new_room=False,
require_consent=True,
):
+ """Update a users membership in a room
+
+ Args:
+ requester (Requester)
+ target (UserID)
+ room_id (str)
+ action (str): The "action" the requester is performing against the
+ target. One of join/leave/kick/ban/invite/unban.
+ txn_id (str|None): The transaction ID associated with the request,
+ or None not provided.
+ remote_room_hosts (list[str]|None): List of remote servers to try
+ and join via if server isn't already in the room.
+ third_party_signed (dict|None): The signed object for third party
+ invites.
+ ratelimit (bool): Whether to apply ratelimiting to this request.
+ content (dict|None): Fields to include in the new events content.
+ new_room (bool): Whether these membership changes are happening
+ as part of a room creation (e.g. initial joins and invites)
+
+ Returns:
+ Deferred[FrozenEvent]
+ """
key = (room_id,)
with (yield self.member_linearizer.queue(key)):
@@ -280,6 +308,7 @@ class RoomMemberHandler(object):
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
+ new_room=new_room,
require_consent=require_consent,
)
@@ -297,6 +326,7 @@ class RoomMemberHandler(object):
third_party_signed=None,
ratelimit=True,
content=None,
+ new_room=False,
require_consent=True,
):
content_specified = bool(content)
@@ -361,8 +391,15 @@ class RoomMemberHandler(object):
)
block_invite = True
+ is_published = yield self.store.is_room_published(room_id)
+
if not self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(), room_id
+ requester.user.to_string(),
+ target.to_string(),
+ third_party_invite=None,
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
):
logger.info("Blocking invite due to spam checker")
block_invite = True
@@ -434,8 +471,26 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to join rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
+
+ inviter = yield self._get_inviter(target.to_string(), room_id)
+ if not is_requester_admin:
+ # We assume that if the spam checker allowed the user to create
+ # a room then they're allowed to join it.
+ if not new_room and not self.spam_checker.user_may_join_room(
+ target.to_string(), room_id, is_invited=inviter is not None
+ ):
+ raise SynapseError(403, "Not allowed to join this room")
+
if not is_host_in_room:
- inviter = yield self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@@ -706,6 +761,7 @@ class RoomMemberHandler(object):
id_server,
requester,
txn_id,
+ new_room=False,
id_access_token=None,
):
if self.config.block_non_admin_invites:
@@ -717,7 +773,23 @@ class RoomMemberHandler(object):
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
- yield self.base_handler.ratelimit(requester)
+ self.ratelimiter.ratelimit(
+ requester.user.to_string(),
+ time_now_s=self.hs.clock.time(),
+ rate_hz=self.hs.config.rc_third_party_invite.per_second,
+ burst_count=self.hs.config.rc_third_party_invite.burst_count,
+ update=True,
+ )
+
+ can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
+ medium, address, room_id
+ )
+ if not can_invite:
+ raise SynapseError(
+ 403,
+ "This third-party identifier can not be invited in this room",
+ Codes.FORBIDDEN,
+ )
can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id
@@ -738,6 +810,19 @@ class RoomMemberHandler(object):
id_server, medium, address, id_access_token
)
+ is_published = yield self.store.is_room_published(room_id)
+
+ if not self.spam_checker.user_may_invite(
+ requester.user.to_string(),
+ invitee,
+ third_party_invite={"medium": medium, "address": address},
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ raise SynapseError(403, "Invites have been disabled on this server")
+
if invitee:
yield self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 110097eab9..ec1542d416 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -184,7 +184,7 @@ class SearchHandler(BaseHandler):
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
)
- room_ids = set(r.room_id for r in rooms)
+ room_ids = {r.room_id for r in rooms}
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
@@ -374,12 +374,12 @@ class SearchHandler(BaseHandler):
).to_string()
if include_profile:
- senders = set(
+ senders = {
ev.sender
for ev in itertools.chain(
res["events_before"], [event], res["events_after"]
)
- )
+ }
if res["events_after"]:
last_event_id = res["events_after"][-1].event_id
@@ -421,7 +421,7 @@ class SearchHandler(BaseHandler):
state_results = {}
if include_state:
- rooms = set(e.room_id for e in allowed_events)
+ rooms = {e.room_id for e in allowed_events}
for room_id in rooms:
state = yield self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index d90c9e0108..3f50d6de47 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,12 +31,15 @@ class SetPasswordHandler(BaseHandler):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
+ self._password_policy_handler = hs.get_password_policy_handler()
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
+ self._password_policy_handler.validate_password(newpassword)
+
password_hash = yield self._auth_handler.hash(newpassword)
except_device_id = requester.device_id if requester else None
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 4324bc702e..669dbc8a48 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -682,11 +682,9 @@ class SyncHandler(object):
# FIXME: order by stream ordering rather than as returned by SQL
if joined_user_ids or invited_user_ids:
- summary["m.heroes"] = sorted(
- [user_id for user_id in (joined_user_ids + invited_user_ids)]
- )[0:5]
+ summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5]
else:
- summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5]
+ summary["m.heroes"] = sorted(gone_user_ids)[0:5]
if not sync_config.filter_collection.lazy_load_members():
return summary
@@ -697,9 +695,9 @@ class SyncHandler(object):
# track which members the client should already know about via LL:
# Ones which are already in state...
- existing_members = set(
+ existing_members = {
user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member
- )
+ }
# ...or ones which are in the timeline...
for ev in batch.events:
@@ -773,10 +771,10 @@ class SyncHandler(object):
# We only request state for the members needed to display the
# timeline:
- members_to_fetch = set(
+ members_to_fetch = {
event.sender # FIXME: we also care about invite targets etc.
for event in batch.events
- )
+ }
if full_state:
# always make sure we LL ourselves so we know we're in the room
@@ -1993,10 +1991,10 @@ def _calculate_state(
)
}
- c_ids = set(e for e in itervalues(current))
- ts_ids = set(e for e in itervalues(timeline_start))
- p_ids = set(e for e in itervalues(previous))
- tc_ids = set(e for e in itervalues(timeline_contains))
+ c_ids = set(itervalues(current))
+ ts_ids = set(itervalues(timeline_start))
+ p_ids = set(itervalues(previous))
+ tc_ids = set(itervalues(timeline_contains))
# If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync,
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 5406618431..391bceb0c4 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -198,7 +198,7 @@ class TypingHandler(object):
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
)
- for domain in set(get_domain_from_id(u) for u in users):
+ for domain in {get_domain_from_id(u) for u in users}:
if domain != self.server_name:
logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu(
@@ -231,7 +231,7 @@ class TypingHandler(object):
return
users = yield self.state.get_current_users_in_room(room_id)
- domains = set(get_domain_from_id(u) for u in users)
+ domains = {get_domain_from_id(u) for u in users}
if self.server_name in domains:
logger.info("Got typing update from %s: %r", user_id, content)
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index 6073fc2725..0c2527bd86 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -148,7 +148,7 @@ def trace_function(f):
pathname=pathname,
lineno=lineno,
msg=msg,
- args=tuple(),
+ args=(),
exc_info=None,
)
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 0b45e1f52a..0dba997a23 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -240,7 +240,7 @@ class BucketCollector(object):
res.append(["+Inf", sum(data.values())])
metric = HistogramMetricFamily(
- self.name, "", buckets=res, sum_value=sum([x * y for x, y in data.items()])
+ self.name, "", buckets=res, sum_value=sum(x * y for x, y in data.items())
)
yield metric
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index c53d2a0d40..b65bcd8806 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -80,13 +80,13 @@ _background_process_db_sched_duration = Counter(
# map from description to a counter, so that we can name our logcontexts
# incrementally. (It actually duplicates _background_process_start_count, but
# it's much simpler to do so than to try to combine them.)
-_background_process_counts = dict() # type: dict[str, int]
+_background_process_counts = {} # type: dict[str, int]
# map from description to the currently running background processes.
#
# it's kept as a dict of sets rather than a big set so that we can keep track
# of process descriptions that no longer have any active processes.
-_background_processes = dict() # type: dict[str, set[_BackgroundProcess]]
+_background_processes = {} # type: dict[str, set[_BackgroundProcess]]
# A lock that covers the above dicts
_bg_metrics_lock = threading.Lock()
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 7d9f5a38d9..433ca2f416 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -400,11 +400,11 @@ class RulesForRoom(object):
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
- interested_in_user_ids = set(
+ interested_in_user_ids = {
user_id
for user_id, membership in itervalues(members)
if membership == Membership.JOIN
- )
+ }
logger.debug("Joined: %r", interested_in_user_ids)
@@ -412,9 +412,9 @@ class RulesForRoom(object):
interested_in_user_ids, on_invalidate=self.invalidate_all_cb
)
- user_ids = set(
+ user_ids = {
uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher
- )
+ }
logger.debug("With pushers: %r", user_ids)
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 8c818a86bf..ba4551d619 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -204,7 +204,7 @@ class EmailPusher(object):
yield self.send_notification(unprocessed, reason)
yield self.save_last_stream_ordering_and_success(
- max([ea["stream_ordering"] for ea in unprocessed])
+ max(ea["stream_ordering"] for ea in unprocessed)
)
# we update the throttle on all the possible unprocessed push actions
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index b13b646bfd..4ccaf178ce 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -526,12 +526,10 @@ class Mailer(object):
# If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
sender_ids = list(
- set(
- [
- notif_events[n["event_id"]].sender
- for n in notifs_by_room[room_id]
- ]
- )
+ {
+ notif_events[n["event_id"]].sender
+ for n in notifs_by_room[room_id]
+ }
)
member_events = yield self.store.get_events(
@@ -558,12 +556,10 @@ class Mailer(object):
# If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
sender_ids = list(
- set(
- [
- notif_events[n["event_id"]].sender
- for n in notifs_by_room[reason["room_id"]]
- ]
- )
+ {
+ notif_events[n["event_id"]].sender
+ for n in notifs_by_room[reason["room_id"]]
+ }
)
member_events = yield self.store.get_events(
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 16a7e8e31d..0644a13cfc 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -18,6 +18,8 @@ import re
from twisted.internet import defer
+from synapse.api.constants import EventTypes
+
logger = logging.getLogger(__name__)
# intentionally looser than what aliases we allow to be registered since
@@ -50,17 +52,17 @@ def calculate_room_name(
(string or None) A human readable name for the room.
"""
# does it have a name?
- if ("m.room.name", "") in room_state_ids:
+ if (EventTypes.Name, "") in room_state_ids:
m_room_name = yield store.get_event(
- room_state_ids[("m.room.name", "")], allow_none=True
+ room_state_ids[(EventTypes.Name, "")], allow_none=True
)
if m_room_name and m_room_name.content and m_room_name.content["name"]:
return m_room_name.content["name"]
# does it have a canonical alias?
- if ("m.room.canonical_alias", "") in room_state_ids:
+ if (EventTypes.CanonicalAlias, "") in room_state_ids:
canon_alias = yield store.get_event(
- room_state_ids[("m.room.canonical_alias", "")], allow_none=True
+ room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True
)
if (
canon_alias
@@ -74,32 +76,22 @@ def calculate_room_name(
# for an event type, so rearrange the data structure
room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
- # right then, any aliases at all?
- if "m.room.aliases" in room_state_bytype_ids:
- m_room_aliases = room_state_bytype_ids["m.room.aliases"]
- for alias_id in m_room_aliases.values():
- alias_event = yield store.get_event(alias_id, allow_none=True)
- if alias_event and alias_event.content.get("aliases"):
- the_aliases = alias_event.content["aliases"]
- if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
- return the_aliases[0]
-
if not fallback_to_members:
return None
my_member_event = None
- if ("m.room.member", user_id) in room_state_ids:
+ if (EventTypes.Member, user_id) in room_state_ids:
my_member_event = yield store.get_event(
- room_state_ids[("m.room.member", user_id)], allow_none=True
+ room_state_ids[(EventTypes.Member, user_id)], allow_none=True
)
if (
my_member_event is not None
and my_member_event.content["membership"] == "invite"
):
- if ("m.room.member", my_member_event.sender) in room_state_ids:
+ if (EventTypes.Member, my_member_event.sender) in room_state_ids:
inviter_member_event = yield store.get_event(
- room_state_ids[("m.room.member", my_member_event.sender)],
+ room_state_ids[(EventTypes.Member, my_member_event.sender)],
allow_none=True,
)
if inviter_member_event:
@@ -114,9 +106,9 @@ def calculate_room_name(
# we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user.
- if "m.room.member" in room_state_bytype_ids:
+ if EventTypes.Member in room_state_bytype_ids:
member_events = yield store.get_events(
- list(room_state_bytype_ids["m.room.member"].values())
+ list(room_state_bytype_ids[EventTypes.Member].values())
)
all_members = [
ev
@@ -138,9 +130,9 @@ def calculate_room_name(
# self-chat, peeked room with 1 participant,
# or inbound invite, or outbound 3PID invite.
if all_members[0].sender == user_id:
- if "m.room.third_party_invite" in room_state_bytype_ids:
+ if EventTypes.ThirdPartyInvite in room_state_bytype_ids:
third_party_invites = room_state_bytype_ids[
- "m.room.third_party_invite"
+ EventTypes.ThirdPartyInvite
].values()
if len(third_party_invites) > 0:
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index b9dca5bc63..01789a9fb4 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -191,7 +191,7 @@ class PusherPool:
min_stream_id - 1, max_stream_id
)
# This returns a tuple, user_id is at index 3
- users_affected = set([r[3] for r in updated_receipts])
+ users_affected = {r[3] for r in updated_receipts}
for u in users_affected:
if u in self.pushers:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 444eb7b7f4..1be1ccbdf3 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -44,7 +44,7 @@ class ReplicationEndpoint(object):
"""Helper base class for defining new replication HTTP endpoints.
This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
- (with an `/:txn_id` prefix for cached requests.), where NAME is a name,
+ (with a `/:txn_id` suffix for cached requests), where NAME is a name,
PATH_ARGS are a tuple of parameters to be encoded in the URL.
For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`,
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 49a3251372..8794720101 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -17,6 +17,7 @@ import logging
from twisted.internet import defer
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import event_type_from_format_version
from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request
@@ -211,7 +212,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
Request format:
- POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
+ POST /_synapse/replication/fed_cleanup_room/:room_id/:txn_id
{}
"""
@@ -238,8 +239,41 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
return 200, {}
+class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
+ """Called to clean up any data in DB for a given room, ready for the
+ server to join the room.
+
+ Request format:
+
+ POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id
+
+ {
+ "room_version": "1",
+ }
+ """
+
+ NAME = "store_room_on_invite"
+ PATH_ARGS = ("room_id",)
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ self.store = hs.get_datastore()
+
+ @staticmethod
+ def _serialize_payload(room_id, room_version):
+ return {"room_version": room_version.identifier}
+
+ async def _handle_request(self, request, room_id):
+ content = parse_json_object_from_request(request)
+ room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
+ await self.store.maybe_store_room_on_invite(room_id, room_version)
+ return 200, {}
+
+
def register_servlets(hs, http_server):
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)
ReplicationCleanRoomRestServlet(hs).register(http_server)
+ ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 3aa6cb8b96..e73342c657 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -32,6 +32,7 @@ from synapse.storage.data_stores.main.state import StateGroupWorkerStore
from synapse.storage.data_stores.main.stream import StreamWorkerStore
from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
from synapse.storage.database import Database
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -68,6 +69,21 @@ class SlavedEventStore(
super(SlavedEventStore, self).__init__(database, db_conn, hs)
+ events_max = self._stream_id_gen.get_current_token()
+ curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+ db_conn,
+ "current_state_delta_stream",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=events_max, # As we share the stream id with events token
+ limit=1000,
+ )
+ self._curr_state_delta_stream_cache = StreamChangeCache(
+ "_curr_state_delta_stream_cache",
+ min_curr_state_delta_id,
+ prefilled_cache=curr_state_delta_prefill,
+ )
+
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
@@ -120,6 +136,10 @@ class SlavedEventStore(
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
+ self._curr_state_delta_stream_cache.entity_has_changed(
+ row.data.room_id, token
+ )
+
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
(data.state_key,)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index ce60ae2e07..ce9d1fae12 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -323,7 +323,11 @@ class ReplicationStreamer(object):
# We need to tell the presence handler that the connection has been
# lost so that it can handle any ongoing syncs on that connection.
- self.presence_handler.update_external_syncs_clear(connection.conn_id)
+ run_as_background_process(
+ "update_external_syncs_clear",
+ self.presence_handler.update_external_syncs_clear,
+ connection.conn_id,
+ )
def _batch_updates(updates):
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index a8d568b14a..208e8a667b 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -24,7 +24,7 @@ import attr
logger = logging.getLogger(__name__)
-MAX_EVENTS_BEHIND = 10000
+MAX_EVENTS_BEHIND = 500000
BackfillStreamRow = namedtuple(
"BackfillStreamRow",
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 4a1fc2ec2b..14eca70ba4 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import (
keys,
notifications,
openid,
+ password_policy,
read_marker,
receipts,
register,
@@ -117,6 +118,7 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
+ password_policy.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
# moving to /_synapse/admin
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index 459482eb6d..a96f75ce26 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -29,7 +29,7 @@ def historical_admin_path_patterns(path_regex):
Note that this should only be used for existing endpoints: new ones should just
register for the /_synapse/admin path.
"""
- return list(
+ return [
re.compile(prefix + path_regex)
for prefix in (
"^/_synapse/admin/v1",
@@ -37,7 +37,7 @@ def historical_admin_path_patterns(path_regex):
"^/_matrix/client/unstable/admin",
"^/_matrix/client/r0/admin",
)
- )
+ ]
def admin_patterns(path_regex: str):
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 064908fbb0..80f959248d 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -226,13 +226,16 @@ class UserRestServletV2(RestServlet):
)
if "deactivated" in body:
- deactivate = bool(body["deactivated"])
+ deactivate = body["deactivated"]
+ if not isinstance(deactivate, bool):
+ raise SynapseError(
+ 400, "'deactivated' parameter is not of type boolean"
+ )
+
if deactivate and not user["deactivated"]:
- result = await self.deactivate_account_handler.deactivate_account(
+ await self.deactivate_account_handler.deactivate_account(
target_user.to_string(), False
)
- if not result:
- raise SynapseError(500, "Could not deactivate user")
user = await self.admin_handler.get_user(target_user)
return 200, user
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 1294e080dc..2c99536678 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -599,6 +599,7 @@ class SSOAuthHandler(object):
redirect_url = self._add_login_token_to_redirect_url(
client_redirect_url, login_token
)
+ # Load page
request.redirect(redirect_url)
finish_request(request)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e7fe50ed72..165313b572 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,6 +14,7 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
+from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -28,6 +29,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -63,11 +65,27 @@ class ProfileDisplaynameRestServlet(RestServlet):
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_displayname(shadow_user.to_string(), content)
+
return 200, {}
def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_displayname(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
@@ -76,6 +94,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
super(ProfileAvatarURLRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -114,11 +133,27 @@ class ProfileAvatarURLRestServlet(RestServlet):
user, requester, new_avatar_url, is_admin
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_avatar_url(shadow_user.to_string(), content)
+
return 200, {}
def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_avatar_url(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 4f74600239..9fd4908136 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -49,7 +49,7 @@ class PushRuleRestServlet(RestServlet):
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
- spec = _rule_spec_from_path([x for x in path.split("/")])
+ spec = _rule_spec_from_path(path.split("/"))
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
@@ -110,7 +110,7 @@ class PushRuleRestServlet(RestServlet):
if self._is_worker:
raise Exception("Cannot handle DELETE /push_rules on worker")
- spec = _rule_spec_from_path([x for x in path.split("/")])
+ spec = _rule_spec_from_path(path.split("/"))
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
@@ -138,7 +138,7 @@ class PushRuleRestServlet(RestServlet):
rules = format_push_rules_for_user(requester.user, rules)
- path = [x for x in path.split("/")][1:]
+ path = path.split("/")[1:]
if path == []:
# we're a reference impl: pedantry is our job.
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 6f6b7aed6e..550a2f1b44 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -54,9 +54,9 @@ class PushersRestServlet(RestServlet):
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
- filtered_pushers = list(
+ filtered_pushers = [
{k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
- )
+ ]
return 200, {"pushers": filtered_pushers}
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 64f51406fb..52a00a5c7c 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -727,7 +727,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["id_server"],
requester,
txn_id,
- content.get("id_access_token"),
+ new_room=False,
+ id_access_token=content.get("id_access_token"),
)
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index dc837d6c75..bd1c0efbcb 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,9 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
from six.moves import http_client
+from twisted.internet import defer
+
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
from synapse.config.emailconfig import ThreepidBehaviour
@@ -28,9 +31,10 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.types import UserID
from synapse.push.mailer import Mailer, load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
-from synapse.util.stringutils import assert_valid_client_secret
+from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -91,7 +95,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", email):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
@@ -216,6 +220,7 @@ class PasswordRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
self._set_password_handler = hs.get_set_password_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -233,9 +238,13 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
- params = await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
- )
+ # blindly trust ASes without UI-authing them
+ if requester.app_service:
+ params = body
+ else:
+ params = await self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request)
+ )
user_id = requester.user.to_string()
else:
requester = None
@@ -268,11 +277,29 @@ class PasswordRestServlet(RestServlet):
await self._set_password_handler.set_password(user_id, new_password, requester)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_password(params, shadow_user.to_string())
+
return 200, {}
def on_OPTIONS(self, _):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_password(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
@@ -363,13 +390,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", email)):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
+ assert_valid_client_secret(body["client_secret"])
+
existing_user_id = await self.store.get_user_id_by_threepid(
"email", body["email"]
)
@@ -428,13 +457,15 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
+ assert_valid_client_secret(body["client_secret"])
+
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
@@ -589,7 +620,8 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- self.datastore = self.hs.get_datastore()
+ self.datastore = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
@@ -599,10 +631,33 @@ class ThreepidRestServlet(RestServlet):
return 200, {"threepids": threepids}
async def on_POST(self, request):
+ if self.hs.config.disable_3pid_changes:
+ raise SynapseError(400, "3PID changes disabled on this server")
+
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
+ # skip validation if this is a shadow 3PID from an AS
+ if requester.app_service:
+ # XXX: ASes pass in a validated threepid directly to bypass the IS.
+ # This makes the API entirely change shape when we have an AS token;
+ # it really should be an entirely separate API - perhaps
+ # /account/3pid/replicate or something.
+ threepid = body.get("threepid")
+
+ await self.auth_handler.add_threepid(
+ user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
+ )
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
+ return 200, {}
+
threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds")
if threepid_creds is None:
raise SynapseError(
@@ -624,12 +679,36 @@ class ThreepidRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$", releases=(), unstable=True)
@@ -666,6 +745,16 @@ class ThreepidAddRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
raise SynapseError(
@@ -701,6 +790,29 @@ class ThreepidBindRestServlet(RestServlet):
return 200, {}
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidUnbindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/unbind$", releases=(), unstable=True)
@@ -738,10 +850,15 @@ class ThreepidDeleteRestServlet(RestServlet):
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
+ self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
+ if self.hs.config.disable_3pid_changes:
+ raise SynapseError(400, "3PID changes disabled on this server")
+
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
@@ -759,6 +876,12 @@ class ThreepidDeleteRestServlet(RestServlet):
logger.exception("Failed to remove threepid")
raise SynapseError(500, "Failed to remove threepid")
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid_delete(body, shadow_user.to_string())
+
if ret:
id_server_unbind_result = "success"
else:
@@ -766,6 +889,77 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
+ @defer.inlineCallbacks
+ def shadow_3pid_delete(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
+
+class ThreepidLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ """Proxy a /_matrix/identity/api/v1/lookup request to an identity
+ server
+ """
+ yield self.auth.get_user_by_req(request)
+
+ # Verify query parameters
+ query_params = request.args
+ assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"])
+
+ # Retrieve needed information from query parameters
+ medium = parse_string(request, "medium")
+ address = parse_string(request, "address")
+ id_server = parse_string(request, "id_server")
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = yield self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
+
+ defer.returnValue((200, ret))
+
+
+class ThreepidBulkLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidBulkLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
+ server
+ """
+ yield self.auth.get_user_by_req(request)
+
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["threepids", "id_server"])
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = yield self.identity_handler.proxy_bulk_lookup_3pid(
+ body["id_server"], body["threepids"]
+ )
+
+ defer.returnValue((200, ret))
+
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
@@ -794,4 +988,6 @@ def register_servlets(hs, http_server):
ThreepidBindRestServlet(hs).register(http_server)
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ ThreepidLookupRestServlet(hs).register(http_server)
+ ThreepidBulkLookupRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 64eb7fec3b..17495f020b 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -15,8 +15,11 @@
import logging
+from twisted.internet import defer
+
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
from ._base import client_patterns
@@ -38,6 +41,7 @@ class AccountDataServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
+ self._profile_handler = hs.get_profile_handler()
async def on_PUT(self, request, user_id, account_data_type):
requester = await self.auth.get_user_by_req(request)
@@ -46,6 +50,11 @@ class AccountDataServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if account_data_type == "im.vector.hide_profile":
+ user = UserID.from_string(user_id)
+ hide_profile = body.get("hide_profile")
+ await self._profile_handler.set_active(user, not hide_profile, True)
+
max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)
diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py
new file mode 100644
index 0000000000..968403cca4
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/password_policy.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import RestServlet
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordPolicyServlet(RestServlet):
+ PATTERNS = client_patterns("/password_policy$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(PasswordPolicyServlet, self).__init__()
+
+ self.policy = hs.config.password_policy
+ self.enabled = hs.config.password_policy_enabled
+
+ def on_GET(self, request):
+ if not self.enabled or not self.policy:
+ return (200, {})
+
+ policy = {}
+
+ for param in [
+ "minimum_length",
+ "require_digit",
+ "require_symbol",
+ "require_lowercase",
+ "require_uppercase",
+ ]:
+ if param in self.policy:
+ policy["m.%s" % param] = self.policy[param]
+
+ return (200, policy)
+
+
+def register_servlets(hs, http_server):
+ PasswordPolicyServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index a09189b1b4..7406c13fb4 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 - 2016 OpenMarket Ltd
-# Copyright 2017 Vector Creations Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,7 @@
import hmac
import logging
+import re
from typing import List, Union
from six import string_types
@@ -123,10 +125,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", body["email"])):
raise SynapseError(
403,
- "Your email domain is not authorized to register on this server",
+ "Your email is not authorized to register on this server",
Codes.THREEPID_DENIED,
)
@@ -190,7 +192,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ assert_valid_client_secret(body["client_secret"])
+
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Phone numbers are not authorized to register on this server",
@@ -373,6 +377,7 @@ class RegisterRestServlet(RestServlet):
self.room_member_handler = hs.get_room_member_handler()
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_registration_ratelimiter()
+ self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
self._registration_flows = _calculate_registration_flows(
@@ -414,12 +419,15 @@ class RegisterRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
+ desired_password = None
if "password" in body:
if (
not isinstance(body["password"], string_types)
or len(body["password"]) > 512
):
raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(body["password"])
+ desired_password = body["password"]
desired_username = None
if "username" in body:
@@ -430,6 +438,8 @@ class RegisterRestServlet(RestServlet):
raise SynapseError(400, "Invalid username")
desired_username = body["username"]
+ desired_display_name = body.get("display_name")
+
appservice = None
if self.auth.has_access_token(request):
appservice = await self.auth.get_appservice_by_req(request)
@@ -453,7 +463,11 @@ class RegisterRestServlet(RestServlet):
if isinstance(desired_username, string_types):
result = await self._do_appservice_registration(
- desired_username, access_token, body
+ desired_username,
+ desired_password,
+ desired_display_name,
+ access_token,
+ body,
)
return 200, result # we throw for non 200 responses
@@ -514,7 +528,7 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
- if not check_3pid_allowed(self.hs, medium, address):
+ if not (await check_3pid_allowed(self.hs, medium, address)):
raise SynapseError(
403,
"Third party identifiers (email/phone numbers)"
@@ -522,6 +536,80 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ existingUid = await self.store.get_user_id_by_threepid(
+ medium, address
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE
+ )
+
+ if self.hs.config.register_mxid_from_3pid:
+ # override the desired_username based on the 3PID if any.
+ # reset it first to avoid folks picking their own username.
+ desired_username = None
+
+ # we should have an auth_result at this point if we're going to progress
+ # to register the user (i.e. we haven't picked up a registered_user_id
+ # from our session store), in which case get ready and gen the
+ # desired_username
+ if auth_result:
+ if (
+ self.hs.config.register_mxid_from_3pid == "email"
+ and LoginType.EMAIL_IDENTITY in auth_result
+ ):
+ address = auth_result[LoginType.EMAIL_IDENTITY]["address"]
+ desired_username = synapse.types.strip_invalid_mxid_characters(
+ address.replace("@", "-").lower()
+ )
+
+ # find a unique mxid for the account, suffixing numbers
+ # if needed
+ while True:
+ try:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+ # if we got this far we passed the check.
+ break
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ m = re.match(r"^(.*?)(\d+)$", desired_username)
+ if m:
+ desired_username = m.group(1) + str(
+ int(m.group(2)) + 1
+ )
+ else:
+ desired_username += "1"
+ else:
+ # something else went wrong.
+ break
+
+ if self.hs.config.register_just_use_email_for_display_name:
+ desired_display_name = address
+ else:
+ # Custom mapping between email address and display name
+ desired_display_name = self._map_email_to_displayname(address)
+ elif (
+ self.hs.config.register_mxid_from_3pid == "msisdn"
+ and LoginType.MSISDN in auth_result
+ ):
+ desired_username = auth_result[LoginType.MSISDN]["address"]
+ else:
+ raise SynapseError(
+ 400, "Cannot derive mxid from 3pid; no recognised 3pid"
+ )
+
+ if desired_username is not None:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session", registered_user_id
@@ -532,9 +620,16 @@ class RegisterRestServlet(RestServlet):
# NB: This may be from the auth handler and NOT from the POST
assert_params_in_dict(params, ["password"])
- desired_username = params.get("username", None)
+ if not self.hs.config.register_mxid_from_3pid:
+ desired_username = params.get("username", None)
+ else:
+ # we keep the original desired_username derived from the 3pid above
+ pass
+
guest_access_token = params.get("guest_access_token", None)
- new_password = params.get("password", None)
+
+ # XXX: don't we need to validate these for length etc like we did on
+ # the ones from the JSON body earlier on in the method?
if desired_username is not None:
desired_username = desired_username.lower()
@@ -567,8 +662,9 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
- password=new_password,
+ password=params.get("password", None),
guest_access_token=guest_access_token,
+ default_display_name=desired_display_name,
threepid=threepid,
address=client_addr,
)
@@ -580,6 +676,14 @@ class RegisterRestServlet(RestServlet):
):
await self.store.upsert_monthly_active_user(registered_user_id)
+ if self.hs.config.shadow_server:
+ await self.registration_handler.shadow_register(
+ localpart=desired_username,
+ display_name=desired_display_name,
+ auth_result=auth_result,
+ params=params,
+ )
+
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
self.auth_handler.set_session_data(
@@ -604,11 +708,30 @@ class RegisterRestServlet(RestServlet):
def on_OPTIONS(self, _):
return 200, {}
- async def _do_appservice_registration(self, username, as_token, body):
+ async def _do_appservice_registration(
+ self, username, password, display_name, as_token, body
+ ):
+ # FIXME: appservice_register() is horribly duplicated with register()
+ # and they should probably just be combined together with a config flag.
user_id = await self.registration_handler.appservice_register(
- username, as_token
+ username, as_token, password, display_name
)
- return await self._create_registration_details(user_id, body)
+ result = await self._create_registration_details(user_id, body)
+
+ auth_result = body.get("auth_result")
+ if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
+ threepid = auth_result[LoginType.EMAIL_IDENTITY]
+ await self._register_email_threepid(
+ user_id, threepid, result["access_token"], body.get("bind_email")
+ )
+
+ if auth_result and LoginType.MSISDN in auth_result:
+ threepid = auth_result[LoginType.MSISDN]
+ await self._register_msisdn_threepid(
+ user_id, threepid, result["access_token"], body.get("bind_msisdn")
+ )
+
+ return result
async def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
@@ -659,6 +782,60 @@ class RegisterRestServlet(RestServlet):
)
+def cap(name):
+ """Capitalise parts of a name containing different words, including those
+ separated by hyphens.
+ For example, 'John-Doe'
+
+ Args:
+ name (str): The name to parse
+ """
+ if not name:
+ return name
+
+ # Split the name by whitespace then hyphens, capitalizing each part then
+ # joining it back together.
+ capatilized_name = " ".join(
+ "-".join(part.capitalize() for part in space_part.split("-"))
+ for space_part in name.split()
+ )
+ return capatilized_name
+
+
+def _map_email_to_displayname(address):
+ """Custom mapping from an email address to a user displayname
+
+ Args:
+ address (str): The email address to process
+ Returns:
+ str: The new displayname
+ """
+ # Split the part before and after the @ in the email.
+ # Replace all . with spaces in the first part
+ parts = address.replace(".", " ").split("@")
+
+ # Figure out which org this email address belongs to
+ org_parts = parts[1].split(" ")
+
+ # If this is a ...matrix.org email, mark them as an Admin
+ if org_parts[-2] == "matrix" and org_parts[-1] == "org":
+ org = "Tchap Admin"
+
+ # Is this is a ...gouv.fr address, set the org to whatever is before
+ # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their
+ # org as "gouv"
+ elif org_parts[-2] == "gouv" and org_parts[-1] == "fr":
+ org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2]
+
+ # Otherwise, mark their org as the email's second-level domain name
+ else:
+ org = org_parts[-2]
+
+ desired_display_name = cap(parts[0]) + " [" + cap(org) + "]"
+
+ return desired_display_name
+
+
def _calculate_registration_flows(
# technically `config` has to provide *all* of these interfaces, not just one
config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index d8292ce29f..8fa68dd37f 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -72,7 +72,7 @@ class SyncRestServlet(RestServlet):
"""
PATTERNS = client_patterns("/sync$")
- ALLOWED_PRESENCE = set(["online", "offline", "unavailable"])
+ ALLOWED_PRESENCE = {"online", "offline", "unavailable"}
def __init__(self, hs):
super(SyncRestServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index bef91a2d3e..faf9dbdea4 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -15,8 +15,13 @@
import logging
+from signedjson.sign import sign_json
+
+from twisted.internet import defer
+
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
from ._base import client_patterns
@@ -35,6 +40,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
"""Searches for users in directory
@@ -61,6 +67,16 @@ class UserDirectorySearchRestServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if self.hs.config.user_directory_defer_to_id_server:
+ signed_body = sign_json(
+ body, self.hs.hostname, self.hs.config.signing_key[0]
+ )
+ url = "%s/_matrix/identity/api/v1/user_directory/search" % (
+ self.hs.config.user_directory_defer_to_id_server,
+ )
+ resp = await self.http_client.post_json_get_json(url, signed_body)
+ defer.returnValue((200, resp))
+
limit = body.get("limit", 10)
limit = min(limit, 50)
@@ -76,5 +92,87 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results
+class UserInfoServlet(RestServlet):
+ """
+ GET /user/{user_id}/info HTTP/1.1
+ """
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
+
+ def __init__(self, hs):
+ super(UserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.notifier = hs.get_notifier()
+ self.clock = hs.get_clock()
+ self.transport_layer = hs.get_federation_transport_client()
+ registry = hs.get_federation_registry()
+
+ if not registry.query_handlers.get("user_info"):
+ registry.register_query_handler("user_info", self._on_federation_query)
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ # Ensure the user is authenticated
+ yield self.auth.get_user_by_req(request, allow_guest=False)
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ # Attempt to make a federation request to the server that owns this user
+ args = {"user_id": user_id}
+ res = yield self.transport_layer.make_query(
+ user.domain, "user_info", args, retry_on_dns_fail=True
+ )
+ defer.returnValue((200, res))
+
+ res = yield self._get_user_info(user_id)
+ defer.returnValue((200, res))
+
+ @defer.inlineCallbacks
+ def _on_federation_query(self, args):
+ """Called when a request for user information appears over federation
+
+ Args:
+ args (dict): Dictionary of query arguments provided by the request
+
+ Returns:
+ Deferred[dict]: Deactivation and expiration information for a given user
+ """
+ user_id = args.get("user_id")
+ if not user_id:
+ raise SynapseError(400, "user_id not provided")
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ res = yield self._get_user_info(user_id)
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def _get_user_info(self, user_id):
+ """Retrieve information about a given user
+
+ Args:
+ user_id (str): The User ID of a given user on this homeserver
+
+ Returns:
+ Deferred[dict]: Deactivation and expiration information for a given user
+ """
+ # Check whether user is deactivated
+ is_deactivated = yield self.store.get_user_deactivated_status(user_id)
+
+ # Check whether user is expired
+ expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
+ is_expired = (
+ expiration_ts is not None and self.clock.time_msec() >= expiration_ts
+ )
+
+ res = {"expired": is_expired, "deactivated": is_deactivated}
+ defer.returnValue(res)
+
+
def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)
+ UserInfoServlet(hs).register(http_server)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 9d6813a047..4b6d030a57 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -149,7 +149,7 @@ class RemoteKey(DirectServeResource):
time_now_ms = self.clock.time_msec()
- cache_misses = dict() # type: Dict[str, Set[str]]
+ cache_misses = {} # type: Dict[str, Set[str]]
for (server_name, key_id, from_server), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 65bbf00073..ba28dd089d 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -135,27 +135,25 @@ def add_file_headers(request, media_type, file_size, upload_name):
# separators as defined in RFC2616. SP and HT are handled separately.
# see _can_encode_filename_as_token.
-_FILENAME_SEPARATOR_CHARS = set(
- (
- "(",
- ")",
- "<",
- ">",
- "@",
- ",",
- ";",
- ":",
- "\\",
- '"',
- "/",
- "[",
- "]",
- "?",
- "=",
- "{",
- "}",
- )
-)
+_FILENAME_SEPARATOR_CHARS = {
+ "(",
+ ")",
+ "<",
+ ">",
+ "@",
+ ",",
+ ";",
+ ":",
+ "\\",
+ '"',
+ "/",
+ "[",
+ "]",
+ "?",
+ "=",
+ "{",
+ "}",
+}
def _can_encode_filename_as_token(x):
diff --git a/synapse/rulecheck/__init__.py b/synapse/rulecheck/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/rulecheck/__init__.py
diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py
new file mode 100644
index 0000000000..6f2a1931c5
--- /dev/null
+++ b/synapse/rulecheck/domain_rule_checker.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.config._base import ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DomainRuleChecker(object):
+ """
+ A re-implementation of the SpamChecker that prevents users in one domain from
+ inviting users in other domains to rooms, based on a configuration.
+
+ Takes a config in the format:
+
+ spam_checker:
+ module: "rulecheck.DomainRuleChecker"
+ config:
+ domain_mapping:
+ "inviter_domain": [ "invitee_domain_permitted", "other_domain_permitted" ]
+ "other_inviter_domain": [ "invitee_domain_permitted" ]
+ default: False
+
+ # Only let local users join rooms if they were explicitly invited.
+ can_only_join_rooms_with_invite: false
+
+ # Only let local users create rooms if they are inviting only one
+ # other user, and that user matches the rules above.
+ can_only_create_one_to_one_rooms: false
+
+ # Only let local users invite during room creation, regardless of the
+ # domain mapping rules above.
+ can_only_invite_during_room_creation: false
+
+ # Prevent local users from inviting users from certain domains to
+ # rooms published in the room directory.
+ domains_prevented_from_being_invited_to_published_rooms: []
+
+ # Allow third party invites
+ can_invite_by_third_party_id: true
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config):
+ self.domain_mapping = config["domain_mapping"] or {}
+ self.default = config["default"]
+
+ self.can_only_join_rooms_with_invite = config.get(
+ "can_only_join_rooms_with_invite", False
+ )
+ self.can_only_create_one_to_one_rooms = config.get(
+ "can_only_create_one_to_one_rooms", False
+ )
+ self.can_only_invite_during_room_creation = config.get(
+ "can_only_invite_during_room_creation", False
+ )
+ self.can_invite_by_third_party_id = config.get(
+ "can_invite_by_third_party_id", True
+ )
+ self.domains_prevented_from_being_invited_to_published_rooms = config.get(
+ "domains_prevented_from_being_invited_to_published_rooms", []
+ )
+
+ def check_event_for_spam(self, event):
+ """Implements synapse.events.SpamChecker.check_event_for_spam
+ """
+ return False
+
+ def user_may_invite(
+ self,
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room=False,
+ ):
+ """Implements synapse.events.SpamChecker.user_may_invite
+ """
+ if self.can_only_invite_during_room_creation and not new_room:
+ return False
+
+ if not self.can_invite_by_third_party_id and third_party_invite:
+ return False
+
+ # This is a third party invite (without a bound mxid), so unless we have
+ # banned all third party invites (above) we allow it.
+ if not invitee_userid:
+ return True
+
+ inviter_domain = self._get_domain_from_id(inviter_userid)
+ invitee_domain = self._get_domain_from_id(invitee_userid)
+
+ if inviter_domain not in self.domain_mapping:
+ return self.default
+
+ if (
+ published_room
+ and invitee_domain
+ in self.domains_prevented_from_being_invited_to_published_rooms
+ ):
+ return False
+
+ return invitee_domain in self.domain_mapping[inviter_domain]
+
+ def user_may_create_room(
+ self, userid, invite_list, third_party_invite_list, cloning
+ ):
+ """Implements synapse.events.SpamChecker.user_may_create_room
+ """
+
+ if cloning:
+ return True
+
+ if not self.can_invite_by_third_party_id and third_party_invite_list:
+ return False
+
+ number_of_invites = len(invite_list) + len(third_party_invite_list)
+
+ if self.can_only_create_one_to_one_rooms and number_of_invites != 1:
+ return False
+
+ return True
+
+ def user_may_create_room_alias(self, userid, room_alias):
+ """Implements synapse.events.SpamChecker.user_may_create_room_alias
+ """
+ return True
+
+ def user_may_publish_room(self, userid, room_id):
+ """Implements synapse.events.SpamChecker.user_may_publish_room
+ """
+ return True
+
+ def user_may_join_room(self, userid, room_id, is_invited):
+ """Implements synapse.events.SpamChecker.user_may_join_room
+ """
+ if self.can_only_join_rooms_with_invite and not is_invited:
+ return False
+
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ """Implements synapse.events.SpamChecker.parse_config
+ """
+ if "default" in config:
+ return config
+ else:
+ raise ConfigError("No default set for spam_config DomainRuleChecker")
+
+ @staticmethod
+ def _get_domain_from_id(mxid):
+ """Parses a string and returns the domain part of the mxid.
+
+ Args:
+ mxid (str): a valid mxid
+
+ Returns:
+ str: the domain part of the mxid
+
+ """
+ idx = mxid.find(":")
+ if idx == -1:
+ raise Exception("Invalid ID: %r" % (mxid,))
+ return mxid[idx + 1 :]
diff --git a/synapse/server.py b/synapse/server.py
index fd2f69e928..884028ca77 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -66,6 +66,7 @@ from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerH
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.message import EventCreationHandler, MessageHandler
from synapse.handlers.pagination import PaginationHandler
+from synapse.handlers.password_policy import PasswordPolicyHandler
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler
@@ -167,6 +168,7 @@ class HomeServer(object):
"event_builder_factory",
"filtering",
"http_client_context_factory",
+ "proxied_http_client",
"simple_http_client",
"proxied_http_client",
"media_repository",
@@ -199,6 +201,7 @@ class HomeServer(object):
"saml_handler",
"event_client_serializer",
"storage",
+ "password_policy_handler",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -536,6 +539,9 @@ class HomeServer(object):
def build_storage(self) -> Storage:
return Storage(self, self.datastores)
+ def build_password_policy_handler(self):
+ return PasswordPolicyHandler(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 40eabfe5d9..3844f0e12f 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -3,6 +3,7 @@ import twisted.internet
import synapse.api.auth
import synapse.config.homeserver
import synapse.crypto.keyring
+import synapse.federation.federation_server
import synapse.federation.sender
import synapse.federation.transport.client
import synapse.handlers
@@ -107,5 +108,9 @@ class HomeServer(object):
self,
) -> synapse.replication.tcp.client.ReplicationClientHandler:
pass
+ def get_federation_registry(
+ self,
+ ) -> synapse.federation.federation_server.FederationHandlerRegistry:
+ pass
def is_mine_id(self, domain_id: str) -> bool:
pass
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fdd6bef6b4..df7a4f6a89 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,7 +16,7 @@
import logging
from collections import namedtuple
-from typing import Dict, Iterable, List, Optional
+from typing import Dict, Iterable, List, Optional, Set
from six import iteritems, itervalues
@@ -662,7 +662,7 @@ class StateResolutionStore(object):
allow_rejected=allow_rejected,
)
- def get_auth_chain(self, event_ids):
+ def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]):
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -674,11 +674,16 @@ class StateResolutionStore(object):
presence of rejected events
Args:
- event_ids (list): The event IDs of the events to fetch the auth
- chain for. Must be state events.
+ event_ids: The event IDs of the events to fetch the auth chain for.
+ Must be state events.
+ ignore_events: Set of events to exclude from the returned auth
+ chain.
+
Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
"""
- return self.store.get_auth_chain_ids(event_ids, include_given=True)
+ return self.store.get_auth_chain_ids(
+ event_ids, include_given=True, ignore_events=ignore_events,
+ )
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 24b7c0faef..9bf98d06f2 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -69,9 +69,9 @@ def resolve_events_with_store(
unconflicted_state, conflicted_state = _seperate(state_sets)
- needed_events = set(
+ needed_events = {
event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids
- )
+ }
needed_event_count = len(needed_events)
if event_map is not None:
needed_events -= set(iterkeys(event_map))
@@ -261,11 +261,11 @@ def _resolve_state_events(conflicted_state, auth_events):
def _resolve_auth_events(events, auth_events):
- reverse = [i for i in reversed(_ordered_events(events))]
+ reverse = list(reversed(_ordered_events(events)))
- auth_keys = set(
+ auth_keys = {
key for event in events for key in event_auth.auth_types_for_event(event)
- )
+ }
new_auth_events = {}
for key in auth_keys:
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 531018c6a5..0ffe6d8c14 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -105,7 +105,7 @@ def resolve_events_with_store(
% (room_id, event.event_id, event.room_id,)
)
- full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
+ full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
@@ -233,7 +233,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
auth_sets = []
for state_set in state_sets:
- auth_ids = set(
+ auth_ids = {
eid
for key, eid in iteritems(state_set)
if (
@@ -246,9 +246,9 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
)
)
and eid not in common
- )
+ }
- auth_chain = yield state_res_store.get_auth_chain(auth_ids)
+ auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
auth_ids.update(auth_chain)
auth_sets.append(auth_ids)
@@ -275,7 +275,7 @@ def _seperate(state_sets):
conflicted_state = {}
for key in set(itertools.chain.from_iterable(state_sets)):
- event_ids = set(state_set.get(key) for state_set in state_sets)
+ event_ids = {state_set.get(key) for state_set in state_sets}
if len(event_ids) == 1:
unconflicted_state[key] = event_ids.pop()
else:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index da3b99f93d..13de5f1f62 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -56,7 +56,7 @@ class SQLBaseStore(metaclass=ABCMeta):
members_changed (iterable[str]): The user_ids of members that have
changed
"""
- for host in set(get_domain_from_id(u) for u in members_changed):
+ for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index bd547f35cf..eb1a7e5002 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -189,7 +189,7 @@ class BackgroundUpdater(object):
keyvalues=None,
retcols=("update_name", "depends_on"),
)
- in_flight = set(update["update_name"] for update in updates)
+ in_flight = {update["update_name"] for update in updates}
for update in updates:
if update["depends_on"] not in in_flight:
self._background_update_queue.append(update["update_name"])
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 2700cca822..acca079f23 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -20,6 +20,7 @@ import logging
import time
from synapse.api.constants import PresenceState
+from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
@@ -117,16 +118,6 @@ class DataStore(
self._clock = hs.get_clock()
self.database_engine = database.engine
- all_users_native = are_all_users_on_domain(
- db_conn.cursor(), database.engine, hs.hostname
- )
- if not all_users_native:
- raise Exception(
- "Found users in database not native to %s!\n"
- "You cannot changed a synapse server_name after it's been configured"
- % (hs.hostname,)
- )
-
self._stream_id_gen = StreamIdGenerator(
db_conn,
"events",
@@ -567,13 +558,26 @@ class DataStore(
)
-def are_all_users_on_domain(txn, database_engine, domain):
+def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
+ """Called before upgrading an existing database to check that it is broadly sane
+ compared with the configuration.
+ """
+ domain = config.server_name
+
sql = database_engine.convert_param_style(
"SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
)
pat = "%:" + domain
- txn.execute(sql, (pat,))
- num_not_matching = txn.fetchall()[0][0]
+ cur.execute(sql, (pat,))
+ num_not_matching = cur.fetchall()[0][0]
if num_not_matching == 0:
- return True
- return False
+ return
+
+ raise Exception(
+ "Found users in database not native to %s!\n"
+ "You cannot changed a synapse server_name after it's been configured"
+ % (domain,)
+ )
+
+
+__all__ = ["DataStore", "check_database_before_upgrade"]
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index b2f39649fd..9c52aa5340 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -35,7 +35,7 @@ def _make_exclusive_regex(services_cache):
exclusive_user_regexes = [
regex.pattern
for service in services_cache
- for regex in service.get_exlusive_user_regexes()
+ for regex in service.get_exclusive_user_regexes()
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
@@ -135,7 +135,7 @@ class ApplicationServiceTransactionWorkerStore(
may be empty.
"""
results = yield self.db.simple_select_list(
- "application_services_state", dict(state=state), ["as_id"]
+ "application_services_state", {"state": state}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
@@ -158,7 +158,7 @@ class ApplicationServiceTransactionWorkerStore(
"""
result = yield self.db.simple_select_one(
"application_services_state",
- dict(as_id=service.id),
+ {"as_id": service.id},
["state"],
allow_none=True,
desc="get_appservice_state",
@@ -177,7 +177,7 @@ class ApplicationServiceTransactionWorkerStore(
A Deferred which resolves when the state was set successfully.
"""
return self.db.simple_upsert(
- "application_services_state", dict(as_id=service.id), dict(state=state)
+ "application_services_state", {"as_id": service.id}, {"state": state}
)
def create_appservice_txn(self, service, events):
@@ -253,13 +253,15 @@ class ApplicationServiceTransactionWorkerStore(
self.db.simple_upsert_txn(
txn,
"application_services_state",
- dict(as_id=service.id),
- dict(last_txn=txn_id),
+ {"as_id": service.id},
+ {"last_txn": txn_id},
)
# Delete txn
self.db.simple_delete_txn(
- txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
+ txn,
+ "application_services_txns",
+ {"txn_id": txn_id, "as_id": service.id},
)
return self.db.runInteraction(
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 13f4c9c72e..e1ccb27142 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -530,7 +530,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
for row in rows
)
- return list(
+ return [
{
"access_token": access_token,
"ip": ip,
@@ -538,7 +538,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
"last_seen": last_seen,
}
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
- )
+ ]
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self):
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index b7617efb80..d55733a4cd 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -137,7 +137,7 @@ class DeviceWorkerStore(SQLBaseStore):
# get the cross-signing keys of the users in the list, so that we can
# determine which of the device changes were cross-signing keys
- users = set(r[0] for r in updates)
+ users = {r[0] for r in updates}
master_key_by_user = {}
self_signing_key_by_user = {}
for user in users:
@@ -446,7 +446,7 @@ class DeviceWorkerStore(SQLBaseStore):
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
- user_ids = set(user_id for user_id, _ in query_list)
+ user_ids = {user_id for user_id, _ in query_list}
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
# We go and check if any of the users need to have their device lists
@@ -454,10 +454,9 @@ class DeviceWorkerStore(SQLBaseStore):
users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
user_ids
)
- user_ids_in_cache = (
- set(user_id for user_id, stream_id in user_map.items() if stream_id)
- - users_needing_resync
- )
+ user_ids_in_cache = {
+ user_id for user_id, stream_id in user_map.items() if stream_id
+ } - users_needing_resync
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
@@ -604,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore):
rows = yield self.db.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
)
- return set(user for row in rows for user in json.loads(row[0]))
+ return {user for row in rows for user in json.loads(row[0])}
else:
return set()
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index e551606f9d..001a53f9b4 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -680,11 +680,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
'user_signing' for a user-signing key
key (dict): the key data
"""
- # the cross-signing keys need to occupy the same namespace as devices,
- # since signatures are identified by device ID. So add an entry to the
- # device table to make sure that we don't have a collision with device
- # IDs
-
# the 'key' dict will look something like:
# {
# "user_id": "@alice:example.com",
@@ -701,16 +696,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# The "keys" property must only have one entry, which will be the public
# key, so we just grab the first value in there
pubkey = next(iter(key["keys"].values()))
- self.db.simple_insert_txn(
- txn,
- "devices",
- values={
- "user_id": user_id,
- "device_id": pubkey,
- "display_name": key_type + " signing key",
- "hidden": True,
- },
- )
+
+ # The cross-signing keys need to occupy the same namespace as devices,
+ # since signatures are identified by device ID. So add an entry to the
+ # device table to make sure that we don't have a collision with device
+ # IDs.
+ # We only need to do this for local users, since remote servers should be
+ # responsible for checking this for their own users.
+ if self.hs.is_mine_id(user_id):
+ self.db.simple_insert_txn(
+ txn,
+ "devices",
+ values={
+ "user_id": user_id,
+ "device_id": pubkey,
+ "display_name": key_type + " signing key",
+ "hidden": True,
+ },
+ )
# and finally, store the key itself
with self._cross_signing_id_gen.get_next() as stream_id:
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 60c67457b4..49a7b8b433 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -14,8 +14,8 @@
# limitations under the License.
import itertools
import logging
+from typing import List, Optional, Set
-from six.moves import range
from six.moves.queue import Empty, PriorityQueue
from twisted.internet import defer
@@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.descriptors import cached
+from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
@@ -46,21 +47,37 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_ids, include_given=include_given
).addCallback(self.get_events_as_list)
- def get_auth_chain_ids(self, event_ids, include_given=False):
+ def get_auth_chain_ids(
+ self,
+ event_ids: List[str],
+ include_given: bool = False,
+ ignore_events: Optional[Set[str]] = None,
+ ):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
- event_ids (list): state events
- include_given (bool): include the given events in result
+ event_ids: state events
+ include_given: include the given events in result
+ ignore_events: Set of events to exclude from the returned auth
+ chain. This is useful if the caller will just discard the
+ given events anyway, and saves us from figuring out their auth
+ chains if not required.
Returns:
list of event_ids
"""
return self.db.runInteraction(
- "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
+ "get_auth_chain_ids",
+ self._get_auth_chain_ids_txn,
+ event_ids,
+ include_given,
+ ignore_events,
)
- def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
+ if ignore_events is None:
+ ignore_events = set()
+
if include_given:
results = set(event_ids)
else:
@@ -71,15 +88,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
front = set(event_ids)
while front:
new_front = set()
- front_list = list(front)
- chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
- for chunk in chunks:
+ for chunk in batch_iter(front, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", chunk
)
- txn.execute(base_sql + clause, list(args))
- new_front.update([r[0] for r in txn])
+ txn.execute(base_sql + clause, args)
+ new_front.update(r[0] for r in txn)
+ new_front -= ignore_events
new_front -= results
front = new_front
@@ -410,7 +426,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
query, (room_id, event_id, False, limit - len(event_results))
)
- new_results = set(t[0] for t in txn) - seen_events
+ new_results = {t[0] for t in txn} - seen_events
new_front |= new_results
seen_events |= new_results
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index c9d0d68c3a..8ae23df00a 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -145,7 +145,7 @@ class EventsStore(
return txn.fetchall()
res = yield self.db.runInteraction("read_forward_extremities", fetch)
- self._current_forward_extremities_amount = c_counter(list(x[0] for x in res))
+ self._current_forward_extremities_amount = c_counter([x[0] for x in res])
@_retry_on_integrity_error
@defer.inlineCallbacks
@@ -598,11 +598,11 @@ class EventsStore(
# We find out which membership events we may have deleted
# and which we have added, then we invlidate the caches for all
# those users.
- members_changed = set(
+ members_changed = {
state_key
for ev_type, state_key in itertools.chain(to_delete, to_insert)
if ev_type == EventTypes.Member
- )
+ }
for member in members_changed:
txn.call_after(
@@ -1615,7 +1615,7 @@ class EventsStore(
"""
)
- referenced_state_groups = set(sg for sg, in txn)
+ referenced_state_groups = {sg for sg, in txn}
logger.info(
"[purge] found %i referenced state groups", len(referenced_state_groups)
)
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 5177b71016..f54c8b1ee0 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -402,7 +402,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
keyvalues={},
retcols=("room_id",),
)
- room_ids = set(row["room_id"] for row in rows)
+ room_ids = {row["room_id"] for row in rows}
for room_id in room_ids:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 7251e819f5..47a3a26072 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -494,9 +494,9 @@ class EventsWorkerStore(SQLBaseStore):
"""
with Measure(self._clock, "_fetch_event_list"):
try:
- events_to_fetch = set(
+ events_to_fetch = {
event_id for events, _ in event_list for event_id in events
- )
+ }
row_dict = self.db.new_transaction(
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
@@ -804,7 +804,7 @@ class EventsWorkerStore(SQLBaseStore):
desc="have_events_in_timeline",
)
- return set(r["event_id"] for r in rows)
+ return {r["event_id"] for r in rows}
@defer.inlineCallbacks
def have_seen_events(self, event_ids):
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index 2b52cf9c1a..3dc4451447 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,9 +17,13 @@
from twisted.internet import defer
from synapse.api.errors import StoreError
+
+from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.roommember import ProfileInfo
+BATCH_SIZE = 100
+
class ProfileWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
@@ -57,6 +62,54 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ def get_latest_profile_replication_batch_number(self):
+ def f(txn):
+ txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
+ rows = self.db.cursor_to_dict(txn)
+ return rows[0]["maxbatch"]
+
+ return self.db.runInteraction("get_latest_profile_replication_batch_number", f)
+
+ def get_profile_batch(self, batchnum):
+ return self.db.simple_select_list(
+ table="profiles",
+ keyvalues={"batch": batchnum},
+ retcols=("user_id", "displayname", "avatar_url", "active"),
+ desc="get_profile_batch",
+ )
+
+ def assign_profile_batch(self):
+ def f(txn):
+ sql = (
+ "UPDATE profiles SET batch = "
+ "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) "
+ "WHERE user_id in ("
+ " SELECT user_id FROM profiles WHERE batch is NULL limit ?"
+ ")"
+ )
+ txn.execute(sql, (BATCH_SIZE,))
+ return txn.rowcount
+
+ return self.db.runInteraction("assign_profile_batch", f)
+
+ def get_replication_hosts(self):
+ def f(txn):
+ txn.execute(
+ "SELECT host, last_synced_batch FROM profile_replication_status"
+ )
+ rows = self.db.cursor_to_dict(txn)
+ return {r["host"]: r["last_synced_batch"] for r in rows}
+
+ return self.db.runInteraction("get_replication_hosts", f)
+
+ def update_replication_batch_for_host(self, host, last_synced_batch):
+ return self.db.simple_upsert(
+ table="profile_replication_status",
+ keyvalues={"host": host},
+ values={"last_synced_batch": last_synced_batch},
+ desc="update_replication_batch_for_host",
+ )
+
def get_from_remote_profile_cache(self, user_id):
return self.db.simple_select_one(
table="remote_profile_cache",
@@ -71,24 +124,53 @@ class ProfileWorkerStore(SQLBaseStore):
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self.db.simple_update_one(
+ def set_profile_displayname(self, user_localpart, new_displayname, batchnum):
+ return self.db.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
+ values={"displayname": new_displayname, "batch": batchnum},
desc="set_profile_displayname",
+ lock=False, # we can do this because user_id has a unique index
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self.db.simple_update_one(
+ def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum):
+ return self.db.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
+ values={"avatar_url": new_avatar_url, "batch": batchnum},
desc="set_profile_avatar_url",
+ lock=False, # we can do this because user_id has a unique index
+ )
+
+ def set_profile_active(self, user_localpart, active, hide, batchnum):
+ values = {"active": int(active), "batch": batchnum}
+ if not active and not hide:
+ # we are deactivating for real (not in hide mode)
+ # so clear the profile.
+ values["avatar_url"] = None
+ values["displayname"] = None
+ return self.db.simple_upsert(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ values=values,
+ desc="set_profile_active",
+ lock=False, # we can do this because user_id has a unique index
)
class ProfileStore(ProfileWorkerStore):
+ def __init__(self, database, db_conn, hs):
+
+ super(ProfileStore, self).__init__(database, db_conn, hs)
+
+ self.db.updates.register_background_index_update(
+ "profile_replication_status_host_index",
+ index_name="profile_replication_status_idx",
+ table="profile_replication_status",
+ columns=["host"],
+ unique=True,
+ )
+
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
@@ -107,7 +189,7 @@ class ProfileStore(ProfileWorkerStore):
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self.db.simple_update(
+ return self.db.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index e2673ae073..62ac88d9f2 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -276,21 +276,21 @@ class PushRulesWorkerStore(
# We ignore app service users for now. This is so that we don't fill
# up the `get_if_users_have_pushers` cache with AS entries that we
# know don't have pushers, nor even read receipts.
- local_users_in_room = set(
+ local_users_in_room = {
u
for u in users_in_room
if self.hs.is_mine_id(u)
and not self.get_if_app_services_interested_in_user(u)
- )
+ }
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room, on_invalidate=cache_context.invalidate
)
- user_ids = set(
+ user_ids = {
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
- )
+ }
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
room_id, on_invalidate=cache_context.invalidate
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 6b03233262..547b9d69cb 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -197,6 +197,84 @@ class PusherWorkerStore(SQLBaseStore):
return result
+ @defer.inlineCallbacks
+ def update_pusher_last_stream_ordering(
+ self, app_id, pushkey, user_id, last_stream_ordering
+ ):
+ yield self.db.simple_update_one(
+ "pushers",
+ {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+ {"last_stream_ordering": last_stream_ordering},
+ desc="update_pusher_last_stream_ordering",
+ )
+
+ @defer.inlineCallbacks
+ def update_pusher_last_stream_ordering_and_success(
+ self, app_id, pushkey, user_id, last_stream_ordering, last_success
+ ):
+ """Update the last stream ordering position we've processed up to for
+ the given pusher.
+
+ Args:
+ app_id (str)
+ pushkey (str)
+ last_stream_ordering (int)
+ last_success (int)
+
+ Returns:
+ Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+ """
+ updated = yield self.db.simple_update(
+ table="pushers",
+ keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+ updatevalues={
+ "last_stream_ordering": last_stream_ordering,
+ "last_success": last_success,
+ },
+ desc="update_pusher_last_stream_ordering_and_success",
+ )
+
+ return bool(updated)
+
+ @defer.inlineCallbacks
+ def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
+ yield self.db.simple_update(
+ table="pushers",
+ keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+ updatevalues={"failing_since": failing_since},
+ desc="update_pusher_failing_since",
+ )
+
+ @defer.inlineCallbacks
+ def get_throttle_params_by_room(self, pusher_id):
+ res = yield self.db.simple_select_list(
+ "pusher_throttle",
+ {"pusher": pusher_id},
+ ["room_id", "last_sent_ts", "throttle_ms"],
+ desc="get_throttle_params_by_room",
+ )
+
+ params_by_room = {}
+ for row in res:
+ params_by_room[row["room_id"]] = {
+ "last_sent_ts": row["last_sent_ts"],
+ "throttle_ms": row["throttle_ms"],
+ }
+
+ return params_by_room
+
+ @defer.inlineCallbacks
+ def set_throttle_params(self, pusher_id, room_id, params):
+ # no need to lock because `pusher_throttle` has a primary key on
+ # (pusher, room_id) so simple_upsert will retry
+ yield self.db.simple_upsert(
+ "pusher_throttle",
+ {"pusher": pusher_id, "room_id": room_id},
+ params,
+ desc="set_throttle_params",
+ lock=False,
+ )
+
class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self):
@@ -282,81 +360,3 @@ class PusherStore(PusherWorkerStore):
with self._pushers_id_gen.get_next() as stream_id:
yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
-
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering(
- self, app_id, pushkey, user_id, last_stream_ordering
- ):
- yield self.db.simple_update_one(
- "pushers",
- {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
- {"last_stream_ordering": last_stream_ordering},
- desc="update_pusher_last_stream_ordering",
- )
-
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering_and_success(
- self, app_id, pushkey, user_id, last_stream_ordering, last_success
- ):
- """Update the last stream ordering position we've processed up to for
- the given pusher.
-
- Args:
- app_id (str)
- pushkey (str)
- last_stream_ordering (int)
- last_success (int)
-
- Returns:
- Deferred[bool]: True if the pusher still exists; False if it has been deleted.
- """
- updated = yield self.db.simple_update(
- table="pushers",
- keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
- updatevalues={
- "last_stream_ordering": last_stream_ordering,
- "last_success": last_success,
- },
- desc="update_pusher_last_stream_ordering_and_success",
- )
-
- return bool(updated)
-
- @defer.inlineCallbacks
- def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
- yield self.db.simple_update(
- table="pushers",
- keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
- updatevalues={"failing_since": failing_since},
- desc="update_pusher_failing_since",
- )
-
- @defer.inlineCallbacks
- def get_throttle_params_by_room(self, pusher_id):
- res = yield self.db.simple_select_list(
- "pusher_throttle",
- {"pusher": pusher_id},
- ["room_id", "last_sent_ts", "throttle_ms"],
- desc="get_throttle_params_by_room",
- )
-
- params_by_room = {}
- for row in res:
- params_by_room[row["room_id"]] = {
- "last_sent_ts": row["last_sent_ts"],
- "throttle_ms": row["throttle_ms"],
- }
-
- return params_by_room
-
- @defer.inlineCallbacks
- def set_throttle_params(self, pusher_id, room_id, params):
- # no need to lock because `pusher_throttle` has a primary key on
- # (pusher, room_id) so simple_upsert will retry
- yield self.db.simple_upsert(
- "pusher_throttle",
- {"pusher": pusher_id, "room_id": room_id},
- params,
- desc="set_throttle_params",
- lock=False,
- )
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 96e54d145e..0d932a0672 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
- return set(r["user_id"] for r in receipts)
+ return {r["user_id"] for r in receipts}
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
@@ -283,7 +283,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
- return list(r[0:5] + (json.loads(r[5]),) for r in txn)
+ return [r[0:5] + (json.loads(r[5]),) for r in txn]
return self.db.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 3e53c8568a..035fe348b0 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -158,6 +158,28 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@defer.inlineCallbacks
+ def get_expired_users(self):
+ """Get IDs of all expired users
+
+ Returns:
+ Deferred[list[str]]: List of expired user IDs
+ """
+
+ def get_expired_users_txn(txn, now_ms):
+ sql = """
+ SELECT user_id from account_validity
+ WHERE expiration_ts_ms <= ?
+ """
+ txn.execute(sql, (now_ms,))
+ rows = txn.fetchall()
+ return [row[0] for row in rows]
+
+ res = yield self.db.runInteraction(
+ "get_expired_users", get_expired_users_txn, self.clock.time_msec()
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
def set_renewal_token_for_user(self, user_id, renewal_token):
"""Defines a renewal token for a given user.
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 9a17e336ba..511316938d 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -295,6 +295,24 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ @defer.inlineCallbacks
+ def is_room_published(self, room_id):
+ """Check whether a room has been published in the local public room
+ directory.
+
+ Args:
+ room_id (str)
+ Returns:
+ bool: Whether the room is currently published in the room directory
+ """
+ # Get room information
+ room_info = yield self.get_room(room_id)
+ if not room_info:
+ defer.returnValue(False)
+
+ # Check the is_public value
+ defer.returnValue(room_info.get("is_public", False))
+
async def get_rooms_paginate(
self,
start: int,
@@ -449,6 +467,11 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
"""
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum, in order not to filter out events we should filter out when sending to
+ # the client.
+ if not self.config.retention_enabled:
+ defer.returnValue({"min_lifetime": None, "max_lifetime": None})
def get_retention_policy_for_room_txn(txn):
txn.execute(
@@ -954,6 +977,23 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.config = hs.config
+ async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion):
+ """Ensure that the room is stored in the table
+
+ Called when we join a room over federation, and overwrites any room version
+ currently in the table.
+ """
+ await self.db.simple_upsert(
+ desc="upsert_room_on_join",
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ values={"room_version": room_version.identifier},
+ insertion_values={"is_public": False, "creator": ""},
+ # rooms has a unique constraint on room_id, so no need to lock when doing an
+ # emulated upsert.
+ lock=False,
+ )
+
@defer.inlineCallbacks
def store_room(
self,
@@ -1003,6 +1043,26 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
+ async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion):
+ """
+ When we receive an invite over federation, store the version of the room if we
+ don't already know the room version.
+ """
+ await self.db.simple_upsert(
+ desc="maybe_store_room_on_invite",
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ values={},
+ insertion_values={
+ "room_version": room_version.identifier,
+ "is_public": False,
+ "creator": "",
+ },
+ # rooms has a unique constraint on room_id, so no need to lock when doing an
+ # emulated upsert.
+ lock=False,
+ )
+
@defer.inlineCallbacks
def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id):
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index d5ced05701..d5bd0cb5cf 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -465,7 +465,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql % (clause,), args)
- return set(row[0] for row in txn)
+ return {row[0] for row in txn}
return await self.db.runInteraction(
"get_users_server_still_shares_room_with",
@@ -826,7 +826,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
GROUP BY room_id, user_id;
"""
txn.execute(sql, (user_id,))
- return set(row[0] for row in txn if row[1] == 0)
+ return {row[0] for row in txn if row[1] == 0}
return self.db.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..20c5af2eb7 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
@@ -658,10 +658,19 @@ CREATE TABLE presence_stream (
+CREATE TABLE profile_replication_status (
+ host text NOT NULL,
+ last_synced_batch bigint NOT NULL
+);
+
+
+
CREATE TABLE profiles (
user_id text NOT NULL,
displayname text,
- avatar_url text
+ avatar_url text,
+ batch bigint,
+ active smallint DEFAULT 1 NOT NULL
);
@@ -1788,6 +1797,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id);
+CREATE INDEX profiles_batch_idx ON profiles USING btree (batch);
+
+
+
CREATE INDEX public_room_index ON rooms USING btree (is_public);
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..e28ec3fa45 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us
CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) );
CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) );
CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL );
-CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) );
+CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) );
CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) );
CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER );
CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) );
@@ -202,6 +202,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id);
CREATE INDEX group_invites_u_idx ON group_invites(user_id);
CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL );
CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL );
CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
index bbd3f18604..c00f287190 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/README.md
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -1,13 +1,21 @@
-# Building full schema dumps
+# Synapse Database Schemas
-These schemas need to be made from a database that has had all background updates run.
+These schemas are used as a basis to create brand new Synapse databases, on both
+SQLite3 and Postgres.
-To do so, use `scripts-dev/make_full_schema.sh`. This will produce
-`full.sql.postgres ` and `full.sql.sqlite` files.
+## Building full schema dumps
+
+If you want to recreate these schemas, they need to be made from a database that
+has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce new
+`full.sql.postgres ` and `full.sql.sqlite` files.
Ensure postgres is installed and your user has the ability to run bash commands
-such as `createdb`.
+such as `createdb`, then call
+
+ ./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
-```
-./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
-```
+There are currently two folders with full-schema snapshots. `16` is a snapshot
+from 2015, for historical reference. The other contains the most recent full
+schema snapshot.
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 3d34103e67..3a3b9a8e72 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -321,7 +321,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_referenced_state_groups",
)
- return set(row["state_group"] for row in rows)
+ return {row["state_group"] for row in rows}
class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
@@ -367,7 +367,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
"""
txn.execute(sql, (last_room_id, batch_size))
- room_ids = list(row[0] for row in txn)
+ room_ids = [row[0] for row in txn]
if not room_ids:
return True, set()
@@ -384,7 +384,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
- joined_room_ids = set(row[0] for row in txn)
+ joined_room_ids = {row[0] for row in txn}
left_rooms = set(room_ids) - joined_room_ids
@@ -404,7 +404,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
retcols=("state_key",),
)
- potentially_left_users = set(row["state_key"] for row in rows)
+ potentially_left_users = {row["state_key"] for row in rows}
# Now lets actually delete the rooms from the DB.
self.db.simple_delete_many_txn(
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 056b25b13a..ada5cce6c2 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -346,11 +346,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_key (str): The room_key portion of a StreamToken
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
- return set(
+ return {
room_id
for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
- )
+ }
@defer.inlineCallbacks
def get_room_events_stream_for_room(
@@ -679,11 +679,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
events_before = yield self.get_events_as_list(
- [e for e in results["before"]["event_ids"]], get_prev_content=True
+ list(results["before"]["event_ids"]), get_prev_content=True
)
events_after = yield self.get_events_as_list(
- [e for e in results["after"]["event_ids"]], get_prev_content=True
+ list(results["after"]["event_ids"]), get_prev_content=True
)
return {
diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index af8025bc17..ec6b8a4ffd 100644
--- a/synapse/storage/data_stores/main/user_erasure_store.py
+++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -63,9 +63,9 @@ class UserErasureWorkerStore(SQLBaseStore):
retcols=("user_id",),
desc="are_users_erased",
)
- erased_users = set(row["user_id"] for row in rows)
+ erased_users = {row["user_id"] for row in rows}
- res = dict((u, u in erased_users) for u in user_ids)
+ res = {u: u in erased_users for u in user_ids}
return res
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index c4ee9b7ccb..57a5267663 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -520,11 +520,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
retcols=("state_group",),
)
- remaining_state_groups = set(
+ remaining_state_groups = {
row["state_group"]
for row in rows
if row["state_group"] not in state_groups_to_delete
- )
+ }
logger.info(
"[purge] de-delta-ing %i remaining state groups",
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 3eeb2f7c04..609db40616 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -15,9 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import sys
import time
-from typing import Iterable, Tuple
+from time import monotonic as monotonic_time
+from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from six import iteritems, iterkeys, itervalues
from six.moves import intern, range
@@ -32,24 +32,14 @@ from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.types import Connection, Cursor
from synapse.util.stringutils import exception_to_unicode
-# import a function which will return a monotonic time, in seconds
-try:
- # on python 3, use time.monotonic, since time.clock can go backwards
- from time import monotonic as monotonic_time
-except ImportError:
- # ... but python 2 doesn't have it
- from time import clock as monotonic_time
-
logger = logging.getLogger(__name__)
-try:
- MAX_TXN_ID = sys.maxint - 1
-except AttributeError:
- # python 3 does not have a maximum int value
- MAX_TXN_ID = 2 ** 63 - 1
+# python 3 does not have a maximum int value
+MAX_TXN_ID = 2 ** 63 - 1
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
@@ -77,7 +67,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool(
- reactor, db_config: DatabaseConnectionConfig, engine
+ reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database.
"""
@@ -90,7 +80,9 @@ def make_pool(
)
-def make_conn(db_config: DatabaseConnectionConfig, engine):
+def make_conn(
+ db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+) -> Connection:
"""Make a new connection to the database and return it.
Returns:
@@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
return db_conn
-class LoggingTransaction(object):
+# The type of entry which goes on our after_callbacks and exception_callbacks lists.
+#
+# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
+# that mypy sees the type but the runtime python doesn't.
+_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
+
+
+class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method.
Args:
txn: The database transcation object to wrap.
- name (str): The name of this transactions for logging.
- database_engine (Sqlite3Engine|PostgresEngine)
- after_callbacks(list|None): A list that callbacks will be appended to
+ name: The name of this transactions for logging.
+ database_engine
+ after_callbacks: A list that callbacks will be appended to
that have been added by `call_after` which should be run on
successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run.
- exception_callbacks(list|None): A list that callbacks will be appended
+ exception_callbacks: A list that callbacks will be appended
to that have been added by `call_on_exception` which should be run
if transaction ends with an error. None indicates that no callbacks
should be allowed to be scheduled to run.
@@ -135,46 +134,67 @@ class LoggingTransaction(object):
]
def __init__(
- self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
+ self,
+ txn: Cursor,
+ name: str,
+ database_engine: BaseDatabaseEngine,
+ after_callbacks: Optional[List[_CallbackListEntry]] = None,
+ exception_callbacks: Optional[List[_CallbackListEntry]] = None,
):
- object.__setattr__(self, "txn", txn)
- object.__setattr__(self, "name", name)
- object.__setattr__(self, "database_engine", database_engine)
- object.__setattr__(self, "after_callbacks", after_callbacks)
- object.__setattr__(self, "exception_callbacks", exception_callbacks)
+ self.txn = txn
+ self.name = name
+ self.database_engine = database_engine
+ self.after_callbacks = after_callbacks
+ self.exception_callbacks = exception_callbacks
- def call_after(self, callback, *args, **kwargs):
+ def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
"""
+ # if self.after_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
- def call_on_exception(self, callback, *args, **kwargs):
+ def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+ # if self.exception_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
- def __getattr__(self, name):
- return getattr(self.txn, name)
+ def fetchall(self) -> List[Tuple]:
+ return self.txn.fetchall()
- def __setattr__(self, name, value):
- setattr(self.txn, name, value)
+ def fetchone(self) -> Tuple:
+ return self.txn.fetchone()
- def __iter__(self):
+ def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()
+ @property
+ def rowcount(self) -> int:
+ return self.txn.rowcount
+
+ @property
+ def description(self) -> Any:
+ return self.txn.description
+
def execute_batch(self, sql, args):
if isinstance(self.database_engine, PostgresEngine):
- from psycopg2.extras import execute_batch
+ from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
for val in args:
self.execute(sql, val)
- def execute(self, sql, *args):
+ def execute(self, sql: str, *args: Any):
self._do_execute(self.txn.execute, sql, *args)
- def executemany(self, sql, *args):
+ def executemany(self, sql: str, *args: Any):
self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql):
@@ -207,6 +227,9 @@ class LoggingTransaction(object):
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)
+ def close(self):
+ self.txn.close()
+
class PerformanceCounters(object):
def __init__(self):
@@ -251,7 +274,9 @@ class Database(object):
_TXN_ID = 0
- def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
+ def __init__(
+ self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ ):
self.hs = hs
self._clock = hs.get_clock()
self._database_config = database_config
@@ -259,9 +284,9 @@ class Database(object):
self.updates = BackgroundUpdater(hs, self)
- self._previous_txn_total_time = 0
- self._current_txn_total_time = 0
- self._previous_loop_ts = 0
+ self._previous_txn_total_time = 0.0
+ self._current_txn_total_time = 0.0
+ self._previous_loop_ts = 0.0
# TODO(paul): These can eventually be removed once the metrics code
# is running in mainline, and we have some nice monitoring frontends
@@ -463,23 +488,23 @@ class Database(object):
sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks
- def runInteraction(self, desc, func, *args, **kwargs):
+ def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
"""Starts a transaction on the database and runs a given function
Arguments:
- desc (str): description of the transaction, for logging and metrics
- func (func): callback function, which will be called with a
+ desc: description of the transaction, for logging and metrics
+ func: callback function, which will be called with a
database transaction (twisted.enterprise.adbapi.Transaction) as
its first argument, followed by `args` and `kwargs`.
- args (list): positional args to pass to `func`
- kwargs (dict): named args to pass to `func`
+ args: positional args to pass to `func`
+ kwargs: named args to pass to `func`
Returns:
Deferred: The result of func
"""
- after_callbacks = []
- exception_callbacks = []
+ after_callbacks = [] # type: List[_CallbackListEntry]
+ exception_callbacks = [] # type: List[_CallbackListEntry]
if LoggingContext.current_context() == LoggingContext.sentinel:
logger.warning("Starting db txn '%s' from sentinel context", desc)
@@ -505,15 +530,15 @@ class Database(object):
return result
@defer.inlineCallbacks
- def runWithConnection(self, func, *args, **kwargs):
+ def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
- func (func): callback function, which will be called with a
+ func: callback function, which will be called with a
database connection (twisted.enterprise.adbapi.Connection) as
its first argument, followed by `args` and `kwargs`.
- args (list): positional args to pass to `func`
- kwargs (dict): named args to pass to `func`
+ args: positional args to pass to `func`
+ kwargs: named args to pass to `func`
Returns:
Deferred: The result of func
@@ -554,8 +579,8 @@ class Database(object):
Returns:
A list of dicts where the key is the column header.
"""
- col_headers = list(intern(str(column[0])) for column in cursor.description)
- results = list(dict(zip(col_headers, row)) for row in cursor)
+ col_headers = [intern(str(column[0])) for column in cursor.description]
+ results = [dict(zip(col_headers, row)) for row in cursor]
return results
def execute(self, desc, decoder, query, *args):
@@ -800,7 +825,7 @@ class Database(object):
return False
# We didn't find any existing rows, so insert a new one
- allvalues = {}
+ allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
@@ -829,7 +854,7 @@ class Database(object):
Returns:
None
"""
- allvalues = {}
+ allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(insertion_values)
@@ -916,7 +941,7 @@ class Database(object):
Returns:
None
"""
- allnames = []
+ allnames = [] # type: List[str]
allnames.extend(key_names)
allnames.extend(value_names)
@@ -1100,7 +1125,7 @@ class Database(object):
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
- results = []
+ results = [] # type: List[Dict[str, Any]]
if not iterable:
return results
@@ -1439,7 +1464,7 @@ class Database(object):
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
where_clause = "WHERE " if filters or keyvalues else ""
- arg_list = []
+ arg_list = [] # type: List[Any]
if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values())
@@ -1504,7 +1529,7 @@ class Database(object):
def make_in_list_sql_clause(
database_engine, column: str, iterable: Iterable
-) -> Tuple[str, Iterable]:
+) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable.
On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9d2d519922..035f9ea6e9 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,29 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import importlib
import platform
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine
-SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
-
-def create_engine(database_config):
+def create_engine(database_config) -> BaseDatabaseEngine:
name = database_config["name"]
- engine_class = SUPPORTED_MODULE.get(name, None)
- if engine_class:
+ if name == "sqlite3":
+ import sqlite3
+
+ return Sqlite3Engine(sqlite3, database_config)
+
+ if name == "psycopg2":
# pypy requires psycopg2cffi rather than psycopg2
- if name == "psycopg2" and platform.python_implementation() == "PyPy":
- name = "psycopg2cffi"
- module = importlib.import_module(name)
- return engine_class(module, database_config)
+ if platform.python_implementation() == "PyPy":
+ import psycopg2cffi as psycopg2 # type: ignore
+ else:
+ import psycopg2 # type: ignore
+
+ return PostgresEngine(psycopg2, database_config)
raise RuntimeError("Unsupported database engine '%s'" % (name,))
-__all__ = ["create_engine", "IncorrectDatabaseSetup"]
+__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ec5a4d198b..ab0bbe4bd3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -12,7 +12,94 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
+from typing import Generic, TypeVar
+
+from synapse.storage.types import Connection
class IncorrectDatabaseSetup(RuntimeError):
pass
+
+
+ConnectionType = TypeVar("ConnectionType", bound=Connection)
+
+
+class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
+ def __init__(self, module, database_config: dict):
+ self.module = module
+
+ @property
+ @abc.abstractmethod
+ def single_threaded(self) -> bool:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def can_native_upsert(self) -> bool:
+ """
+ Do we support native UPSERTs?
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def supports_tuple_comparison(self) -> bool:
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def supports_using_any_list(self) -> bool:
+ """
+ Do we support using `a = ANY(?)` and passing a list
+ """
+ ...
+
+ @abc.abstractmethod
+ def check_database(
+ self, db_conn: ConnectionType, allow_outdated_version: bool = False
+ ) -> None:
+ ...
+
+ @abc.abstractmethod
+ def check_new_database(self, txn) -> None:
+ """Gets called when setting up a brand new database. This allows us to
+ apply stricter checks on new databases versus existing database.
+ """
+ ...
+
+ @abc.abstractmethod
+ def convert_param_style(self, sql: str) -> str:
+ ...
+
+ @abc.abstractmethod
+ def on_new_connection(self, db_conn: ConnectionType) -> None:
+ ...
+
+ @abc.abstractmethod
+ def is_deadlock(self, error: Exception) -> bool:
+ ...
+
+ @abc.abstractmethod
+ def is_connection_closed(self, conn: ConnectionType) -> bool:
+ ...
+
+ @abc.abstractmethod
+ def lock_table(self, txn, table: str) -> None:
+ ...
+
+ @abc.abstractmethod
+ def get_next_state_group_id(self, txn) -> int:
+ """Returns an int that can be used as a new state_group ID
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def server_version(self) -> str:
+ """Gets a string giving the server version. For example: '3.22.0'
+ """
+ ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a077345960..6c7d08a6f2 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -15,16 +15,14 @@
import logging
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
logger = logging.getLogger(__name__)
-class PostgresEngine(object):
- single_threaded = False
-
+class PostgresEngine(BaseDatabaseEngine):
def __init__(self, database_module, database_config):
- self.module = database_module
+ super().__init__(database_module, database_config)
self.module.extensions.register_type(self.module.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
@@ -36,6 +34,10 @@ class PostgresEngine(object):
self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet
+ @property
+ def single_threaded(self) -> bool:
+ return False
+
def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
@@ -53,7 +55,7 @@ class PostgresEngine(object):
if rows and rows[0][0] != "UTF8":
raise IncorrectDatabaseSetup(
"Database has incorrect encoding: '%s' instead of 'UTF8'\n"
- "See docs/postgres.rst for more information." % (rows[0][0],)
+ "See docs/postgres.md for more information." % (rows[0][0],)
)
txn.execute(
@@ -62,12 +64,16 @@ class PostgresEngine(object):
collation, ctype = txn.fetchone()
if collation != "C":
logger.warning(
- "Database has incorrect collation of %r. Should be 'C'", collation
+ "Database has incorrect collation of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information.",
+ collation,
)
if ctype != "C":
logger.warning(
- "Database has incorrect ctype of %r. Should be 'C'", ctype
+ "Database has incorrect ctype of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information.",
+ ctype,
)
def check_new_database(self, txn):
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 641e490697..2bfeefd54e 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,16 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import sqlite3
import struct
import threading
+from synapse.storage.engines import BaseDatabaseEngine
-class Sqlite3Engine(object):
- single_threaded = True
+class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
def __init__(self, database_module, database_config):
- self.module = database_module
+ super().__init__(database_module, database_config)
database = database_config.get("args", {}).get("database")
self._is_in_memory = database in (None, ":memory:",)
@@ -32,6 +32,10 @@ class Sqlite3Engine(object):
self._current_state_group_id_lock = threading.Lock()
@property
+ def single_threaded(self) -> bool:
+ return True
+
+ @property
def can_native_upsert(self):
"""
Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
@@ -68,7 +72,6 @@ class Sqlite3Engine(object):
return sql
def on_new_connection(self, db_conn):
-
# We need to import here to avoid an import loop.
from synapse.storage.prepare_database import prepare_database
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index b950550f23..0f9ac1cf09 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -602,14 +602,14 @@ class EventsPersistenceStorage(object):
event_id_to_state_group.update(event_to_groups)
# State groups of old_latest_event_ids
- old_state_groups = set(
+ old_state_groups = {
event_id_to_state_group[evid] for evid in old_latest_event_ids
- )
+ }
# State groups of new_latest_event_ids
- new_state_groups = set(
+ new_state_groups = {
event_id_to_state_group[evid] for evid in new_latest_event_ids
- )
+ }
# If they old and new groups are the same then we don't need to do
# anything.
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index c285ef52a0..6cb7d4b922 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -278,13 +278,17 @@ def _upgrade_existing_database(
the current_version wasn't generated by applying those delta files.
database_engine (DatabaseEngine)
config (synapse.config.homeserver.HomeServerConfig|None):
- application config, or None if we are connecting to an existing
- database which we expect to be configured already
+ None if we are initialising a blank database, otherwise the application
+ config
data_stores (list[str]): The names of the data stores to instantiate
on the given database.
is_empty (bool): Is this a blank database? I.e. do we need to run the
upgrade portions of the delta scripts.
"""
+ if is_empty:
+ assert not applied_delta_files
+ else:
+ assert config
if current_version > SCHEMA_VERSION:
raise ValueError(
@@ -292,6 +296,13 @@ def _upgrade_existing_database(
+ "new for the server to understand"
)
+ # some of the deltas assume that config.server_name is set correctly, so now
+ # is a good time to run the sanity check.
+ if not is_empty and "main" in data_stores:
+ from synapse.storage.data_stores.main import check_database_before_upgrade
+
+ check_database_before_upgrade(cur, database_engine, config)
+
start_ver = current_version
if not upgraded:
start_ver += 1
@@ -345,9 +356,9 @@ def _upgrade_existing_database(
"Could not open delta dir for version %d: %s" % (v, directory)
)
- duplicates = set(
+ duplicates = {
file_name for file_name, count in file_name_counter.items() if count > 1
- )
+ }
if duplicates:
# We don't support using the same file name in the same delta version.
raise PrepareDatabaseException(
@@ -454,7 +465,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
),
(modname,),
)
- applied_deltas = set(d for d, in cur)
+ applied_deltas = {d for d, in cur}
for (name, stream) in names_and_streams:
if name in applied_deltas:
continue
diff --git a/synapse/storage/schema/delta/48/profiles_batch.sql b/synapse/storage/schema/delta/48/profiles_batch.sql
new file mode 100644
index 0000000000..e744c02fe8
--- /dev/null
+++ b/synapse/storage/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Add a batch number to track changes to profiles and the
+ * order they're made in so we can replicate user profiles
+ * to other hosts as they change
+ */
+ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL;
+
+/*
+ * Index on the batch number so we can get profiles
+ * by their batch
+ */
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+
+/*
+ * A table to track what batch of user profiles has been
+ * synced to what profile replication target.
+ */
+CREATE TABLE profile_replication_status (
+ host TEXT NOT NULL,
+ last_synced_batch BIGINT NOT NULL
+);
diff --git a/synapse/storage/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/schema/delta/50/profiles_deactivated_users.sql
new file mode 100644
index 0000000000..c8893ecbe8
--- /dev/null
+++ b/synapse/storage/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * A flag saying whether the user owning the profile has been deactivated
+ * This really belongs on the users table, not here, but the users table
+ * stores users by their full user_id and profiles stores them by localpart,
+ * so we can't easily join between the two tables. Plus, the batch number
+ * realy ought to represent data in this table that has changed.
+ */
+ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;
diff --git a/synapse/storage/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/schema/delta/55/profile_replication_status_index.sql
new file mode 100644
index 0000000000..18a0f7e10c
--- /dev/null
+++ b/synapse/storage/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('profile_replication_status_host_index', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/schema/delta/55/room_retention.sql
index ee6cdf7a14..ee6cdf7a14 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
+++ b/synapse/storage/schema/delta/55/room_retention.sql
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
new file mode 100644
index 0000000000..daff81c5ee
--- /dev/null
+++ b/synapse/storage/types.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Iterable, Iterator, List, Tuple
+
+from typing_extensions import Protocol
+
+
+"""
+Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
+"""
+
+
+class Cursor(Protocol):
+ def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+ ...
+
+ def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+ ...
+
+ def fetchall(self) -> List[Tuple]:
+ ...
+
+ def fetchone(self) -> Tuple:
+ ...
+
+ @property
+ def description(self) -> Any:
+ return None
+
+ @property
+ def rowcount(self) -> int:
+ return 0
+
+ def __iter__(self) -> Iterator[Tuple]:
+ ...
+
+ def close(self) -> None:
+ ...
+
+
+class Connection(Protocol):
+ def cursor(self) -> Cursor:
+ ...
+
+ def close(self) -> None:
+ ...
+
+ def commit(self) -> None:
+ ...
+
+ def rollback(self, *args, **kwargs) -> None:
+ ...
diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
new file mode 100644
index 0000000000..2c9155d15c
--- /dev/null
+++ b/synapse/third_party_rules/access_rules.py
@@ -0,0 +1,588 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import email.utils
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.api.errors import SynapseError
+from synapse.config._base import ConfigError
+from synapse.types import get_domain_from_id
+
+ACCESS_RULES_TYPE = "im.vector.room.access_rules"
+ACCESS_RULE_RESTRICTED = "restricted"
+ACCESS_RULE_UNRESTRICTED = "unrestricted"
+ACCESS_RULE_DIRECT = "direct"
+
+VALID_ACCESS_RULES = (
+ ACCESS_RULE_DIRECT,
+ ACCESS_RULE_RESTRICTED,
+ ACCESS_RULE_UNRESTRICTED,
+)
+
+# Rules to which we need to apply the power levels restrictions.
+#
+# These are all of the rules that neither:
+# * forbid users from joining based on a server blacklist (which means that there
+# is no need to apply power level restrictions), nor
+# * target direct chats (since we allow both users to be room admins in this case).
+#
+# The power-level restrictions, when they are applied, prevent the following:
+# * the default power level for users (users_default) being set to anything other than 0.
+# * a non-default power level being assigned to any user which would be forbidden from
+# joining a restricted room.
+RULES_WITH_RESTRICTED_POWER_LEVELS = (ACCESS_RULE_UNRESTRICTED,)
+
+
+class RoomAccessRules(object):
+ """Implementation of the ThirdPartyEventRules module API that allows federation admins
+ to define custom rules for specific events and actions.
+ Implements the custom behaviour for the "im.vector.room.access_rules" state event.
+
+ Takes a config in the format:
+
+ third_party_event_rules:
+ module: third_party_rules.RoomAccessRules
+ config:
+ # List of domains (server names) that can't be invited to rooms if the
+ # "restricted" rule is set. Defaults to an empty list.
+ domains_forbidden_when_restricted: []
+
+ # Identity server to use when checking the HS an email address belongs to
+ # using the /info endpoint. Required.
+ id_server: "vector.im"
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config, http_client):
+ self.http_client = http_client
+
+ self.id_server = config["id_server"]
+
+ self.domains_forbidden_when_restricted = config.get(
+ "domains_forbidden_when_restricted", []
+ )
+
+ @staticmethod
+ def parse_config(config):
+ if "id_server" in config:
+ return config
+ else:
+ raise ConfigError("No IS for event rules TchapEventRules")
+
+ def on_create_room(self, requester, config, is_requester_admin) -> bool:
+ """Implements synapse.events.ThirdPartyEventRules.on_create_room
+
+ Checks if a im.vector.room.access_rules event is being set during room creation.
+ If yes, make sure the event is correct. Otherwise, append an event with the
+ default rule to the initial state.
+ """
+ is_direct = config.get("is_direct")
+ preset = config.get("preset")
+ access_rule = None
+ join_rule = None
+
+ # If there's a rules event in the initial state, check if it complies with the
+ # spec for im.vector.room.access_rules and deny the request if not.
+ for event in config.get("initial_state", []):
+ if event["type"] == ACCESS_RULES_TYPE:
+ access_rule = event["content"].get("rule")
+
+ # Make sure the event has a valid content.
+ if access_rule is None:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule name is valid.
+ if access_rule not in VALID_ACCESS_RULES:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule is "direct" if the room is a direct chat.
+ if (is_direct and access_rule != ACCESS_RULE_DIRECT) or (
+ access_rule == ACCESS_RULE_DIRECT and not is_direct
+ ):
+ raise SynapseError(400, "Invalid access rule")
+
+ if event["type"] == EventTypes.JoinRules:
+ join_rule = event["content"].get("join_rule")
+
+ if access_rule is None:
+ # If there's no access rules event in the initial state, create one with the
+ # default setting.
+ if is_direct:
+ default_rule = ACCESS_RULE_DIRECT
+ else:
+ # If the default value for non-direct chat changes, we should make another
+ # case here for rooms created with either a "public" join_rule or the
+ # "public_chat" preset to make sure those keep defaulting to "restricted"
+ default_rule = ACCESS_RULE_RESTRICTED
+
+ if not config.get("initial_state"):
+ config["initial_state"] = []
+
+ config["initial_state"].append(
+ {
+ "type": ACCESS_RULES_TYPE,
+ "state_key": "",
+ "content": {"rule": default_rule},
+ }
+ )
+
+ access_rule = default_rule
+
+ # Check that the preset or the join rule in use is compatible with the access
+ # rule, whether it's a user-defined one or the default one (i.e. if it involves
+ # a "public" join rule, the access rule must be "restricted").
+ if (
+ join_rule == JoinRules.PUBLIC or preset == RoomCreationPreset.PUBLIC_CHAT
+ ) and access_rule != ACCESS_RULE_RESTRICTED:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Check if the creator can override values for the power levels.
+ allowed = self._is_power_level_content_allowed(
+ config.get("power_level_content_override", {}), access_rule
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content override")
+
+ # Second loop for events we need to know the current rule to process.
+ for event in config.get("initial_state", []):
+ if event["type"] == EventTypes.PowerLevels:
+ allowed = self._is_power_level_content_allowed(
+ event["content"], access_rule
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content")
+
+ return True
+
+ @defer.inlineCallbacks
+ def check_threepid_can_be_invited(self, medium, address, state_events):
+ """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited
+
+ Check if a threepid can be invited to the room via a 3PID invite given the current
+ rules and the threepid's address, by retrieving the HS it's mapped to from the
+ configured identity server, and checking if we can invite users from it.
+ """
+ rule = self._get_rule_from_state(state_events)
+
+ if medium != "email":
+ defer.returnValue(False)
+
+ if rule != ACCESS_RULE_RESTRICTED:
+ # Only "restricted" requires filtering 3PID invites. We don't need to do
+ # anything for "direct" here, because only "restricted" requires filtering
+ # based on the HS the address is mapped to.
+ defer.returnValue(True)
+
+ parsed_address = email.utils.parseaddr(address)[1]
+ if parsed_address != address:
+ # Avoid reproducing the security issue described here:
+ # https://matrix.org/blog/2019/04/18/security-update-sydent-1-0-2
+ # It's probably not worth it but let's just be overly safe here.
+ defer.returnValue(False)
+
+ # Get the HS this address belongs to from the identity server.
+ res = yield self.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/info" % (self.id_server,),
+ {"medium": medium, "address": address},
+ )
+
+ # Look for a domain that's not forbidden from being invited.
+ if not res.get("hs"):
+ defer.returnValue(False)
+ if res.get("hs") in self.domains_forbidden_when_restricted:
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
+ def check_event_allowed(self, event, state_events):
+ """Implements synapse.events.ThirdPartyEventRules.check_event_allowed
+
+ Checks the event's type and the current rule and calls the right function to
+ determine whether the event can be allowed.
+ """
+ if event.type == ACCESS_RULES_TYPE:
+ return self._on_rules_change(event, state_events)
+
+ # We need to know the rule to apply when processing the event types below.
+ rule = self._get_rule_from_state(state_events)
+
+ if event.type == EventTypes.PowerLevels:
+ return self._is_power_level_content_allowed(event.content, rule)
+
+ if event.type == EventTypes.Member or event.type == EventTypes.ThirdPartyInvite:
+ return self._on_membership_or_invite(event, rule, state_events)
+
+ if event.type == EventTypes.JoinRules:
+ return self._on_join_rule_change(event, rule)
+
+ if event.type == EventTypes.RoomAvatar:
+ return self._on_room_avatar_change(event, rule)
+
+ if event.type == EventTypes.Name:
+ return self._on_room_name_change(event, rule)
+
+ if event.type == EventTypes.Topic:
+ return self._on_room_topic_change(event, rule)
+
+ return True
+
+ def _on_rules_change(self, event, state_events):
+ """Implement the checks and behaviour specified on allowing or forbidding a new
+ im.vector.room.access_rules event.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ new_rule = event.content.get("rule")
+
+ # Check for invalid values.
+ if new_rule not in VALID_ACCESS_RULES:
+ return False
+
+ # We must not allow rooms with the "public" join rule to be given any other access
+ # rule than "restricted".
+ join_rule = self._get_join_rule_from_state(state_events)
+ if join_rule == JoinRules.PUBLIC and new_rule != ACCESS_RULE_RESTRICTED:
+ return False
+
+ # Make sure we don't apply "direct" if the room has more than two members.
+ if new_rule == ACCESS_RULE_DIRECT:
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ if len(existing_members) > 2 or len(threepid_tokens) > 1:
+ return False
+
+ prev_rules_event = state_events.get((ACCESS_RULES_TYPE, ""))
+
+ # Now that we know the new rule doesn't break the "direct" case, we can allow any
+ # new rule in rooms that had none before.
+ if prev_rules_event is None:
+ return True
+
+ prev_rule = prev_rules_event.content.get("rule")
+
+ # Currently, we can only go from "restricted" to "unrestricted".
+ if prev_rule == ACCESS_RULE_RESTRICTED and new_rule == ACCESS_RULE_UNRESTRICTED:
+ return True
+
+ return False
+
+ def _on_membership_or_invite(self, event, rule, state_events):
+ """Applies the correct rule for incoming m.room.member and
+ m.room.third_party_invite events.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ if rule == ACCESS_RULE_RESTRICTED:
+ ret = self._on_membership_or_invite_restricted(event)
+ elif rule == ACCESS_RULE_UNRESTRICTED:
+ ret = self._on_membership_or_invite_unrestricted()
+ elif rule == ACCESS_RULE_DIRECT:
+ ret = self._on_membership_or_invite_direct(event, state_events)
+ else:
+ # We currently apply the default (restricted) if we don't know the rule, we
+ # might want to change that in the future.
+ ret = self._on_membership_or_invite_restricted(event)
+
+ return ret
+
+ def _on_membership_or_invite_restricted(self, event):
+ """Implements the checks and behaviour specified for the "restricted" rule.
+
+ "restricted" currently means that users can only invite users if their server is
+ included in a limited list of domains.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # We're not applying the rules on m.room.third_party_member events here because
+ # the filtering on threepids is done in check_threepid_can_be_invited, which is
+ # called before check_event_allowed.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return True
+
+ # We only need to process "join" and "invite" memberships, in order to be backward
+ # compatible, e.g. if a user from a blacklisted server joined a restricted room
+ # before the rules started being enforced on the server, that user must be able to
+ # leave it.
+ if event.membership not in [Membership.JOIN, Membership.INVITE]:
+ return True
+
+ invitee_domain = get_domain_from_id(event.state_key)
+ return invitee_domain not in self.domains_forbidden_when_restricted
+
+ def _on_membership_or_invite_unrestricted(self):
+ """Implements the checks and behaviour specified for the "unrestricted" rule.
+
+ "unrestricted" currently means that every event is allowed.
+
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return True
+
+ def _on_membership_or_invite_direct(self, event, state_events):
+ """Implements the checks and behaviour specified for the "direct" rule.
+
+ "direct" currently means that no member is allowed apart from the two initial
+ members the room was created for (i.e. the room's creator and their first
+ invitee).
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # Get the room memberships and 3PID invite tokens from the room's state.
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ # There should never be more than one 3PID invite in the room state: if the second
+ # original user came and left, and we're inviting them using their email address,
+ # given we know they have a Matrix account binded to the address (so they could
+ # join the first time), Synapse will successfully look it up before attempting to
+ # store an invite on the IS.
+ if len(threepid_tokens) == 1 and event.type == EventTypes.ThirdPartyInvite:
+ # If we already have a 3PID invite in flight, don't accept another one, unless
+ # the new one has the same invite token as its state key. This is because 3PID
+ # invite revocations must be allowed, and a revocation is basically a new 3PID
+ # invite event with an empty content and the same token as the invite it
+ # revokes.
+ return event.state_key in threepid_tokens
+
+ if len(existing_members) == 2:
+ # If the user was within the two initial user of the room, Synapse would have
+ # looked it up successfully and thus sent a m.room.member here instead of
+ # m.room.third_party_invite.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return False
+
+ # We can only have m.room.member events here. The rule in this case is to only
+ # allow the event if its target is one of the initial two members in the room,
+ # i.e. the state key of one of the two m.room.member states in the room.
+ return event.state_key in existing_members
+
+ # We're alone in the room (and always have been) and there's one 3PID invite in
+ # flight.
+ if len(existing_members) == 1 and len(threepid_tokens) == 1:
+ # We can only have m.room.member events here. In this case, we can only allow
+ # the event if it's either a m.room.member from the joined user (we can assume
+ # that the only m.room.member event is a join otherwise we wouldn't be able to
+ # send an event to the room) or an an invite event which target is the invited
+ # user.
+ target = event.state_key
+ is_from_threepid_invite = self._is_invite_from_threepid(
+ event, threepid_tokens[0]
+ )
+ if is_from_threepid_invite or target == existing_members[0]:
+ return True
+
+ return False
+
+ return True
+
+ def _is_power_level_content_allowed(self, content, access_rule):
+ """Check if a given power levels event is permitted under the given access rule.
+
+ It shouldn't be allowed if it either changes the default PL to a non-0 value or
+ gives a non-0 PL to a user that would have been forbidden from joining the room
+ under a more restrictive access rule.
+
+ Args:
+ content (dict[]): The content of the m.room.power_levels event to check.
+ access_rule (str): The access rule in place in this room.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # Check if we need to apply the restrictions with the current rule.
+ if access_rule not in RULES_WITH_RESTRICTED_POWER_LEVELS:
+ return True
+
+ # If users_default is explicitly set to a non-0 value, deny the event.
+ users_default = content.get("users_default", 0)
+ if users_default:
+ return False
+
+ users = content.get("users", {})
+ for user_id, power_level in users.items():
+ server_name = get_domain_from_id(user_id)
+ # Check the domain against the blacklist. If found, and the PL isn't 0, deny
+ # the event.
+ if (
+ server_name in self.domains_forbidden_when_restricted
+ and power_level != 0
+ ):
+ return False
+
+ return True
+
+ def _on_join_rule_change(self, event, rule):
+ """Check whether a join rule change is allowed. A join rule change is always
+ allowed unless the new join rule is "public" and the current access rule isn't
+ "restricted".
+ The rationale is that external users (those whose server would be denied access
+ to rooms enforcing the "restricted" access rule) should always rely on non-
+ external users for access to rooms, therefore they shouldn't be able to access
+ rooms that don't require an invite to be joined.
+
+ Note that we currently rely on the default access rule being "restricted": during
+ room creation, the m.room.join_rules event will be sent *before* the
+ im.vector.room.access_rules one, so the access rule that will be considered here
+ in this case will be the default "restricted" one. This is fine since the
+ "restricted" access rule allows any value for the join rule, but we should keep
+ that in mind if we need to change the default access rule in the future.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ if event.content.get("join_rule") == JoinRules.PUBLIC:
+ return rule == ACCESS_RULE_RESTRICTED
+
+ return True
+
+ def _on_room_avatar_change(self, event, rule):
+ """Check whether a change of room avatar is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ def _on_room_name_change(self, event, rule):
+ """Check whether a change of room name is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ def _on_room_topic_change(self, event, rule):
+ """Check whether a change of room topic is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ @staticmethod
+ def _get_rule_from_state(state_events):
+ """Extract the rule to be applied from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ str, the name of the rule (either "direct", "restricted" or "unrestricted")
+ """
+ access_rules = state_events.get((ACCESS_RULES_TYPE, ""))
+ if access_rules is None:
+ rule = ACCESS_RULE_RESTRICTED
+ else:
+ rule = access_rules.content.get("rule")
+ return rule
+
+ @staticmethod
+ def _get_join_rule_from_state(state_events):
+ """Extract the room's join rule from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ str, the name of the join rule (either "public", or "invite")
+ """
+ join_rule_event = state_events.get((EventTypes.JoinRules, ""))
+ if join_rule_event is None:
+ return None
+ return join_rule_event.content.get("join_rule")
+
+ @staticmethod
+ def _get_members_and_tokens_from_state(state_events):
+ """Retrieves from a list of state events the list of users that have a
+ m.room.member event in the room, and the tokens of 3PID invites in the room.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ existing_members (list[str]): List of targets of the m.room.member events in
+ the state.
+ threepid_invite_tokens (list[str]): List of tokens of the 3PID invites in the
+ state.
+ """
+ existing_members = []
+ threepid_invite_tokens = []
+ for key, state_event in state_events.items():
+ if key[0] == EventTypes.Member and state_event.content:
+ existing_members.append(state_event.state_key)
+ if key[0] == EventTypes.ThirdPartyInvite and state_event.content:
+ # Don't include revoked invites.
+ threepid_invite_tokens.append(state_event.state_key)
+
+ return existing_members, threepid_invite_tokens
+
+ @staticmethod
+ def _is_invite_from_threepid(invite, threepid_invite_token):
+ """Checks whether the given invite follows the given 3PID invite.
+
+ Args:
+ invite (EventBase): The m.room.member event with "invite" membership.
+ threepid_invite_token (str): The state key from the 3PID invite.
+ """
+ token = (
+ invite.content.get("third_party_invite", {})
+ .get("signed", {})
+ .get("token", "")
+ )
+
+ return token == threepid_invite_token
diff --git a/synapse/types.py b/synapse/types.py
index f3cd465735..16a7f87011 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -19,6 +19,8 @@ import sys
from collections import namedtuple
from typing import Any, Dict, Tuple, TypeVar
+from six.moves import filter
+
import attr
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
@@ -262,6 +264,19 @@ def contains_invalid_mxid_characters(localpart):
return any(c not in mxid_localpart_allowed_characters for c in localpart)
+def strip_invalid_mxid_characters(localpart):
+ """Removes any invalid characters from an mxid
+
+ Args:
+ localpart (basestring): the localpart to be stripped
+
+ Returns:
+ localpart (basestring): the localpart having been stripped
+ """
+ filtered = filter(lambda c: c in mxid_localpart_allowed_characters, localpart)
+ return "".join(filtered)
+
+
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 635b897d6c..f2ccd5e7c6 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -30,7 +30,7 @@ def freeze(o):
return o
try:
- return tuple([freeze(i) for i in o])
+ return tuple(freeze(i) for i in o)
except TypeError:
pass
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 3ec1dfb0c2..34ce7cac16 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -16,11 +16,14 @@
import logging
import re
+from twisted.internet import defer
+
logger = logging.getLogger(__name__)
+@defer.inlineCallbacks
def check_3pid_allowed(hs, medium, address):
- """Checks whether a given format of 3PID is allowed to be used on this HS
+ """Checks whether a given 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
@@ -28,9 +31,36 @@ def check_3pid_allowed(hs, medium, address):
address (str): address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
- bool: whether the 3PID medium/address is allowed to be added to this HS
+ defered bool: whether the 3PID medium/address is allowed to be added to this HS
"""
+ if hs.config.check_is_for_allowed_local_3pids:
+ data = yield hs.get_simple_http_client().get_json(
+ "https://%s%s"
+ % (
+ hs.config.check_is_for_allowed_local_3pids,
+ "/_matrix/identity/api/v1/internal-info",
+ ),
+ {"medium": medium, "address": address},
+ )
+
+ # Check for invalid response
+ if "hs" not in data and "shadow_hs" not in data:
+ defer.returnValue(False)
+
+ # Check if this user is intended to register for this homeserver
+ if (
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
+ ):
+ defer.returnValue(False)
+
+ if data.get("requires_invite", False) and not data.get("invited", False):
+ # Requires an invite but hasn't been invited
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
@@ -43,8 +73,8 @@ def check_3pid_allowed(hs, medium, address):
if medium == constraint["medium"] and re.match(
constraint["pattern"], address
):
- return True
+ defer.returnValue(True)
else:
- return True
+ defer.returnValue(True)
- return False
+ defer.returnValue(False)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index d0abd8f04f..e60d9756b7 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -75,7 +75,7 @@ def filter_events_for_client(
"""
# Filter out events that have been soft failed so that we don't relay them
# to clients.
- events = list(e for e in events if not e.internal_metadata.is_soft_failed())
+ events = [e for e in events if not e.internal_metadata.is_soft_failed()]
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
event_id_to_state = yield storage.state.get_state_for_events(
@@ -97,7 +97,7 @@ def filter_events_for_client(
erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
if apply_retention_policies:
- room_ids = set(e.room_id for e in events)
+ room_ids = {e.room_id for e in events}
retention_policies = {}
for room_id in room_ids:
diff --git a/sytest-blacklist b/sytest-blacklist
index 79b2d4402a..fd50197b13 100644
--- a/sytest-blacklist
+++ b/sytest-blacklist
@@ -36,3 +36,24 @@ Inbound federation of state requires event_id as a mandatory paramater
# Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands
Can upload self-signing keys
+
+# flaky test
+If remote user leaves room we no longer receive device updates
+
+# flaky test
+Can re-join room if re-invited
+
+# flaky test
+Forgotten room messages cannot be paginated
+
+# flaky test
+Local device key changes get to remote servers
+
+# flaky test
+Old leaves are present in gapped incremental syncs
+
+# flaky test on workers
+Old members are included in gappy incr LL sync if they start speaking
+
+# flaky test on workers
+Presence changes to UNAVAILABLE are reported to remote room members
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 8bdbc608a9..d3feafa1b7 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.app.frontend_proxy import FrontendProxyServer
+from synapse.app.generic_worker import GenericWorkerServer
from tests.unittest import HomeserverTestCase
@@ -22,11 +22,16 @@ class FrontendProxyTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserverToUse=FrontendProxyServer
+ http_client=None, homeserverToUse=GenericWorkerServer
)
return hs
+ def default_config(self, name="test"):
+ c = super().default_config(name)
+ c["worker_app"] = "synapse.app.frontend_proxy"
+ return c
+
def test_listen_http_with_presence_enabled(self):
"""
When presence is on, the stub servlet will not register.
@@ -46,9 +51,7 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- self.resource = (
- site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
- )
+ self.resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
@@ -76,9 +79,7 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- self.resource = (
- site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
- )
+ self.resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 48792d1480..89fcc3889a 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -16,7 +16,7 @@ from mock import Mock, patch
from parameterized import parameterized
-from synapse.app.federation_reader import FederationReaderServer
+from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer
from tests.unittest import HomeserverTestCase
@@ -25,10 +25,18 @@ from tests.unittest import HomeserverTestCase
class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserverToUse=FederationReaderServer
+ http_client=None, homeserverToUse=GenericWorkerServer
)
return hs
+ def default_config(self, name="test"):
+ conf = super().default_config(name)
+ # we're using FederationReaderServer, which uses a SlavedStore, so we
+ # have to tell the FederationHandler not to try to access stuff that is only
+ # in the primary store.
+ conf["worker_app"] = "yes"
+ return conf
+
@parameterized.expand(
[
(["federation"], "auth_fail"),
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 2684e662de..463855ecc8 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -48,7 +48,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
)
self.assertSetEqual(
- set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]),
+ {"homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"},
set(os.listdir(self.dir)),
)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index e7d8699040..296dc887be 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -83,7 +83,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
)
)
- self.assertEqual(members, set(["@user:other.example.com", u1]))
+ self.assertEqual(members, {"@user:other.example.com", u1})
self.assertEqual(len(channel.json_body["pdus"]), 6)
def test_needs_to_be_in_room(self):
diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py
new file mode 100644
index 0000000000..34f6bfb422
--- /dev/null
+++ b/tests/handlers/test_identity.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse.rest.admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account
+
+from tests import unittest
+
+
+class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.address = "test@test"
+ self.is_server_name = "testis"
+ self.rewritten_is_url = "int.testis"
+
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = [self.is_server_name]
+ config["rewrite_identity_server_urls"] = {
+ self.is_server_name: self.rewritten_is_url
+ }
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.side_effect = defer.succeed({})
+ mock_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ mock_blacklisting_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_blacklisting_http_client.get_json.side_effect = defer.succeed({})
+ mock_blacklisting_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.blacklisting_http_client = (
+ mock_blacklisting_http_client
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+
+ def test_rewritten_id_server(self):
+ """
+ Tests that, when validating a 3PID association while rewriting the IS's server
+ name:
+ * the bind request is done against the rewritten hostname
+ * the original, non-rewritten, server name is stored in the database
+ """
+ handler = self.hs.get_handlers().identity_handler
+ post_json_get_json = handler.blacklisting_http_client.post_json_get_json
+ store = self.hs.get_datastore()
+
+ creds = {"sid": "123", "client_secret": "some_secret"}
+
+ # Make sure processing the mocked response goes through.
+ data = self.get_success(
+ handler.bind_threepid(
+ client_secret=creds["client_secret"],
+ sid=creds["sid"],
+ mxid=self.user_id,
+ id_server=self.is_server_name,
+ use_v2=False,
+ )
+ )
+ self.assertEqual(data.get("address"), self.address)
+
+ # Check that the request was done against the rewritten server name.
+ post_json_get_json.assert_called_once_with(
+ "https://%s/_matrix/identity/api/v1/3pid/bind" % self.rewritten_is_url,
+ {
+ "sid": creds["sid"],
+ "client_secret": creds["client_secret"],
+ "mxid": self.user_id,
+ },
+ headers={},
+ )
+
+ # Check that the original server name is saved in the database instead of the
+ # rewritten one.
+ id_servers = self.get_success(
+ store.get_id_servers_user_bound(self.user_id, "email", self.address)
+ )
+ self.assertEqual(id_servers, [self.is_server_name])
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index c171038df8..05ea40a7de 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -338,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
- state, is_mine=True, syncing_user_ids=set([user_id]), now=now
+ state, is_mine=True, syncing_user_ids={user_id}, now=now
)
self.assertIsNotNone(new_state)
@@ -494,8 +494,10 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.helper.join(room_id, "@test2:server")
# Mark test2 as online, test will be offline with a last_active of 0
- self.presence_handler.set_state(
- UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ self.get_success(
+ self.presence_handler.set_state(
+ UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ )
)
self.reactor.pump([0]) # Wait for presence updates to be handled
@@ -543,14 +545,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
room_id = self.helper.create_room_as(self.user_id)
# Mark test as online
- self.presence_handler.set_state(
- UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
+ self.get_success(
+ self.presence_handler.set_state(
+ UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
+ )
)
# Mark test2 as online, test will be offline with a last_active of 0.
# Note we don't join them to the room yet
- self.presence_handler.set_state(
- UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ self.get_success(
+ self.presence_handler.set_state(
+ UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ )
)
# Add servers to the room
@@ -579,7 +585,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_state.state, PresenceState.ONLINE)
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
- destinations=set(("server2", "server3")), states=[expected_state]
+ destinations={"server2", "server3"}, states=[expected_state]
)
def _add_new_user(self, room_id, user_id):
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d60c124eec..2311040201 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -67,13 +67,11 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- yield self.store.create_profile(self.frank.localpart)
-
self.handler = hs.get_profile_handler()
@defer.inlineCallbacks
def test_get_my_name(self):
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
displayname = yield self.handler.get_displayname(self.frank)
@@ -116,8 +114,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.create_profile("caroline")
- yield self.store.set_profile_displayname("caroline", "Caroline")
+ yield self.store.set_profile_displayname("caroline", "Caroline", 1)
response = yield self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"}
@@ -128,7 +125,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ self.frank.localpart, "http://my.server/me.png", 1
)
avatar_url = yield self.handler.get_avatar_url(self.frank)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e2915eb7b1..5e7f14a3d5 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
+from synapse.rest.client.v2_alpha.register import _map_email_to_displayname
from synapse.types import RoomAlias, UserID, create_requester
from .. import unittest
@@ -256,6 +257,26 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_email_to_displayname_mapping(self):
+ """Test that custom emails are mapped to new user displaynames correctly"""
+ self._check_mapping(
+ "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]"
+ )
+
+ self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]")
+
+ self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]")
+
+ # Multibyte unicode characters
+ self._check_mapping(
+ "j\u030a\u0065an-poppy.seed@example.com",
+ "J\u030a\u0065an-Poppy Seed [Example]",
+ )
+
+ def _check_mapping(self, i, expected):
+ result = _map_email_to_displayname(i)
+ self.assertEqual(result, expected)
+
@defer.inlineCallbacks
def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
"""Creates a new user if the user does not exist,
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index d9d312f0fb..8e6b0b7536 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -21,8 +21,12 @@ from tests import unittest
# The expected number of state events in a fresh public room.
EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5
+
# The expected number of state events in a fresh private room.
-EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
+#
+# Note: we increase this by 1 on the dinsic branch as we send
+# a "im.vector.room.access_rules" state event into new private rooms
+EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7
class StatsRoomTests(unittest.HomeserverTestCase):
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 140cc0a3c2..51e2b37218 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -74,6 +74,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"set_received_txn_response",
"get_destination_retry_timings",
"get_devices_by_remote",
+ "maybe_store_room_on_invite",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
@@ -129,12 +130,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
hs.get_auth().check_user_in_room = check_user_in_room
def get_joined_hosts_for_room(room_id):
- return set(member.domain for member in self.room_members)
+ return {member.domain for member in self.room_members}
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
def get_current_users_in_room(room_id):
- return set(str(u) for u in self.room_members)
+ return {str(u) for u in self.room_members}
hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
@@ -257,7 +258,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
member = RoomMember(ROOM_ID, U_APPLE.to_string())
self.handler._member_typing_until[member] = 1002000
- self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()])
+ self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()}
self.assertEquals(self.event_source.get_current_key(), 0)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 0a4765fff4..7b92bdbc47 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -114,7 +114,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
public_users = self.get_users_in_public_rooms()
self.assertEqual(
- self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)])
+ self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
)
self.assertEqual(public_users, [])
@@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
public_users = self.get_users_in_public_rooms()
self.assertEqual(
- self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)])
+ self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
)
self.assertEqual(public_users, [])
@@ -226,7 +226,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
public_users = self.get_users_in_public_rooms()
self.assertEqual(
- self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)])
+ self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
)
self.assertEqual(public_users, [])
@@ -358,12 +358,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
public_users = self.get_users_in_public_rooms()
# User 1 and User 2 are in the same public room
- self.assertEqual(set(public_users), set([(u1, room), (u2, room)]))
+ self.assertEqual(set(public_users), {(u1, room), (u2, room)})
# User 1 and User 3 share private rooms
self.assertEqual(
self._compress_shared(shares_private),
- set([(u1, u3, private_room), (u3, u1, private_room)]),
+ {(u1, u3, private_room), (u3, u1, private_room)},
)
def test_initial_share_all_users(self):
@@ -398,7 +398,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# No users share rooms
self.assertEqual(public_users, [])
- self.assertEqual(self._compress_shared(shares_private), set([]))
+ self.assertEqual(self._compress_shared(shares_private), set())
# Despite not sharing a room, search_all_users means we get a search
# result.
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 80187406bc..83032cc9ea 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -163,7 +163,7 @@ class EmailPusherTests(HomeserverTestCase):
# Get the stream ordering before it gets sent
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
@@ -174,7 +174,7 @@ class EmailPusherTests(HomeserverTestCase):
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
@@ -192,7 +192,7 @@ class EmailPusherTests(HomeserverTestCase):
# The stream ordering has increased
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index fe3441f081..baf9c785f4 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -102,7 +102,7 @@ class HTTPPusherTests(HomeserverTestCase):
# Get the stream ordering before it gets sent
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": user_id})
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
@@ -113,7 +113,7 @@ class HTTPPusherTests(HomeserverTestCase):
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": user_id})
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
@@ -132,7 +132,7 @@ class HTTPPusherTests(HomeserverTestCase):
# The stream ordering has increased
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": user_id})
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
@@ -152,7 +152,7 @@ class HTTPPusherTests(HomeserverTestCase):
# The stream ordering has increased, again
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": user_id})
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..e163a46f6b 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@
import json
+from mock import Mock
+
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import account
from tests import unittest
-class IdentityTestCase(unittest.HomeserverTestCase):
+class IdentityDisabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts fail when the HS's config disallows them."""
servlets = [
+ account.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
@@ -32,24 +39,113 @@ class IdentityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["testis"]
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_disabled(self):
+ request, channel = self.make_request(
+ b"POST", "/createRoom", b"{}", access_token=self.tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ room_id = channel.json_body["room_id"]
+
+ params = {
+ "id_server": "testis",
+ "medium": "email",
+ "address": "test@example.com",
+ }
+ request_data = json.dumps(params)
+ request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
+ request, channel = self.make_request(
+ b"POST", request_url, request_data, access_token=self.tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
def test_3pid_lookup_disabled(self):
- self.hs.config.enable_3pid_lookup = False
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
- self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
+ def test_3pid_bulk_lookup_disabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+
+class IdentityEnabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts succeed when the HS's config allows them."""
+ servlets = [
+ account.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+
+ config = self.default_config()
+ config["enable_3pid_lookup"] = True
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.return_value = defer.succeed((200, "{}"))
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.http_client = (
+ mock_http_client
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_enabled(self):
request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=tok
+ b"POST", "/createRoom", b"{}", access_token=self.tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
room_id = channel.json_body["room_id"]
+ # Replace the blacklisting SimpleHttpClient with our mock
+ self.hs.get_room_member_handler().simple_http_client = Mock(
+ spec=["get_json", "post_json_get_json"]
+ )
+ self.hs.get_room_member_handler().simple_http_client.get_json.return_value = (
+ defer.succeed((200, "{}"))
+ )
+
params = {
"id_server": "testis",
"medium": "email",
@@ -58,7 +154,44 @@ class IdentityTestCase(unittest.HomeserverTestCase):
request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
request, channel = self.make_request(
- b"POST", request_url, request_data, access_token=tok
+ b"POST", request_url, request_data, access_token=self.tok
)
self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ get_json = self.hs.get_handlers().identity_handler.http_client.get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "test@example.com", "medium": "email"},
+ )
+
+ def test_3pid_lookup_enabled(self):
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.render(request)
+
+ get_json = self.hs.get_simple_http_client().get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "foo@bar.baz", "medium": "email"},
+ )
+
+ def test_3pid_bulk_lookup_enabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.render(request)
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ post_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/bulk_lookup",
+ {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]},
+ )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 95475bb651..9e549d8a91 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -34,6 +34,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
"default_policy": {
@@ -203,6 +204,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
}
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
new file mode 100644
index 0000000000..f10ae0adeb
--- /dev/null
+++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,726 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+import random
+import string
+
+from mock import Mock
+
+from twisted.internet import defer
+from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.third_party_rules.access_rules import (
+ ACCESS_RULE_DIRECT,
+ ACCESS_RULE_RESTRICTED,
+ ACCESS_RULE_UNRESTRICTED,
+ ACCESS_RULES_TYPE,
+)
+
+from tests import unittest
+
+
+class RoomAccessTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["third_party_event_rules"] = {
+ "module": "synapse.third_party_rules.access_rules.RoomAccessRules",
+ "config": {
+ "domains_forbidden_when_restricted": ["forbidden_domain"],
+ "id_server": "testis",
+ },
+ }
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ def send_invite(destination, room_id, event_id, pdu):
+ return defer.succeed(pdu)
+
+ def get_json(uri, args={}, headers=None):
+ address_domain = args["address"].split("@")[1]
+ return defer.succeed({"hs": address_domain})
+
+ def post_json_get_json(uri, post_json, args={}, headers=None):
+ token = "".join(random.choice(string.ascii_letters) for _ in range(10))
+ return defer.succeed(
+ {
+ "token": token,
+ "public_keys": [
+ {
+ "public_key": "serverpublickey",
+ "key_validity_url": "https://testis/pubkey/isvalid",
+ },
+ {
+ "public_key": "phemeralpublickey",
+ "key_validity_url": "https://testis/pubkey/ephemeral/isvalid",
+ },
+ ],
+ "display_name": "f...@b...",
+ }
+ )
+
+ mock_federation_client = Mock(spec=["send_invite"])
+ mock_federation_client.send_invite.side_effect = send_invite
+
+ mock_http_client = Mock(
+ spec=["get_json", "post_json_get_json"],
+ )
+ # Mocking the response for /info on the IS API.
+ mock_http_client.get_json.side_effect = get_json
+ # Mocking the response for /store-invite on the IS API.
+ mock_http_client.post_json_get_json.side_effect = post_json_get_json
+ self.hs = self.setup_test_homeserver(
+ config=config,
+ federation_client=mock_federation_client,
+ simple_http_client=mock_http_client,
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.blacklisting_http_client = mock_http_client
+
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.restricted_room = self.create_room()
+ self.unrestricted_room = self.create_room(rule=ACCESS_RULE_UNRESTRICTED)
+ self.direct_rooms = [
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ ]
+
+ self.invitee_id = self.register_user("invitee", "test")
+ self.invitee_tok = self.login("invitee", "test")
+
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ )
+
+ def test_create_room_no_rule(self):
+ """Tests that creating a room with no rule will set the default value."""
+ room_id = self.create_room()
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, ACCESS_RULE_RESTRICTED)
+
+ def test_create_room_direct_no_rule(self):
+ """Tests that creating a direct room with no rule will set the default value."""
+ room_id = self.create_room(direct=True)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, ACCESS_RULE_DIRECT)
+
+ def test_create_room_valid_rule(self):
+ """Tests that creating a room with a valid rule will set the right value."""
+ room_id = self.create_room(rule=ACCESS_RULE_UNRESTRICTED)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, ACCESS_RULE_UNRESTRICTED)
+
+ def test_create_room_invalid_rule(self):
+ """Tests that creating a room with an invalid rule will set fail."""
+ self.create_room(rule=ACCESS_RULE_DIRECT, expected_code=400)
+
+ def test_create_room_direct_invalid_rule(self):
+ """Tests that creating a direct room with an invalid rule will fail.
+ """
+ self.create_room(direct=True, rule=ACCESS_RULE_RESTRICTED, expected_code=400)
+
+ def test_public_room(self):
+ """Tests that it's not possible to have a room with the public join rule and an
+ access rule that's not restricted.
+ """
+ # Creating a room with the public_chat preset should succeed and set the access
+ # rule to restricted.
+ preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT)
+ self.assertEqual(
+ self.current_rule_in_room(preset_room_id), ACCESS_RULE_RESTRICTED
+ )
+
+ # Creating a room with the public join rule in its initial state should succeed
+ # and set the access rule to restricted.
+ init_state_room_id = self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ]
+ )
+ self.assertEqual(
+ self.current_rule_in_room(init_state_room_id), ACCESS_RULE_RESTRICTED
+ )
+
+ # Changing access rule to unrestricted should fail.
+ self.change_rule_in_room(
+ preset_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403
+ )
+ self.change_rule_in_room(
+ init_state_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403
+ )
+
+ # Changing access rule to direct should fail.
+ self.change_rule_in_room(preset_room_id, ACCESS_RULE_DIRECT, expected_code=403)
+ self.change_rule_in_room(
+ init_state_room_id, ACCESS_RULE_DIRECT, expected_code=403
+ )
+
+ # Changing join rule to public in an unrestricted room should fail.
+ self.change_join_rule_in_room(
+ self.unrestricted_room, JoinRules.PUBLIC, expected_code=403
+ )
+ # Changing join rule to public in an direct room should fail.
+ self.change_join_rule_in_room(
+ self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403
+ )
+
+ # Creating a new room with the public_chat preset and an access rule that isn't
+ # restricted should fail.
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=400,
+ )
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=ACCESS_RULE_DIRECT,
+ expected_code=400,
+ )
+
+ # Creating a room with the public join rule in its initial state and an access
+ # rule that isn't restricted should fail.
+ self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ],
+ rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=400,
+ )
+ self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ],
+ rule=ACCESS_RULE_DIRECT,
+ expected_code=400,
+ )
+
+ def test_restricted(self):
+ """Tests that in restricted mode we're unable to invite users from blacklisted
+ servers but can invite other users.
+ """
+ # We can't invite a user from a forbidden HS.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can invite a user which HS isn't forbidden.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:allowed_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.restricted_room,
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.restricted_room,
+ expected_code=200,
+ )
+
+ def test_direct(self):
+ """Tests that, in direct mode, other users than the initial two can't be invited,
+ but the following scenario works:
+ * invited user joins the room
+ * invited user leaves the room
+ * room creator re-invites invited user
+ Also tests that a user from a HS that's in the list of forbidden domains (to use
+ in restricted mode) can be invited.
+ """
+ not_invited_user = "@not_invited:forbidden_domain"
+
+ # We can't invite a new user to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # The invited user can join the room.
+ self.helper.join(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can leave the room.
+ self.helper.leave(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can be re-invited to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # If we're alone in the room and have always been the only member, we can invite
+ # someone.
+ self.helper.invite(
+ room=self.direct_rooms[1],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # Disable the 3pid invite ratelimiter
+ burst = self.hs.config.rc_third_party_invite.burst_count
+ per_second = self.hs.config.rc_third_party_invite.per_second
+ self.hs.config.rc_third_party_invite.burst_count = 10
+ self.hs.config.rc_third_party_invite.per_second = 0.1
+
+ # We can't send a 3PID invite to a room that already has two members.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[0],
+ expected_code=403,
+ )
+
+ # We can't send a 3PID invite to a room that already has a pending invite.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[1],
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to a room in which we've always been the only member.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to a room in which there's a 3PID invite.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=403,
+ )
+
+ self.hs.config.rc_third_party_invite.burst_count = burst
+ self.hs.config.rc_third_party_invite.per_second = per_second
+
+ def test_unrestricted(self):
+ """Tests that, in unrestricted mode, we can invite whoever we want, but we can
+ only change the power level of users that wouldn't be forbidden in restricted
+ mode.
+ """
+ # We can invite
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:not_forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a power level event that doesn't redefine the default PL or set a
+ # non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a power level event that redefines the default PL and doesn't set
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={
+ "users": {self.user_id: 100, "@test:not_forbidden_domain": 10},
+ "users_default": 10,
+ },
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't send a power level event that doesn't redefines the default PL but sets
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_rules(self):
+ """Tests that we can only change the current rule from restricted to
+ unrestricted.
+ """
+ # We can change the rule from restricted to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.restricted_room,
+ new_rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=200,
+ )
+
+ # We can't change the rule from restricted to direct.
+ self.change_rule_in_room(
+ room_id=self.restricted_room, new_rule=ACCESS_RULE_DIRECT, expected_code=403
+ )
+
+ # We can't change the rule from unrestricted to restricted.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=ACCESS_RULE_RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from unrestricted to direct.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=ACCESS_RULE_DIRECT,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to restricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=ACCESS_RULE_RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=403,
+ )
+
+ def test_change_room_avatar(self):
+ """Tests that changing the room avatar is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ avatar_content = {
+ "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394},
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ }
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_name(self):
+ """Tests that changing the room name is always allowed unless the room is a direct
+ chat, in which case it's forbidden.
+ """
+
+ name_content = {"name": "My super room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_topic(self):
+ """Tests that changing the room topic is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ topic_content = {"topic": "Welcome to this room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_revoke_3pid_invite_direct(self):
+ """Tests that revoking a 3PID invite doesn't cause the room access rules module to
+ confuse the revokation as a new 3PID invite.
+ """
+ invite_token = "sometoken"
+
+ invite_body = {
+ "display_name": "ker...@exa...",
+ "public_keys": [
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ },
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I",
+ },
+ ],
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ }
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=self.tok,
+ )
+
+ invite_token = "someothertoken"
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ def create_room(
+ self,
+ direct=False,
+ rule=None,
+ preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ initial_state=None,
+ expected_code=200,
+ ):
+ content = {"is_direct": direct, "preset": preset}
+
+ if rule:
+ content["initial_state"] = [
+ {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}}
+ ]
+
+ if initial_state:
+ if "initial_state" not in content:
+ content["initial_state"] = []
+
+ content["initial_state"] += initial_state
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ json.dumps(content),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ if expected_code == 200:
+ return channel.json_body["room_id"]
+
+ def current_rule_in_room(self, room_id):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["rule"]
+
+ def change_rule_in_room(self, room_id, new_rule, expected_code=200):
+ data = {"rule": new_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200):
+ data = {"join_rule": new_join_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_threepid_invite(self, address, room_id, expected_code=200):
+ params = {"id_server": "testis", "medium": "email", "address": address}
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/invite" % room_id,
+ json.dumps(params),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_state_with_state_key(
+ self, room_id, event_type, state_key, body, tok, expect_code=200
+ ):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id,
+ event_type,
+ state_key,
+ )
+
+ request, channel = self.make_request(
+ "PUT", path, json.dumps(body), access_token=tok
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expect_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
new file mode 100644
index 0000000000..37f970c6b0
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account, password_policy, register
+
+from tests import unittest
+
+
+class PasswordPolicyTestCase(unittest.HomeserverTestCase):
+ """Tests the password policy feature and its compliance with MSC2000.
+
+ When validating a password, Synapse does the necessary checks in this order:
+
+ 1. Password is long enough
+ 2. Password contains digit(s)
+ 3. Password contains symbol(s)
+ 4. Password contains uppercase letter(s)
+ 5. Password contains lowercase letter(s)
+
+ Therefore, each test in this test case that tests whether a password triggers the
+ right error code to be returned provides a password good enough to pass the previous
+ steps but not the one it's testing (nor any step that comes after).
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ password_policy.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.register_url = "/_matrix/client/r0/register"
+ self.policy = {
+ "enabled": True,
+ "minimum_length": 10,
+ "require_digit": True,
+ "require_symbol": True,
+ "require_lowercase": True,
+ "require_uppercase": True,
+ }
+
+ config = self.default_config()
+ config["password_config"] = {"policy": self.policy}
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def test_get_policy(self):
+ """Tests if the /password_policy endpoint returns the configured policy."""
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/password_policy"
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "m.minimum_length": 10,
+ "m.require_digit": True,
+ "m.require_symbol": True,
+ "m.require_lowercase": True,
+ "m.require_uppercase": True,
+ },
+ channel.result,
+ )
+
+ def test_password_too_short(self):
+ request_data = json.dumps({"username": "kermit", "password": "shorty"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result
+ )
+
+ def test_password_no_digit(self):
+ request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result
+ )
+
+ def test_password_no_symbol(self):
+ request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result
+ )
+
+ def test_password_no_uppercase(self):
+ request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result
+ )
+
+ def test_password_no_lowercase(self):
+ request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result
+ )
+
+ def test_password_compliant(self):
+ request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ # Getting a 401 here means the password has passed validation and the server has
+ # responded with a list of registration flows.
+ self.assertEqual(channel.code, 401, channel.result)
+
+ def test_password_change(self):
+ """This doesn't test every possible use case, only that hitting /account/password
+ triggers the password validation code.
+ """
+ compliant_password = "C0mpl!antpassword"
+ not_compliant_password = "notcompliantpassword"
+
+ user_id = self.register_user("kermit", compliant_password)
+ tok = self.login("kermit", compliant_password)
+
+ request_data = json.dumps(
+ {
+ "new_password": not_compliant_password,
+ "auth": {
+ "password": compliant_password,
+ "type": LoginType.PASSWORD,
+ "user": user_id,
+ },
+ }
+ )
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/account/password",
+ request_data,
+ access_token=tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index d0c997e385..d99b100d0f 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,8 +19,12 @@ import datetime
import json
import os
+from mock import Mock
+
import pkg_resources
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
@@ -261,6 +265,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
+class RegisterHideProfileTestCase(unittest.HomeserverTestCase):
+
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.url = b"/_matrix/client/r0/register"
+
+ config = self.default_config()
+ config["enable_registration"] = True
+ config["show_users_in_user_directory"] = False
+ config["replicate_user_profiles_to"] = ["fakeserver"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_profile_hidden(self):
+ user_id = self.register_user("kermit", "monkey")
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # We expect post_json_get_json to have been called twice: once with the original
+ # profile and once with the None profile resulting from the request to hide it
+ # from the user directory.
+ self.assertEqual(post_json.call_count, 2, post_json.call_args_list)
+
+ # Get the args (and not kwargs) passed to post_json.
+ args = post_json.call_args[0]
+ # Make sure the last call was attempting to replicate profiles.
+ split_uri = args[0].split("/")
+ self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0])
+ # Make sure the last profile update was overriding the user's profile to None.
+ self.assertEqual(args[1]["batch"][user_id], None, args[1])
+
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -269,6 +314,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
login.register_servlets,
sync.register_servlets,
account_validity.register_servlets,
+ account.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -361,6 +407,138 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
)
+class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.client.v1.profile.register_servlets,
+ synapse.rest.client.v1.room.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Set accounts to expire after a week
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ config["replicate_user_profiles_to"] = "test.is"
+
+ # Mock homeserver requests to an identity server
+ mock_http_client = Mock(spec=["post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_expired_user_in_directory(self):
+ """Test that an expired user is hidden in the user directory"""
+ # Create an admin user to search the user directory
+ admin_id = self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ # Ensure the admin never expires
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": admin_id,
+ "expiration_ts": 999999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Create a user to expire
+ username = "kermit"
+ user_id = self.register_user(username, "monkey")
+ self.login(username, "monkey")
+
+ self.pump(1000)
+ self.reactor.advance(1000)
+ self.pump()
+
+ # Expire the user
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Wait for the background job to run which hides expired users in the directory
+ self.pump(60 * 60 * 1000)
+
+ # Mock the homeserver's HTTP client
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's None, signifying that the user should be removed from the user
+ # directory because they were expired
+ replicated_content = batch[user_id]
+ self.assertIsNone(replicated_content)
+
+ # Now renew the user, and check they get replicated again to the identity server
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 99999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.pump(10)
+ self.reactor.advance(10)
+ self.pump()
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None, signifying that the user is back in the user
+ # directory
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -511,7 +689,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"POST", "account/deactivate", request_data, access_token=tok
)
self.render(request)
- self.assertEqual(request.code, 200)
+ self.assertEqual(request.code, 200, channel.result)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 9c13a13786..fa3a3ec1bd 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -40,16 +40,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertTrue(
- set(
- [
- "next_batch",
- "rooms",
- "presence",
- "account_data",
- "to_device",
- "device_lists",
- ]
- ).issubset(set(channel.json_body.keys()))
+ {
+ "next_batch",
+ "rooms",
+ "presence",
+ "account_data",
+ "to_device",
+ "device_lists",
+ }.issubset(set(channel.json_body.keys()))
)
def test_sync_presence_disabled(self):
@@ -63,9 +61,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertTrue(
- set(
- ["next_batch", "rooms", "account_data", "to_device", "device_lists"]
- ).issubset(set(channel.json_body.keys()))
+ {
+ "next_batch",
+ "rooms",
+ "account_data",
+ "to_device",
+ "device_lists",
+ }.issubset(set(channel.json_body.keys()))
)
diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py
new file mode 100644
index 0000000000..1accc70dc9
--- /dev/null
+++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+
+import synapse.rest.admin
+from synapse.config._base import ConfigError
+from synapse.rest.client.v1 import login, room
+from synapse.rulecheck.domain_rule_checker import DomainRuleChecker
+
+from tests import unittest
+from tests.server import make_request, render
+
+
+class DomainRuleCheckerTestCase(unittest.TestCase):
+ def test_allowed(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ "domains_prevented_from_being_invited_to_published_rooms": ["target_two"],
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_one", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_two", "test:target_two", None, "room", False
+ )
+ )
+
+ # User can invite internal user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test1:target_one", None, "room", False, True
+ )
+ )
+
+ # User can invite external user to a non-published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, False
+ )
+ )
+
+ def test_disallowed(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ "source_four": [],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_one", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_one", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_four", "test:target_one", None, "room", False
+ )
+ )
+
+ # User cannot invite external user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, True
+ )
+ )
+
+ def test_default_allow(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_default_deny(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_config_parse(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ self.assertEquals(config, DomainRuleChecker.parse_config(config))
+
+ def test_config_parse_failure(self):
+ config = {
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ }
+ }
+ self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config)
+
+
+class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["localhost"]
+
+ config["spam_checker"] = {
+ "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker",
+ "config": {
+ "default": True,
+ "domain_mapping": {},
+ "can_only_join_rooms_with_invite": True,
+ "can_only_create_one_to_one_rooms": True,
+ "can_only_invite_during_room_creation": True,
+ "can_invite_by_third_party_id": False,
+ },
+ }
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user_id = self.register_user("admin_user", "pass", admin=True)
+ self.admin_access_token = self.login("admin_user", "pass")
+
+ self.normal_user_id = self.register_user("normal_user", "pass", admin=False)
+ self.normal_access_token = self.login("normal_user", "pass")
+
+ self.other_user_id = self.register_user("other_user", "pass", admin=False)
+
+ def test_admin_can_create_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_normal_user_cannot_create_empty_room(self):
+ channel = self._create_room(self.normal_access_token)
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_cannot_create_room_with_multiple_invites(self):
+ channel = self._create_room(
+ self.normal_access_token,
+ content={"invite": [self.other_user_id, self.admin_user_id]},
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly counts both normal and third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [self.other_user_id],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly rejects third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_can_room_with_single_invites(self):
+ channel = self._create_room(
+ self.normal_access_token, content={"invite": [self.other_user_id]}
+ )
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_cannot_join_public_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403
+ )
+
+ def test_can_join_invited_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ def test_cannot_invite(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ def test_cannot_3pid_invite(self):
+ """Test that unbound 3pid invites get rejected.
+ """
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/invite" % (room_id),
+ {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"},
+ access_token=self.normal_access_token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 403, channel.result["body"])
+
+ def _create_room(self, token, content={}):
+ path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,)
+
+ request, channel = make_request(
+ self.hs.get_reactor(),
+ "POST",
+ path,
+ content=json.dumps(content).encode("utf8"),
+ )
+ render(request, self.resource, self.hs.get_reactor())
+
+ return channel
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 5bafad9f19..5059ade850 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -603,7 +603,7 @@ class TestStateResolutionStore(object):
return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
- def get_auth_chain(self, event_ids):
+ def get_auth_chain(self, event_ids, ignore_events):
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -617,6 +617,8 @@ class TestStateResolutionStore(object):
Args:
event_ids (list): The event IDs of the events to fetch the auth
chain for. Must be state events.
+ ignore_events: Set of events to exclude from the returned auth
+ chain.
Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
@@ -627,7 +629,7 @@ class TestStateResolutionStore(object):
stack = list(event_ids)
while stack:
event_id = stack.pop()
- if event_id in result:
+ if event_id in result or event_id in ignore_events:
continue
result.add(event_id)
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index d491ea2924..e37260a820 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -373,7 +373,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
self.assertEqual(
set(self._dump_to_tuple(res)),
- set([(1, "user1", "hello"), (2, "user2", "there")]),
+ {(1, "user1", "hello"), (2, "user2", "there")},
)
# Update only user2
@@ -400,5 +400,5 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
self.assertEqual(
set(self._dump_to_tuple(res)),
- set([(1, "user1", "hello"), (2, "user2", "bleb")]),
+ {(1, "user1", "hello"), (2, "user2", "bleb")},
)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index fd52512696..31710949a8 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -69,14 +69,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
pass
def _add_appservice(self, as_token, id, url, hs_token, sender):
- as_yaml = dict(
- url=url,
- as_token=as_token,
- hs_token=hs_token,
- id=id,
- sender_localpart=sender,
- namespaces={},
- )
+ as_yaml = {
+ "url": url,
+ "as_token": as_token,
+ "hs_token": hs_token,
+ "id": id,
+ "sender_localpart": sender,
+ "namespaces": {},
+ }
# use the token as the filename
with open(as_token, "w") as outfile:
outfile.write(yaml.dump(as_yaml))
@@ -135,14 +135,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
)
def _add_service(self, url, as_token, id):
- as_yaml = dict(
- url=url,
- as_token=as_token,
- hs_token="something",
- id=id,
- sender_localpart="a_sender",
- namespaces={},
- )
+ as_yaml = {
+ "url": url,
+ "as_token": as_token,
+ "hs_token": "something",
+ "id": id,
+ "sender_localpart": "a_sender",
+ "namespaces": {},
+ }
# use the token as the filename
with open(as_token, "w") as outfile:
outfile.write(yaml.dump(as_yaml))
@@ -384,8 +384,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
)
self.assertEquals(2, len(services))
self.assertEquals(
- set([self.as_list[2]["id"], self.as_list[0]["id"]]),
- set([services[0].id, services[1].id]),
+ {self.as_list[2]["id"], self.as_list[0]["id"]},
+ {services[0].id, services[1].id},
)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 029ac26454..0e04b2cf92 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -134,7 +134,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
+ self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
# Run the background update and check it did the right thing
self.run_background_update()
@@ -172,7 +172,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
+ self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
# Run the background update and check it did the right thing
self.run_background_update()
@@ -227,9 +227,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(
- set(latest_event_ids), set((event_id_a, event_id_b, event_id_c))
- )
+ self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c})
# Run the background update and check it did the right thing
self.run_background_update()
@@ -237,7 +235,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c]))
+ self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c})
class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index f26ff57a18..a7b7fd36d3 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -59,24 +59,22 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
)
)
- expected = set(
- [
- b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
- b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
- b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
- b'synapse_forward_extremities_bucket{le="5.0"} 2.0',
- b'synapse_forward_extremities_bucket{le="7.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="10.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="15.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="20.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="50.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="100.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="200.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="500.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
- b"synapse_forward_extremities_count 3.0",
- b"synapse_forward_extremities_sum 10.0",
- ]
- )
+ expected = {
+ b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
+ b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
+ b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
+ b'synapse_forward_extremities_bucket{le="5.0"} 2.0',
+ b'synapse_forward_extremities_bucket{le="7.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="10.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="15.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="20.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="50.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="100.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="200.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="500.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
+ b"synapse_forward_extremities_count 3.0",
+ b"synapse_forward_extremities_sum 10.0",
+ }
self.assertEqual(items, expected)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..7458a37e54 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,9 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.create_profile(self.u_frank.localpart)
-
- yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
self.assertEquals(
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
@@ -43,10 +41,8 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.create_profile(self.u_frank.localpart)
-
yield self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ self.u_frank.localpart, "http://my.site/here", 1
)
self.assertEquals(
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 04d58fbf24..0b88308ff4 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -394,7 +394,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, False)
- self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
+ self.assertEqual(known_absent, {(e1.type, e1.state_key)})
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
############################################
diff --git a/tests/test_state.py b/tests/test_state.py
index d1578fe581..66f22f6813 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -254,9 +254,7 @@ class StateTestCase(unittest.TestCase):
ctx_d = context_store["D"]
prev_state_ids = yield ctx_d.get_prev_state_ids()
- self.assertSetEqual(
- {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
- )
+ self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
@@ -313,9 +311,7 @@ class StateTestCase(unittest.TestCase):
ctx_e = context_store["E"]
prev_state_ids = yield ctx_e.get_prev_state_ids()
- self.assertSetEqual(
- {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
- )
+ self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@@ -388,9 +384,7 @@ class StateTestCase(unittest.TestCase):
ctx_d = context_store["D"]
prev_state_ids = yield ctx_d.get_prev_state_ids()
- self.assertSetEqual(
- {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
- )
+ self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
@@ -482,7 +476,7 @@ class StateTestCase(unittest.TestCase):
current_state_ids = yield context.get_current_state_ids()
self.assertEqual(
- set([e.event_id for e in old_state]), set(current_state_ids.values())
+ {e.event_id for e in old_state}, set(current_state_ids.values())
)
self.assertEqual(group_name, context.state_group)
@@ -513,9 +507,7 @@ class StateTestCase(unittest.TestCase):
prev_state_ids = yield context.get_prev_state_ids()
- self.assertEqual(
- set([e.event_id for e in old_state]), set(prev_state_ids.values())
- )
+ self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
self.assertIsNotNone(context.state_group)
diff --git a/tests/test_types.py b/tests/test_types.py
index 8d97c751ea..7390a1ce62 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -12,9 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import string_types
from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ GroupID,
+ RoomAlias,
+ UserID,
+ map_username_to_mxid_localpart,
+ strip_invalid_mxid_characters,
+)
from tests import unittest
@@ -103,3 +110,16 @@ class MapUsernameTestCase(unittest.TestCase):
self.assertEqual(
map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast"
)
+
+
+class StripInvalidMxidCharactersTestCase(unittest.TestCase):
+ def test_return_type(self):
+ unstripped = strip_invalid_mxid_characters("test")
+ stripped = strip_invalid_mxid_characters("test@")
+
+ self.assertTrue(isinstance(unstripped, string_types), type(unstripped))
+ self.assertTrue(isinstance(stripped, string_types), type(stripped))
+
+ def test_strip(self):
+ stripped = strip_invalid_mxid_characters("test@")
+ self.assertEqual(stripped, "test", stripped)
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index f2be63706b..72a9de5370 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -67,7 +67,7 @@ class StreamChangeCacheTests(unittest.TestCase):
# If we update an existing entity, it keeps the two existing entities
cache.entity_has_changed("bar@baz.net", 5)
self.assertEqual(
- set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key)
+ {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
)
def test_get_all_entities_changed(self):
@@ -137,7 +137,7 @@ class StreamChangeCacheTests(unittest.TestCase):
cache.get_entities_changed(
["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
),
- set(["bar@baz.net", "user@elsewhere.org"]),
+ {"bar@baz.net", "user@elsewhere.org"},
)
# Query all the entries mid-way through the stream, but include one
@@ -153,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase):
],
stream_pos=2,
),
- set(["bar@baz.net", "user@elsewhere.org"]),
+ {"bar@baz.net", "user@elsewhere.org"},
)
# Query all the entries, but before the first known point. We will get
@@ -168,21 +168,13 @@ class StreamChangeCacheTests(unittest.TestCase):
],
stream_pos=0,
),
- set(
- [
- "user@foo.com",
- "bar@baz.net",
- "user@elsewhere.org",
- "not@here.website",
- ]
- ),
+ {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"},
)
# Query a subset of the entries mid-way through the stream. We should
# only get back the subset.
self.assertEqual(
- cache.get_entities_changed(["bar@baz.net"], stream_pos=2),
- set(["bar@baz.net"]),
+ cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"},
)
def test_max_pos(self):
diff --git a/tox.ini b/tox.ini
index b9132a3177..00e095ef80 100644
--- a/tox.ini
+++ b/tox.ini
@@ -123,6 +123,7 @@ skip_install = True
basepython = python3.6
deps =
flake8
+ flake8-comprehensions
black==19.10b0 # We pin so that our tests don't start failing on new releases of black.
commands =
python -m black --check --diff .
@@ -138,7 +139,7 @@ commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests scripts-dev
skip_install = True
deps = towncrier>=18.6.0rc1
commands =
- python -m towncrier.check --compare-with=origin/develop
+ python -m towncrier.check --compare-with=origin/dinsic
basepython = python3.6
[testenv:check-sampleconfig]
@@ -167,7 +168,6 @@ commands=
coverage html
[testenv:mypy]
-basepython = python3.7
skip_install = True
deps =
{[base]deps}
@@ -178,10 +178,14 @@ env =
extras = all
commands = mypy \
synapse/api \
- synapse/config/ \
+ synapse/appservice \
+ synapse/config \
synapse/events/spamcheck.py \
+ synapse/federation/federation_base.py \
+ synapse/federation/federation_client.py \
synapse/federation/sender \
synapse/federation/transport \
+ synapse/handlers/presence.py \
synapse/handlers/sync.py \
synapse/handlers/ui_auth \
synapse/logging/ \
@@ -190,6 +194,7 @@ commands = mypy \
synapse/rest \
synapse/spam_checker_api \
synapse/storage/engines \
+ synapse/storage/database.py \
synapse/streams
# To find all folders that pass mypy you run:
|