diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
index 7f42fad909..ee76954185 100644
--- a/.buildkite/pipeline.yml
+++ b/.buildkite/pipeline.yml
@@ -47,7 +47,6 @@ steps:
- wait
-
- command:
- "python -m pip install tox"
- "tox -e py35-old,codecov"
@@ -117,8 +116,10 @@ steps:
limit: 2
- label: ":python: 3.5 / :postgres: 9.5"
+ agents:
+ queue: "medium"
env:
- TRIAL_FLAGS: "-j 4"
+ TRIAL_FLAGS: "-j 8"
command:
- "bash -c 'python -m pip install tox && python -m tox -e py35-postgres,codecov'"
plugins:
@@ -134,8 +135,10 @@ steps:
limit: 2
- label: ":python: 3.7 / :postgres: 9.5"
+ agents:
+ queue: "medium"
env:
- TRIAL_FLAGS: "-j 4"
+ TRIAL_FLAGS: "-j 8"
command:
- "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,codecov'"
plugins:
@@ -151,8 +154,10 @@ steps:
limit: 2
- label: ":python: 3.7 / :postgres: 11"
+ agents:
+ queue: "medium"
env:
- TRIAL_FLAGS: "-j 4"
+ TRIAL_FLAGS: "-j 8"
command:
- "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,codecov'"
plugins:
@@ -189,7 +194,7 @@ steps:
- label: "SyTest - :python: 3.5 / :postgres: 9.6 / Monolith"
agents:
- queue: "medium"
+ queue: "xlarge"
env:
POSTGRES: "1"
command:
@@ -197,7 +202,7 @@ steps:
- "bash /synapse_sytest.sh"
plugins:
- docker#v3.0.1:
- image: "matrixdotorg/sytest-synapse:py35"
+ image: "matrixdotorg/sytest-synapse:dinsic-py3"
propagate-environment: true
always-pull: true
workdir: "/src"
@@ -208,22 +213,23 @@ steps:
- exit_status: 2
limit: 2
- - label: "SyTest - :python: 3.5 / :postgres: 9.6 / Workers"
+ - label: "SyTest - :python: 3 / :postgres: 9.6 / Workers"
agents:
queue: "medium"
env:
POSTGRES: "1"
WORKERS: "1"
+ BLACKLIST: "synapse-blacklist-with-workers"
command:
- "bash .buildkite/merge_base_branch.sh"
+ - "bash -c 'cat /src/sytest-blacklist /src/.buildkite/worker-blacklist > /src/synapse-blacklist-with-workers'"
- "bash /synapse_sytest.sh"
plugins:
- docker#v3.0.1:
- image: "matrixdotorg/sytest-synapse:py35"
+ image: "matrixdotorg/sytest-synapse:dinsic-py3"
propagate-environment: true
always-pull: true
workdir: "/src"
- soft_fail: true
retry:
automatic:
- exit_status: -1
diff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist
new file mode 100644
index 0000000000..8ed8eef1a3
--- /dev/null
+++ b/.buildkite/worker-blacklist
@@ -0,0 +1,34 @@
+# This file serves as a blacklist for SyTest tests that we expect will fail in
+# Synapse when run under worker mode. For more details, see sytest-blacklist.
+
+Message history can be paginated
+
+m.room.history_visibility == "world_readable" allows/forbids appropriately for Guest users
+
+m.room.history_visibility == "world_readable" allows/forbids appropriately for Real users
+
+Can re-join room if re-invited
+
+/upgrade creates a new room
+
+The only membership state included in an initial sync is for all the senders in the timeline
+
+Local device key changes get to remote servers
+
+If remote user leaves room we no longer receive device updates
+
+Forgotten room messages cannot be paginated
+
+Inbound federation can get public room list
+
+Members from the gap are included in gappy incr LL sync
+
+Leaves are present in non-gapped incremental syncs
+
+Old leaves are present in gapped incremental syncs
+
+User sees updates to presence from other users in the incremental sync.
+
+Gapped incremental syncs include all state changes
+
+Old members are included in gappy incr LL sync if they start speaking
diff --git a/MANIFEST.in b/MANIFEST.in
index 834ddfad39..8855e74d08 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,4 +1,5 @@
include synctl
+include sytest-blacklist
include LICENSE
include VERSION
include *.rst
@@ -7,7 +8,6 @@ include demo/README
include demo/demo.tls.dh
include demo/*.py
include demo/*.sh
-include sytest-blacklist
recursive-include synapse/storage/schema *.sql
recursive-include synapse/storage/schema *.sql.postgres
@@ -34,6 +34,7 @@ exclude Dockerfile
exclude .dockerignore
exclude test_postgresql.sh
exclude .editorconfig
+exclude sytest-blacklist
include pyproject.toml
recursive-include changelog.d *
@@ -49,3 +50,8 @@ prune .buildkite
exclude jenkins*
recursive-exclude jenkins *.sh
+
+# FIXME: we shouldn't have these templates here
+recursive-include res/templates-dinsic *.css
+recursive-include res/templates-dinsic *.html
+recursive-include res/templates-dinsic *.txt
diff --git a/changelog.d/1.feature b/changelog.d/1.feature
new file mode 100644
index 0000000000..845642e445
--- /dev/null
+++ b/changelog.d/1.feature
@@ -0,0 +1 @@
+Forbid changing the name, avatar or topic of a direct room.
diff --git a/changelog.d/10.bugfix b/changelog.d/10.bugfix
new file mode 100644
index 0000000000..51f89f46dd
--- /dev/null
+++ b/changelog.d/10.bugfix
@@ -0,0 +1 @@
+Don't apply retention policy based filtering on state events.
diff --git a/changelog.d/11.feature b/changelog.d/11.feature
new file mode 100644
index 0000000000..362e4b1efd
--- /dev/null
+++ b/changelog.d/11.feature
@@ -0,0 +1 @@
+Allow server admins to configure a custom global rate-limiting for third party invites.
\ No newline at end of file
diff --git a/changelog.d/12.feature b/changelog.d/12.feature
new file mode 100644
index 0000000000..8e6e7a28af
--- /dev/null
+++ b/changelog.d/12.feature
@@ -0,0 +1 @@
+Add `/user/:user_id/info` CS servlet and to give user deactivated/expired information.
\ No newline at end of file
diff --git a/changelog.d/13.feature b/changelog.d/13.feature
new file mode 100644
index 0000000000..c2d2e93abf
--- /dev/null
+++ b/changelog.d/13.feature
@@ -0,0 +1 @@
+Hide expired users from the user directory, and optionally re-add them on renewal.
\ No newline at end of file
diff --git a/changelog.d/14.feature b/changelog.d/14.feature
new file mode 100644
index 0000000000..020d0bac1e
--- /dev/null
+++ b/changelog.d/14.feature
@@ -0,0 +1 @@
+User displaynames now have capitalised letters after - symbols.
\ No newline at end of file
diff --git a/changelog.d/15.misc b/changelog.d/15.misc
new file mode 100644
index 0000000000..4cc4a5175f
--- /dev/null
+++ b/changelog.d/15.misc
@@ -0,0 +1 @@
+Fix the ordering on `scripts/generate_signing_key.py`'s import statement.
diff --git a/changelog.d/17.misc b/changelog.d/17.misc
new file mode 100644
index 0000000000..58120ab5c7
--- /dev/null
+++ b/changelog.d/17.misc
@@ -0,0 +1 @@
+Blacklist some flaky sytests until they're fixed.
\ No newline at end of file
diff --git a/changelog.d/18.feature b/changelog.d/18.feature
new file mode 100644
index 0000000000..f5aa29a6e8
--- /dev/null
+++ b/changelog.d/18.feature
@@ -0,0 +1 @@
+Add option `limit_profile_requests_to_known_users` to prevent requirement of a user sharing a room with another user to query their profile information.
\ No newline at end of file
diff --git a/changelog.d/19.feature b/changelog.d/19.feature
new file mode 100644
index 0000000000..95a44a4a89
--- /dev/null
+++ b/changelog.d/19.feature
@@ -0,0 +1 @@
+Add `max_avatar_size` and `allowed_avatar_mimetypes` to restrict the size of user avatars and their file type respectively.
\ No newline at end of file
diff --git a/changelog.d/2.bugfix b/changelog.d/2.bugfix
new file mode 100644
index 0000000000..4fe5691468
--- /dev/null
+++ b/changelog.d/2.bugfix
@@ -0,0 +1 @@
+Don't treat 3PID revocation as a new 3PID invite.
diff --git a/changelog.d/20.bugfix b/changelog.d/20.bugfix
new file mode 100644
index 0000000000..8ba53c28f9
--- /dev/null
+++ b/changelog.d/20.bugfix
@@ -0,0 +1 @@
+Validate `client_secret` parameter against the regex provided by the C-S spec.
\ No newline at end of file
diff --git a/changelog.d/21.bugfix b/changelog.d/21.bugfix
new file mode 100644
index 0000000000..630d7812f7
--- /dev/null
+++ b/changelog.d/21.bugfix
@@ -0,0 +1 @@
+Fix resetting user passwords via a phone number.
diff --git a/changelog.d/3.bugfix b/changelog.d/3.bugfix
new file mode 100644
index 0000000000..cc4bcefa80
--- /dev/null
+++ b/changelog.d/3.bugfix
@@ -0,0 +1 @@
+Fix encoding on password reset HTML responses in Python 2.
diff --git a/changelog.d/4.bugfix b/changelog.d/4.bugfix
new file mode 100644
index 0000000000..fe717920a6
--- /dev/null
+++ b/changelog.d/4.bugfix
@@ -0,0 +1 @@
+Fix handling of filtered strings in Python 3.
diff --git a/changelog.d/5.bugfix b/changelog.d/5.bugfix
new file mode 100644
index 0000000000..53f57f46ca
--- /dev/null
+++ b/changelog.d/5.bugfix
@@ -0,0 +1 @@
+Fix room retention policy management in worker mode.
diff --git a/changelog.d/5083.feature b/changelog.d/5083.feature
new file mode 100644
index 0000000000..2ffdd37eef
--- /dev/null
+++ b/changelog.d/5083.feature
@@ -0,0 +1 @@
+Adds auth_profile_reqs option to require access_token to GET /profile endpoints on CS API.
diff --git a/changelog.d/5098.misc b/changelog.d/5098.misc
new file mode 100644
index 0000000000..9cd83bf226
--- /dev/null
+++ b/changelog.d/5098.misc
@@ -0,0 +1 @@
+Add workarounds for pep-517 install errors.
diff --git a/changelog.d/5214.feature b/changelog.d/5214.feature
new file mode 100644
index 0000000000..6c0f15c901
--- /dev/null
+++ b/changelog.d/5214.feature
@@ -0,0 +1 @@
+Allow server admins to define and enforce a password policy (MSC2000).
diff --git a/changelog.d/5416.misc b/changelog.d/5416.misc
new file mode 100644
index 0000000000..155e8c7cd3
--- /dev/null
+++ b/changelog.d/5416.misc
@@ -0,0 +1 @@
+Add unique index to the profile_replication_status table.
diff --git a/changelog.d/5420.feature b/changelog.d/5420.feature
new file mode 100644
index 0000000000..745864b903
--- /dev/null
+++ b/changelog.d/5420.feature
@@ -0,0 +1 @@
+Add configuration option to hide new users from the user directory.
diff --git a/changelog.d/5610.feature b/changelog.d/5610.feature
new file mode 100644
index 0000000000..b99514f97e
--- /dev/null
+++ b/changelog.d/5610.feature
@@ -0,0 +1 @@
+Implement new custom event rules for power levels.
diff --git a/changelog.d/5678.removal b/changelog.d/5678.removal
new file mode 100644
index 0000000000..085b84fda6
--- /dev/null
+++ b/changelog.d/5678.removal
@@ -0,0 +1 @@
+Synapse now no longer accepts the `-v`/`--verbose`, `-f`/`--log-file`, or `--log-config` command line flags, and removes the deprecated `verbose` and `log_file` configuration file options. Users of these options should migrate their options into the dedicated log configuration.
diff --git a/changelog.d/5694.misc b/changelog.d/5694.misc
new file mode 100644
index 0000000000..3b12dcc849
--- /dev/null
+++ b/changelog.d/5694.misc
@@ -0,0 +1 @@
+Make Jaeger fully configurable.
diff --git a/changelog.d/5695.misc b/changelog.d/5695.misc
new file mode 100644
index 0000000000..4741d32e25
--- /dev/null
+++ b/changelog.d/5695.misc
@@ -0,0 +1 @@
+Add precautionary measures to prevent future abuse of `window.opener` in default welcome page.
diff --git a/changelog.d/5702.bugfix b/changelog.d/5702.bugfix
new file mode 100644
index 0000000000..43b6e39b13
--- /dev/null
+++ b/changelog.d/5702.bugfix
@@ -0,0 +1 @@
+Fix 3PID invite to invite association detection in the Tchap room access rules.
diff --git a/changelog.d/5706.misc b/changelog.d/5706.misc
new file mode 100644
index 0000000000..5e15dfd5fa
--- /dev/null
+++ b/changelog.d/5706.misc
@@ -0,0 +1 @@
+Reduce database IO usage by optimising queries for current membership.
diff --git a/changelog.d/5713.misc b/changelog.d/5713.misc
new file mode 100644
index 0000000000..01ea1cf8d7
--- /dev/null
+++ b/changelog.d/5713.misc
@@ -0,0 +1 @@
+Improve caching when fetching `get_filtered_current_state_ids`.
diff --git a/changelog.d/5715.misc b/changelog.d/5715.misc
new file mode 100644
index 0000000000..a77366e0c0
--- /dev/null
+++ b/changelog.d/5715.misc
@@ -0,0 +1 @@
+Don't accept opentracing data from clients.
diff --git a/changelog.d/5717.misc b/changelog.d/5717.misc
new file mode 100644
index 0000000000..07dc3bca94
--- /dev/null
+++ b/changelog.d/5717.misc
@@ -0,0 +1 @@
+Speed up PostgreSQL unit tests in CI.
diff --git a/changelog.d/5719.misc b/changelog.d/5719.misc
new file mode 100644
index 0000000000..6d5294724c
--- /dev/null
+++ b/changelog.d/5719.misc
@@ -0,0 +1 @@
+Update the coding style document.
diff --git a/changelog.d/5720.misc b/changelog.d/5720.misc
new file mode 100644
index 0000000000..590f64f19d
--- /dev/null
+++ b/changelog.d/5720.misc
@@ -0,0 +1 @@
+Improve database query performance when recording retry intervals for remote hosts.
diff --git a/changelog.d/5722.misc b/changelog.d/5722.misc
new file mode 100644
index 0000000000..f2d236188d
--- /dev/null
+++ b/changelog.d/5722.misc
@@ -0,0 +1 @@
+Add a set of opentracing utils.
diff --git a/changelog.d/5724.bugfix b/changelog.d/5724.bugfix
new file mode 100644
index 0000000000..1b3683daf6
--- /dev/null
+++ b/changelog.d/5724.bugfix
@@ -0,0 +1 @@
+Fix stack overflow in server key lookup code.
\ No newline at end of file
diff --git a/changelog.d/5725.bugfix b/changelog.d/5725.bugfix
new file mode 100644
index 0000000000..73ef419727
--- /dev/null
+++ b/changelog.d/5725.bugfix
@@ -0,0 +1 @@
+start.sh no longer uses deprecated cli option.
diff --git a/changelog.d/5729.removal b/changelog.d/5729.removal
new file mode 100644
index 0000000000..3af5198e6b
--- /dev/null
+++ b/changelog.d/5729.removal
@@ -0,0 +1 @@
+ Synapse now no longer accepts the `-v`/`--verbose`, `-f`/`--log-file`, or `--log-config` command line flags, and removes the deprecated `verbose` and `log_file` configuration file options. Users of these options should migrate their options into the dedicated log configuration.
diff --git a/changelog.d/5730.misc b/changelog.d/5730.misc
new file mode 100644
index 0000000000..a99677f5e7
--- /dev/null
+++ b/changelog.d/5730.misc
@@ -0,0 +1 @@
+Cache result of get_version_string to reduce overhead of `/version` federation requests.
diff --git a/changelog.d/5731.misc b/changelog.d/5731.misc
new file mode 100644
index 0000000000..dffae5d874
--- /dev/null
+++ b/changelog.d/5731.misc
@@ -0,0 +1 @@
+Return 'user_type' in admin API user endpoints results.
diff --git a/changelog.d/5732.feature b/changelog.d/5732.feature
new file mode 100644
index 0000000000..9021864350
--- /dev/null
+++ b/changelog.d/5732.feature
@@ -0,0 +1 @@
+Add sd_notify hooks to ease systemd integration and allows usage of Type=Notify.
diff --git a/changelog.d/5733.misc b/changelog.d/5733.misc
new file mode 100644
index 0000000000..a2a8c26383
--- /dev/null
+++ b/changelog.d/5733.misc
@@ -0,0 +1 @@
+Don't package the sytest test blacklist file.
diff --git a/changelog.d/5736.misc b/changelog.d/5736.misc
new file mode 100644
index 0000000000..5713b8b32d
--- /dev/null
+++ b/changelog.d/5736.misc
@@ -0,0 +1 @@
+Replace uses of returnValue with plain return, as returnValue is not needed on Python 3.
diff --git a/changelog.d/5738.misc b/changelog.d/5738.misc
new file mode 100644
index 0000000000..5e15dfd5fa
--- /dev/null
+++ b/changelog.d/5738.misc
@@ -0,0 +1 @@
+Reduce database IO usage by optimising queries for current membership.
diff --git a/changelog.d/5740.misc b/changelog.d/5740.misc
new file mode 100644
index 0000000000..97a476bef5
--- /dev/null
+++ b/changelog.d/5740.misc
@@ -0,0 +1 @@
+Blacklist some flakey tests in worker mode.
diff --git a/changelog.d/5743.bugfix b/changelog.d/5743.bugfix
new file mode 100644
index 0000000000..65728ff079
--- /dev/null
+++ b/changelog.d/5743.bugfix
@@ -0,0 +1 @@
+Log when we receive an event receipt from an unexpected origin.
diff --git a/changelog.d/5749.misc b/changelog.d/5749.misc
new file mode 100644
index 0000000000..48dd61f461
--- /dev/null
+++ b/changelog.d/5749.misc
@@ -0,0 +1 @@
+Fix some error cases in the caching layer.
diff --git a/changelog.d/5750.misc b/changelog.d/5750.misc
new file mode 100644
index 0000000000..6beaa460a5
--- /dev/null
+++ b/changelog.d/5750.misc
@@ -0,0 +1 @@
+Add a prometheus metric for pending cache lookups.
\ No newline at end of file
diff --git a/changelog.d/5753.misc b/changelog.d/5753.misc
new file mode 100644
index 0000000000..22bba9ce3c
--- /dev/null
+++ b/changelog.d/5753.misc
@@ -0,0 +1 @@
+Stop trying to fetch events with event_id=None.
diff --git a/changelog.d/5760.feature b/changelog.d/5760.feature
new file mode 100644
index 0000000000..90302d793e
--- /dev/null
+++ b/changelog.d/5760.feature
@@ -0,0 +1 @@
+Force the access rule to be "restricted" if the join rule is "public".
diff --git a/changelog.d/5768.misc b/changelog.d/5768.misc
new file mode 100644
index 0000000000..7a9c88b4c2
--- /dev/null
+++ b/changelog.d/5768.misc
@@ -0,0 +1 @@
+Convert RedactionTestCase to modern test style.
diff --git a/changelog.d/5780.misc b/changelog.d/5780.misc
new file mode 100644
index 0000000000..b7eb56e625
--- /dev/null
+++ b/changelog.d/5780.misc
@@ -0,0 +1 @@
+Allow looping calls to be given arguments.
diff --git a/changelog.d/5807.feature b/changelog.d/5807.feature
new file mode 100644
index 0000000000..8b7d29a23c
--- /dev/null
+++ b/changelog.d/5807.feature
@@ -0,0 +1 @@
+Allow defining HTML templates to serve the user on account renewal attempt when using the account validity feature.
diff --git a/changelog.d/5815.feature b/changelog.d/5815.feature
new file mode 100644
index 0000000000..ca4df4e7f6
--- /dev/null
+++ b/changelog.d/5815.feature
@@ -0,0 +1 @@
+Implement per-room message retention policies.
diff --git a/changelog.d/6.bugfix b/changelog.d/6.bugfix
new file mode 100644
index 0000000000..43ab65cc95
--- /dev/null
+++ b/changelog.d/6.bugfix
@@ -0,0 +1 @@
+Don't forbid membership events which membership isn't 'join' or 'invite' in restricted rooms, so that users who got into these rooms before the access rules started to be enforced can leave them.
diff --git a/changelog.d/6125.feature b/changelog.d/6125.feature
new file mode 100644
index 0000000000..cbe5f8d3c8
--- /dev/null
+++ b/changelog.d/6125.feature
@@ -0,0 +1 @@
+Reject all pending invites for a user during deactivation.
diff --git a/changelog.d/6147.bugfix b/changelog.d/6147.bugfix
new file mode 100644
index 0000000000..b0f936d280
--- /dev/null
+++ b/changelog.d/6147.bugfix
@@ -0,0 +1 @@
+Don't 500 when trying to exchange a revoked 3PID invite.
diff --git a/changelog.d/6238.feature b/changelog.d/6238.feature
new file mode 100644
index 0000000000..d225ac33b6
--- /dev/null
+++ b/changelog.d/6238.feature
@@ -0,0 +1 @@
+Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars.
diff --git a/changelog.d/6436.bugfix b/changelog.d/6436.bugfix
new file mode 100644
index 0000000000..954a4e1d84
--- /dev/null
+++ b/changelog.d/6436.bugfix
@@ -0,0 +1 @@
+Fix a bug where a room could become unusable with a low retention policy and a low activity.
diff --git a/changelog.d/9.misc b/changelog.d/9.misc
new file mode 100644
index 0000000000..24fd12c978
--- /dev/null
+++ b/changelog.d/9.misc
@@ -0,0 +1 @@
+Add SyTest to the BuildKite CI.
diff --git a/contrib/systemd-with-workers/system/matrix-synapse-worker@.service b/contrib/systemd-with-workers/system/matrix-synapse-worker@.service
index 9d980d5168..3507e2e989 100644
--- a/contrib/systemd-with-workers/system/matrix-synapse-worker@.service
+++ b/contrib/systemd-with-workers/system/matrix-synapse-worker@.service
@@ -4,7 +4,8 @@ After=matrix-synapse.service
BindsTo=matrix-synapse.service
[Service]
-Type=simple
+Type=notify
+NotifyAccess=main
User=matrix-synapse
WorkingDirectory=/var/lib/matrix-synapse
EnvironmentFile=/etc/default/matrix-synapse
diff --git a/contrib/systemd-with-workers/system/matrix-synapse.service b/contrib/systemd-with-workers/system/matrix-synapse.service
index 3aae19034c..68e8991f18 100644
--- a/contrib/systemd-with-workers/system/matrix-synapse.service
+++ b/contrib/systemd-with-workers/system/matrix-synapse.service
@@ -2,7 +2,8 @@
Description=Synapse Matrix Homeserver
[Service]
-Type=simple
+Type=notify
+NotifyAccess=main
User=matrix-synapse
WorkingDirectory=/var/lib/matrix-synapse
EnvironmentFile=/etc/default/matrix-synapse
diff --git a/contrib/systemd/matrix-synapse.service b/contrib/systemd/matrix-synapse.service
index 595b69916c..38d369ea3d 100644
--- a/contrib/systemd/matrix-synapse.service
+++ b/contrib/systemd/matrix-synapse.service
@@ -14,7 +14,9 @@
Description=Synapse Matrix homeserver
[Service]
-Type=simple
+Type=notify
+NotifyAccess=main
+ExecReload=/bin/kill -HUP $MAINPID
Restart=on-abort
User=synapse
diff --git a/demo/start.sh b/demo/start.sh
index 1c4f12d0bb..eccaa2abeb 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -29,7 +29,7 @@ for port in 8080 8081 8082; do
if ! grep -F "Customisation made by demo/start.sh" -q $DIR/etc/$port.config; then
printf '\n\n# Customisation made by demo/start.sh\n' >> $DIR/etc/$port.config
-
+
echo 'enable_registration: true' >> $DIR/etc/$port.config
# Warning, this heredoc depends on the interaction of tabs and spaces. Please don't
@@ -43,7 +43,7 @@ for port in 8080 8081 8082; do
tls: true
resources:
- names: [client, federation]
-
+
- port: $port
tls: false
bind_addresses: ['::1', '127.0.0.1']
@@ -68,7 +68,7 @@ for port in 8080 8081 8082; do
# Generate tls keys
openssl req -x509 -newkey rsa:4096 -keyout $DIR/etc/localhost\:$https_port.tls.key -out $DIR/etc/localhost\:$https_port.tls.crt -days 365 -nodes -subj "/O=matrix"
-
+
# Ignore keys from the trusted keys server
echo '# Ignore keys from the trusted keys server' >> $DIR/etc/$port.config
echo 'trusted_key_servers:' >> $DIR/etc/$port.config
@@ -120,7 +120,6 @@ for port in 8080 8081 8082; do
python3 -m synapse.app.homeserver \
--config-path "$DIR/etc/$port.config" \
-D \
- -vv \
popd
done
diff --git a/docs/code_style.rst b/docs/code_style.rst
index e3ca626bfd..39ac4ebedc 100644
--- a/docs/code_style.rst
+++ b/docs/code_style.rst
@@ -1,4 +1,8 @@
-# Code Style
+Code Style
+==========
+
+Formatting tools
+----------------
The Synapse codebase uses a number of code formatting tools in order to
quickly and automatically check for formatting (and sometimes logical) errors
@@ -6,20 +10,20 @@ in code.
The necessary tools are detailed below.
-## Formatting tools
+- **black**
-The Synapse codebase uses [black](https://pypi.org/project/black/) as an
-opinionated code formatter, ensuring all comitted code is properly
-formatted.
+ The Synapse codebase uses `black <https://pypi.org/project/black/>`_ as an
+ opinionated code formatter, ensuring all comitted code is properly
+ formatted.
-First install ``black`` with::
+ First install ``black`` with::
- pip install --upgrade black
+ pip install --upgrade black
-Have ``black`` auto-format your code (it shouldn't change any
-functionality) with::
+ Have ``black`` auto-format your code (it shouldn't change any functionality)
+ with::
- black . --exclude="\.tox|build|env"
+ black . --exclude="\.tox|build|env"
- **flake8**
@@ -54,17 +58,16 @@ functionality is supported in your editor for a more convenient development
workflow. It is not, however, recommended to run ``flake8`` on save as it
takes a while and is very resource intensive.
-## General rules
+General rules
+-------------
- **Naming**:
- Use camel case for class and type names
- Use underscores for functions and variables.
-- Use double quotes ``"foo"`` rather than single quotes ``'foo'``.
-
-- **Comments**: should follow the `google code style
- <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
+- **Docstrings**: should follow the `google code style
+ <https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings>`_.
This is so that we can generate documentation with `sphinx
<http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
`examples
@@ -73,6 +76,8 @@ takes a while and is very resource intensive.
- **Imports**:
+ - Imports should be sorted by ``isort`` as described above.
+
- Prefer to import classes and functions rather than packages or modules.
Example::
@@ -92,25 +97,84 @@ takes a while and is very resource intensive.
This goes against the advice in the Google style guide, but it means that
errors in the name are caught early (at import time).
- - Multiple imports from the same package can be combined onto one line::
+ - Avoid wildcard imports (``from synapse.types import *``) and relative
+ imports (``from .types import UserID``).
- from synapse.types import GroupID, RoomID, UserID
+Configuration file format
+-------------------------
- An effort should be made to keep the individual imports in alphabetical
- order.
+The `sample configuration file <./sample_config.yaml>`_ acts as a reference to
+Synapse's configuration options for server administrators. Remember that many
+readers will be unfamiliar with YAML and server administration in general, so
+that it is important that the file be as easy to understand as possible, which
+includes following a consistent format.
- If the list becomes long, wrap it with parentheses and split it over
- multiple lines.
+Some guidelines follow:
- - As per `PEP-8 <https://www.python.org/dev/peps/pep-0008/#imports>`_,
- imports should be grouped in the following order, with a blank line between
- each group:
+* Sections should be separated with a heading consisting of a single line
+ prefixed and suffixed with ``##``. There should be **two** blank lines
+ before the section header, and **one** after.
- 1. standard library imports
- 2. related third party imports
- 3. local application/library specific imports
+* Each option should be listed in the file with the following format:
- - Imports within each group should be sorted alphabetically by module name.
+ * A comment describing the setting. Each line of this comment should be
+ prefixed with a hash (``#``) and a space.
- - Avoid wildcard imports (``from synapse.types import *``) and relative
- imports (``from .types import UserID``).
+ The comment should describe the default behaviour (ie, what happens if
+ the setting is omitted), as well as what the effect will be if the
+ setting is changed.
+
+ Often, the comment end with something like "uncomment the
+ following to \<do action>".
+
+ * A line consisting of only ``#``.
+
+ * A commented-out example setting, prefixed with only ``#``.
+
+ For boolean (on/off) options, convention is that this example should be
+ the *opposite* to the default (so the comment will end with "Uncomment
+ the following to enable [or disable] \<feature\>." For other options,
+ the example should give some non-default value which is likely to be
+ useful to the reader.
+
+* There should be a blank line between each option.
+
+* Where several settings are grouped into a single dict, *avoid* the
+ convention where the whole block is commented out, resulting in comment
+ lines starting ``# #``, as this is hard to read and confusing to
+ edit. Instead, leave the top-level config option uncommented, and follow
+ the conventions above for sub-options. Ensure that your code correctly
+ handles the top-level option being set to ``None`` (as it will be if no
+ sub-options are enabled).
+
+* Lines should be wrapped at 80 characters.
+
+Example::
+
+ ## Frobnication ##
+
+ # The frobnicator will ensure that all requests are fully frobnicated.
+ # To enable it, uncomment the following.
+ #
+ #frobnicator_enabled: true
+
+ # By default, the frobnicator will frobnicate with the default frobber.
+ # The following will make it use an alternative frobber.
+ #
+ #frobincator_frobber: special_frobber
+
+ # Settings for the frobber
+ #
+ frobber:
+ # frobbing speed. Defaults to 1.
+ #
+ #speed: 10
+
+ # frobbing distance. Defaults to 1000.
+ #
+ #distance: 100
+
+Note that the sample configuration is generated from the synapse code and is
+maintained by a script, ``scripts-dev/generate_sample_config``. Making sure
+that the output from this script matches the desired format is left as an
+exercise for the reader!
diff --git a/docs/log_contexts.rst b/docs/log_contexts.rst
index f5cd5de8ab..4502cd9454 100644
--- a/docs/log_contexts.rst
+++ b/docs/log_contexts.rst
@@ -148,7 +148,7 @@ call any other functions.
d = more_stuff()
result = yield d # also fine, of course
- defer.returnValue(result)
+ return result
def nonInlineCallbacksFun():
logger.debug("just a wrapper really")
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 0a96197ca6..972c212f3d 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -54,6 +54,13 @@ pid_file: DATADIR/homeserver.pid
#
#require_auth_for_profile_requests: true
+# Whether to require a user to share a room with another user in order
+# to retrieve their profile information. Only checked on Client-Server
+# requests. Profile requests from other servers should be checked by the
+# requesting server. Defaults to 'false'.
+#
+# limit_profile_requests_to_known_users: true
+
# If set to 'false', requires authentication to access the server's public rooms
# directory through the client API. Defaults to 'true'.
#
@@ -289,6 +296,74 @@ listeners:
#
#allow_per_room_profiles: false
+# Whether to show the users on this homeserver in the user directory. Defaults to
+# 'true'.
+#
+#show_users_in_user_directory: false
+
+# Message retention policy at the server level.
+#
+# Room admins and mods can define a retention period for their rooms using the
+# 'm.room.retention' state event, and server admins can cap this period by setting
+# the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+#
+# If this feature is enabled, Synapse will regularly look for and purge events
+# which are older than the room's maximum retention period. Synapse will also
+# filter events received over federation so that events that should have been
+# purged are ignored and not stored again.
+#
+retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
+
## TLS ##
@@ -493,6 +568,8 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+# - one that ratelimits third-party invites requests based on the account
+# that's making the requests.
#
# The defaults are as shown below.
#
@@ -514,6 +591,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# failed_attempts:
# per_second: 0.17
# burst_count: 3
+#
+#rc_third_party_invite:
+# per_second: 0.2
+# burst_count: 10
# Ratelimiting settings for incoming federation
@@ -575,6 +656,30 @@ uploads_path: "DATADIR/uploads"
#
#max_upload_size: 10M
+# The largest allowed size for a user avatar. If not defined, no
+# restriction will be imposed.
+#
+# Note that this only applies when an avatar is changed globally.
+# Per-room avatar changes are not affected. See allow_per_room_profiles
+# for disabling that functionality.
+#
+# Note that user avatar changes will not work if this is set without
+# using Synapse's local media repo.
+#
+#max_avatar_size: 10M
+
+# Allow mimetypes for a user avatar. If not defined, no restriction will
+# be imposed.
+#
+# Note that this only applies when an avatar is changed globally.
+# Per-room avatar changes are not affected. See allow_per_room_profiles
+# for disabling that functionality.
+#
+# Note that user avatar changes will not work if this is set without
+# using Synapse's local media repo.
+#
+#allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
@@ -785,6 +890,16 @@ uploads_path: "DATADIR/uploads"
# period: 6w
# renew_at: 1w
# renew_email_subject: "Renew your %(app)s account"
+# # Directory in which Synapse will try to find the HTML files to serve to the
+# # user when trying to renew an account. Optional, defaults to
+# # synapse/res/templates.
+# template_dir: "res/templates"
+# # HTML to be displayed to the user after they successfully renewed their
+# # account. Optional.
+# account_renewed_html_path: "account_renewed.html"
+# # HTML to be displayed when the user tries to renew an account with an invalid
+# # renewal token. Optional.
+# invalid_token_html_path: "invalid_token.html"
# Time that a user's session remains valid for, after they log in.
#
@@ -808,9 +923,32 @@ uploads_path: "DATADIR/uploads"
#
#disable_msisdn_registration: true
+# Derive the user's matrix ID from a type of 3PID used when registering.
+# This overrides any matrix ID the user proposes when calling /register
+# The 3PID type should be present in registrations_require_3pid to avoid
+# users failing to register if they don't specify the right kind of 3pid.
+#
+#register_mxid_from_3pid: email
+
+# Uncomment to set the display name of new users to their email address,
+# rather than using the default heuristic.
+#
+#register_just_use_email_for_display_name: true
+
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
+# Use an Identity Server to establish which 3PIDs are allowed to register?
+# Overrides allowed_local_3pids below.
+#
+#check_is_for_allowed_local_3pids: matrix.org
+#
+# If you are using an IS you can also check whether that IS registers
+# pending invites for the given 3PID (and then allow it to sign up on
+# the platform):
+#
+#allow_invited_3pids: False
+#
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\.org'
@@ -819,6 +957,11 @@ uploads_path: "DATADIR/uploads"
# - medium: msisdn
# pattern: '\+44'
+# If true, stop users from trying to change the 3PIDs associated with
+# their accounts.
+#
+#disable_3pid_changes: False
+
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -860,6 +1003,30 @@ uploads_path: "DATADIR/uploads"
# - matrix.org
# - vector.im
+# If enabled, user IDs, display names and avatar URLs will be replicated
+# to this server whenever they change.
+# This is an experimental API currently implemented by sydent to support
+# cross-homeserver user directories.
+#
+#replicate_user_profiles_to: example.com
+
+# If specified, attempt to replay registrations, profile changes & 3pid
+# bindings on the given target homeserver via the AS API. The HS is authed
+# via a given AS token.
+#
+#shadow_server:
+# hs_url: https://shadow.example.com
+# hs: shadow.example.com
+# as_token: 12u394refgbdhivsia
+
+# If enabled, don't let users set their own display names/avatars
+# other than for the very first time (unless they are a server admin).
+# Useful when provisioning users based on the contents of a 3rd party
+# directory and to avoid ambiguities.
+#
+#disable_set_displayname: False
+#disable_set_avatar_url: False
+
# Users who register on this homeserver will automatically be joined
# to these rooms
#
@@ -1097,6 +1264,36 @@ password_config:
#
#pepper: "EVEN_MORE_SECRET"
+ # Define and enforce a password policy. Each parameter is optional, boolean
+ # parameters default to 'false' and integer parameters default to 0.
+ # This is an early implementation of MSC2000.
+ #
+ #policy:
+ # Whether to enforce the password policy.
+ #
+ #enabled: true
+
+ # Minimum accepted length for a password.
+ #
+ #minimum_length: 15
+
+ # Whether a password must contain at least one digit.
+ #
+ #require_digit: true
+
+ # Whether a password must contain at least one symbol.
+ # A symbol is any character that's not a number or a letter.
+ #
+ #require_symbol: true
+
+ # Whether a password must contain at least one lowercase letter.
+ #
+ #require_lowercase: true
+
+ # Whether a password must contain at least one lowercase letter.
+ #
+ #require_uppercase: true
+
# Enable sending emails for password resets, notification events or
@@ -1241,6 +1438,11 @@ password_config:
#user_directory:
# enabled: true
# search_all_users: false
+#
+# # If this is set, user search will be delegated to this ID server instead
+# # of synapse performing the search itself.
+# # This is an experimental API.
+# defer_to_id_server: https://id.example.com
# User Consent configuration
@@ -1430,3 +1632,19 @@ opentracing:
#
#homeserver_whitelist:
# - ".*"
+
+ # Jaeger can be configured to sample traces at different rates.
+ # All configuration options provided by Jaeger can be set here.
+ # Jaeger's configuration mostly related to trace sampling which
+ # is documented here:
+ # https://www.jaegertracing.io/docs/1.13/sampling/.
+ #
+ #jaeger_config:
+ # sampler:
+ # type: const
+ # param: 1
+
+ # Logging whether spans were started and reported
+ #
+ # logging:
+ # false
diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py
index ca4b879526..5c5a115ca9 100644
--- a/docs/sphinx/conf.py
+++ b/docs/sphinx/conf.py
@@ -12,8 +12,8 @@
# All configuration values have a default; values that are commented out
# serve to show the default.
-import sys
import os
+import sys
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
@@ -191,11 +191,11 @@ htmlhelp_basename = "Synapsedoc"
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
- #'papersize': 'letterpaper',
+ # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
- #'pointsize': '10pt',
+ # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
- #'preamble': '',
+ # 'preamble': '',
}
# Grouping the document tree into LaTeX files. List of tuples
diff --git a/res/templates-dinsic/mail-Vector.css b/res/templates-dinsic/mail-Vector.css
new file mode 100644
index 0000000000..6a3e36eda1
--- /dev/null
+++ b/res/templates-dinsic/mail-Vector.css
@@ -0,0 +1,7 @@
+.header {
+ border-bottom: 4px solid #e4f7ed ! important;
+}
+
+.notif_link a, .footer a {
+ color: #76CFA6 ! important;
+}
diff --git a/res/templates-dinsic/mail.css b/res/templates-dinsic/mail.css
new file mode 100644
index 0000000000..5ab3e1b06d
--- /dev/null
+++ b/res/templates-dinsic/mail.css
@@ -0,0 +1,156 @@
+body {
+ margin: 0px;
+}
+
+pre, code {
+ word-break: break-word;
+ white-space: pre-wrap;
+}
+
+#page {
+ font-family: 'Open Sans', Helvetica, Arial, Sans-Serif;
+ font-color: #454545;
+ font-size: 12pt;
+ width: 100%;
+ padding: 20px;
+}
+
+#inner {
+ width: 640px;
+}
+
+.header {
+ width: 100%;
+ height: 87px;
+ color: #454545;
+ border-bottom: 4px solid #e5e5e5;
+}
+
+.logo {
+ text-align: right;
+ margin-left: 20px;
+}
+
+.salutation {
+ padding-top: 10px;
+ font-weight: bold;
+}
+
+.summarytext {
+}
+
+.room {
+ width: 100%;
+ color: #454545;
+ border-bottom: 1px solid #e5e5e5;
+}
+
+.room_header td {
+ padding-top: 38px;
+ padding-bottom: 10px;
+ border-bottom: 1px solid #e5e5e5;
+}
+
+.room_name {
+ vertical-align: middle;
+ font-size: 18px;
+ font-weight: bold;
+}
+
+.room_header h2 {
+ margin-top: 0px;
+ margin-left: 75px;
+ font-size: 20px;
+}
+
+.room_avatar {
+ width: 56px;
+ line-height: 0px;
+ text-align: center;
+ vertical-align: middle;
+}
+
+.room_avatar img {
+ width: 48px;
+ height: 48px;
+ object-fit: cover;
+ border-radius: 24px;
+}
+
+.notif {
+ border-bottom: 1px solid #e5e5e5;
+ margin-top: 16px;
+ padding-bottom: 16px;
+}
+
+.historical_message .sender_avatar {
+ opacity: 0.3;
+}
+
+/* spell out opacity and historical_message class names for Outlook aka Word */
+.historical_message .sender_name {
+ color: #e3e3e3;
+}
+
+.historical_message .message_time {
+ color: #e3e3e3;
+}
+
+.historical_message .message_body {
+ color: #c7c7c7;
+}
+
+.historical_message td,
+.message td {
+ padding-top: 10px;
+}
+
+.sender_avatar {
+ width: 56px;
+ text-align: center;
+ vertical-align: top;
+}
+
+.sender_avatar img {
+ margin-top: -2px;
+ width: 32px;
+ height: 32px;
+ border-radius: 16px;
+}
+
+.sender_name {
+ display: inline;
+ font-size: 13px;
+ color: #a2a2a2;
+}
+
+.message_time {
+ text-align: right;
+ width: 100px;
+ font-size: 11px;
+ color: #a2a2a2;
+}
+
+.message_body {
+}
+
+.notif_link td {
+ padding-top: 10px;
+ padding-bottom: 10px;
+ font-weight: bold;
+}
+
+.notif_link a, .footer a {
+ color: #454545;
+ text-decoration: none;
+}
+
+.debug {
+ font-size: 10px;
+ color: #888;
+}
+
+.footer {
+ margin-top: 20px;
+ text-align: center;
+}
\ No newline at end of file
diff --git a/res/templates-dinsic/notif.html b/res/templates-dinsic/notif.html
new file mode 100644
index 0000000000..bcdfeea9da
--- /dev/null
+++ b/res/templates-dinsic/notif.html
@@ -0,0 +1,45 @@
+{% for message in notif.messages %}
+ <tr class="{{ "historical_message" if message.is_historical else "message" }}">
+ <td class="sender_avatar">
+ {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
+ {% if message.sender_avatar_url %}
+ <img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
+ {% else %}
+ {% if message.sender_hash % 3 == 0 %}
+ <img class="sender_avatar" src="https://vector.im/beta/img/76cfa6.png" />
+ {% elif message.sender_hash % 3 == 1 %}
+ <img class="sender_avatar" src="https://vector.im/beta/img/50e2c2.png" />
+ {% else %}
+ <img class="sender_avatar" src="https://vector.im/beta/img/f4c371.png" />
+ {% endif %}
+ {% endif %}
+ {% endif %}
+ </td>
+ <td class="message_contents">
+ {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
+ <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
+ {% endif %}
+ <div class="message_body">
+ {% if message.msgtype == "m.text" %}
+ {{ message.body_text_html }}
+ {% elif message.msgtype == "m.emote" %}
+ {{ message.body_text_html }}
+ {% elif message.msgtype == "m.notice" %}
+ {{ message.body_text_html }}
+ {% elif message.msgtype == "m.image" %}
+ <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
+ {% elif message.msgtype == "m.file" %}
+ <span class="filename">{{ message.body_text_plain }}</span>
+ {% endif %}
+ </div>
+ </td>
+ <td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
+ </tr>
+{% endfor %}
+<tr class="notif_link">
+ <td></td>
+ <td>
+ <a href="{{ notif.link }}">Voir {{ room.title }}</a>
+ </td>
+ <td></td>
+</tr>
diff --git a/res/templates-dinsic/notif.txt b/res/templates-dinsic/notif.txt
new file mode 100644
index 0000000000..3dff1bb570
--- /dev/null
+++ b/res/templates-dinsic/notif.txt
@@ -0,0 +1,16 @@
+{% for message in notif.messages %}
+{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
+{% if message.msgtype == "m.text" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.emote" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.notice" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.image" %}
+{{ message.body_text_plain }}
+{% elif message.msgtype == "m.file" %}
+{{ message.body_text_plain }}
+{% endif %}
+{% endfor %}
+
+Voir {{ room.title }} à {{ notif.link }}
diff --git a/res/templates-dinsic/notif_mail.html b/res/templates-dinsic/notif_mail.html
new file mode 100644
index 0000000000..1e1efa74b2
--- /dev/null
+++ b/res/templates-dinsic/notif_mail.html
@@ -0,0 +1,55 @@
+<!doctype html>
+<html lang="en">
+ <head>
+ <style type="text/css">
+ {% include 'mail.css' without context %}
+ {% include "mail-%s.css" % app_name ignore missing without context %}
+ </style>
+ </head>
+ <body>
+ <table id="page">
+ <tr>
+ <td> </td>
+ <td id="inner">
+ <table class="header">
+ <tr>
+ <td>
+ <div class="salutation">Bonjour {{ user_display_name }},</div>
+ <div class="summarytext">{{ summary_text }}</div>
+ </td>
+ <td class="logo">
+ {% if app_name == "Riot" %}
+ <img src="http://matrix.org/img/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
+ {% elif app_name == "Vector" %}
+ <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
+ {% else %}
+ <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
+ {% endif %}
+ </td>
+ </tr>
+ </table>
+ {% for room in rooms %}
+ {% include 'room.html' with context %}
+ {% endfor %}
+ <div class="footer">
+ <a href="{{ unsubscribe_link }}">Se désinscrire</a>
+ <br/>
+ <br/>
+ <div class="debug">
+ Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
+ an event was received at {{ reason.received_at|format_ts("%c") }}
+ which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
+ {% if reason.last_sent_ts %}
+ and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
+ which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
+ {% else %}
+ and we don't have a last time we sent a mail for this room.
+ {% endif %}
+ </div>
+ </div>
+ </td>
+ <td> </td>
+ </tr>
+ </table>
+ </body>
+</html>
diff --git a/res/templates-dinsic/notif_mail.txt b/res/templates-dinsic/notif_mail.txt
new file mode 100644
index 0000000000..fae877426f
--- /dev/null
+++ b/res/templates-dinsic/notif_mail.txt
@@ -0,0 +1,10 @@
+Bonjour {{ user_display_name }},
+
+{{ summary_text }}
+
+{% for room in rooms %}
+{% include 'room.txt' with context %}
+{% endfor %}
+
+Vous pouvez désactiver ces notifications en cliquant ici {{ unsubscribe_link }}
+
diff --git a/res/templates-dinsic/room.html b/res/templates-dinsic/room.html
new file mode 100644
index 0000000000..0487b1b11c
--- /dev/null
+++ b/res/templates-dinsic/room.html
@@ -0,0 +1,33 @@
+<table class="room">
+ <tr class="room_header">
+ <td class="room_avatar">
+ {% if room.avatar_url %}
+ <img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
+ {% else %}
+ {% if room.hash % 3 == 0 %}
+ <img alt="" src="https://vector.im/beta/img/76cfa6.png" />
+ {% elif room.hash % 3 == 1 %}
+ <img alt="" src="https://vector.im/beta/img/50e2c2.png" />
+ {% else %}
+ <img alt="" src="https://vector.im/beta/img/f4c371.png" />
+ {% endif %}
+ {% endif %}
+ </td>
+ <td class="room_name" colspan="2">
+ {{ room.title }}
+ </td>
+ </tr>
+ {% if room.invite %}
+ <tr>
+ <td></td>
+ <td>
+ <a href="{{ room.link }}">Rejoindre la conversation.</a>
+ </td>
+ <td></td>
+ </tr>
+ {% else %}
+ {% for notif in room.notifs %}
+ {% include 'notif.html' with context %}
+ {% endfor %}
+ {% endif %}
+</table>
diff --git a/res/templates-dinsic/room.txt b/res/templates-dinsic/room.txt
new file mode 100644
index 0000000000..dd36d01d21
--- /dev/null
+++ b/res/templates-dinsic/room.txt
@@ -0,0 +1,9 @@
+{{ room.title }}
+
+{% if room.invite %}
+ Vous avez été invité, rejoignez la conversation en cliquant sur le lien suivant {{ room.link }}
+{% else %}
+ {% for notif in room.notifs %}
+ {% include 'notif.txt' with context %}
+ {% endfor %}
+{% endif %}
diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment
index 0ec5075e79..b8a85abe18 100755
--- a/scripts-dev/check-newsfragment
+++ b/scripts-dev/check-newsfragment
@@ -5,9 +5,9 @@
set -e
-# make sure that origin/develop is up to date
-git remote set-branches --add origin develop
-git fetch origin develop
+# make sure that origin/dinsic is up to date
+git remote set-branches --add origin dinsic
+git fetch origin dinsic
# if there are changes in the debian directory, check that the debian changelog
# has been updated
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 7ce6540bdd..943b5a2c86 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -128,7 +128,7 @@ class Auth(object):
)
self._check_joined_room(member, user_id, room_id)
- defer.returnValue(member)
+ return member
@defer.inlineCallbacks
def check_user_was_in_room(self, room_id, user_id):
@@ -156,13 +156,13 @@ class Auth(object):
if forgot:
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
- defer.returnValue(member)
+ return member
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.is_host_joined(room_id, host)
- defer.returnValue(latest_event_ids)
+ return latest_event_ids
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
@@ -207,6 +207,7 @@ class Auth(object):
access_token = self.get_access_token_from_request(request)
user_id, app_service = yield self._get_appservice_user_id(request)
+
if user_id:
request.authenticated_entity = user_id
@@ -219,9 +220,7 @@ class Auth(object):
device_id="dummy-device", # stubbed
)
- defer.returnValue(
- synapse.types.create_requester(user_id, app_service=app_service)
- )
+ return synapse.types.create_requester(user_id, app_service=app_service)
user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"]
@@ -262,39 +261,41 @@ class Auth(object):
request.authenticated_entity = user.to_string()
- defer.returnValue(
- synapse.types.create_requester(
- user, token_id, is_guest, device_id, app_service=app_service
- )
+ return synapse.types.create_requester(
+ user, token_id, is_guest, device_id, app_service=app_service
)
except KeyError:
raise MissingClientTokenError()
- @defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
+
if app_service is None:
- defer.returnValue((None, None))
+ return (None, None)
if app_service.ip_range_whitelist:
ip_address = IPAddress(self.hs.get_ip_from_request(request))
if ip_address not in app_service.ip_range_whitelist:
- defer.returnValue((None, None))
+ return (None, None)
if b"user_id" not in request.args:
- defer.returnValue((app_service.sender, app_service))
+ return (app_service.sender, app_service)
user_id = request.args[b"user_id"][0].decode("utf8")
if app_service.sender == user_id:
- defer.returnValue((app_service.sender, app_service))
+ return (app_service.sender, app_service)
if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
- if not (yield self.store.get_user_by_id(user_id)):
- raise AuthError(403, "Application service has not registered this user")
- defer.returnValue((user_id, app_service))
+ # Let ASes manipulate nonexistent users (e.g. to shadow-register them)
+ # if not (yield self.store.get_user_by_id(user_id)):
+ # raise AuthError(
+ # 403,
+ # "Application service has not registered this user"
+ # )
+ return (user_id, app_service)
@defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"):
@@ -330,7 +331,7 @@ class Auth(object):
msg="Access token has expired", soft_logout=True
)
- defer.returnValue(r)
+ return r
# otherwise it needs to be a valid macaroon
try:
@@ -378,7 +379,7 @@ class Auth(object):
}
else:
raise RuntimeError("Unknown rights setting %s", rights)
- defer.returnValue(ret)
+ return ret
except (
_InvalidMacaroonException,
pymacaroons.exceptions.MacaroonException,
@@ -506,7 +507,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
if not ret:
- defer.returnValue(None)
+ return None
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
@@ -518,7 +519,7 @@ class Auth(object):
"device_id": ret.get("device_id"),
"valid_until_ms": ret.get("valid_until_ms"),
}
- defer.returnValue(user_info)
+ return user_info
def get_appservice_by_req(self, request):
token = self.get_access_token_from_request(request)
@@ -543,7 +544,7 @@ class Auth(object):
@defer.inlineCallbacks
def compute_auth_events(self, event, current_state_ids, for_verification=False):
if event.type == EventTypes.Create:
- defer.returnValue([])
+ return []
auth_ids = []
@@ -604,7 +605,7 @@ class Auth(object):
if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id)
- defer.returnValue(auth_ids)
+ return auth_ids
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id, user):
@@ -618,7 +619,7 @@ class Auth(object):
is_admin = yield self.is_server_admin(user)
if is_admin:
- defer.returnValue(True)
+ return True
user_id = user.to_string()
yield self.check_joined_room(room_id, user_id)
@@ -712,7 +713,7 @@ class Auth(object):
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.check_user_was_in_room(room_id, user_id)
- defer.returnValue((member_event.membership, member_event.event_id))
+ return (member_event.membership, member_event.event_id)
except AuthError:
visibility = yield self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
@@ -721,7 +722,7 @@ class Auth(object):
visibility
and visibility.content["history_visibility"] == "world_readable"
):
- defer.returnValue((Membership.JOIN, None))
+ return (Membership.JOIN, None)
return
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 3ffde0d7fc..c7cae9768f 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -85,6 +85,7 @@ class EventTypes(object):
RoomAvatar = "m.room.avatar"
RoomEncryption = "m.room.encryption"
GuestAccess = "m.room.guest_access"
+ Encryption = "m.room.encryption"
# These are used for validation
Message = "m.room.message"
@@ -94,6 +95,8 @@ class EventTypes(object):
ServerACL = "m.room.server_acl"
Pinned = "m.room.pinned_events"
+ Retention = "m.room.retention"
+
class RejectedReason(object):
AUTH_ERROR = "auth_error"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index ad3e262041..c293135b51 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -61,6 +62,13 @@ class Codes(object):
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
+ PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
+ PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
+ PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
+ PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE"
+ PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL"
+ PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY"
+ WEAK_PASSWORD = "M_WEAK_PASSWORD"
class CodeMessageException(RuntimeError):
@@ -418,6 +426,18 @@ class IncompatibleRoomVersionError(SynapseError):
return cs_error(self.msg, self.errcode, room_version=self._room_version)
+class PasswordRefusedError(SynapseError):
+ """A password has been refused, either during password reset/change or registration.
+ """
+
+ def __init__(
+ self,
+ msg="This password doesn't comply with the server's policy",
+ errcode=Codes.WEAK_PASSWORD,
+ ):
+ super(PasswordRefusedError, self).__init__(code=400, msg=msg, errcode=errcode)
+
+
class RequestSendFailed(RuntimeError):
"""Sending a HTTP request over federation failed due to not being able to
talk to the remote server for some reason.
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 9b3daca29b..9f06556bd2 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -132,7 +132,7 @@ class Filtering(object):
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
result = yield self.store.get_user_filter(user_localpart, filter_id)
- defer.returnValue(FilterCollection(result))
+ return FilterCollection(result)
def add_user_filter(self, user_localpart, user_filter):
self.check_valid_filter(user_filter)
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 540dbd9236..c010e70955 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -15,10 +15,12 @@
import gc
import logging
+import os
import signal
import sys
import traceback
+import sdnotify
from daemonize import Daemonize
from twisted.internet import defer, error, reactor
@@ -242,9 +244,16 @@ def start(hs, listeners=None):
if hasattr(signal, "SIGHUP"):
def handle_sighup(*args, **kwargs):
+ # Tell systemd our state, if we're using it. This will silently fail if
+ # we're not using systemd.
+ sd_channel = sdnotify.SystemdNotifier()
+ sd_channel.notify("RELOADING=1")
+
for i in _sighup_callbacks:
i(hs)
+ sd_channel.notify("READY=1")
+
signal.signal(signal.SIGHUP, handle_sighup)
register_sighup(refresh_certificate)
@@ -260,6 +269,7 @@ def start(hs, listeners=None):
hs.get_datastore().start_profiling()
setup_sentry(hs)
+ setup_sdnotify(hs)
except Exception:
traceback.print_exc(file=sys.stderr)
reactor = hs.get_reactor()
@@ -292,6 +302,25 @@ def setup_sentry(hs):
scope.set_tag("worker_name", name)
+def setup_sdnotify(hs):
+ """Adds process state hooks to tell systemd what we are up to.
+ """
+
+ # Tell systemd our state, if we're using it. This will silently fail if
+ # we're not using systemd.
+ sd_channel = sdnotify.SystemdNotifier()
+
+ hs.get_reactor().addSystemEventTrigger(
+ "after",
+ "startup",
+ lambda: sd_channel.notify("READY=1\nMAINPID=%s" % (os.getpid())),
+ )
+
+ hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", lambda: sd_channel.notify("STOPPING=1")
+ )
+
+
def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
"""Replaces the resolver with one that limits the number of in flight DNS
requests.
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index e01f3e5f3b..54bb114dec 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -168,7 +168,9 @@ def start(config_options):
)
ps.setup()
- reactor.callWhenRunning(_base.start, ps, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ps, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-appservice", config)
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 29bddc4823..721bb5b119 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -194,7 +194,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-client-reader", config)
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index 042cfd04af..473c8895d0 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -193,7 +193,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-event-creator", config)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 76a97f8f32..5255d9e8cc 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -175,7 +175,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-federation-reader", config)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index fec49d5092..c5a2880e69 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -198,7 +198,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-federation-sender", config)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 1f1f1df78e..e2822ca848 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -70,12 +70,12 @@ class PresenceStatusStubServlet(RestServlet):
except HttpResponseException as e:
raise e.to_synapse_error()
- defer.returnValue((200, result))
+ return (200, result)
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
yield self.auth.get_user_by_req(request)
- defer.returnValue((200, {}))
+ return (200, {})
class KeyUploadServlet(RestServlet):
@@ -126,11 +126,11 @@ class KeyUploadServlet(RestServlet):
self.main_uri + request.uri.decode("ascii"), body, headers=headers
)
- defer.returnValue((200, result))
+ return (200, result)
else:
# Just interested in counts.
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
- defer.returnValue((200, {"one_time_key_counts": result}))
+ return (200, {"one_time_key_counts": result})
class FrontendProxySlavedStore(
@@ -247,7 +247,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-frontend-proxy", config)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 0c075cb3f1..fe4fa20bd9 100755..100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -406,7 +406,7 @@ def setup(config_options):
if provision:
yield acme.provision_certificate()
- defer.returnValue(provision)
+ return provision
@defer.inlineCallbacks
def reprovision_acme():
@@ -447,7 +447,7 @@ def setup(config_options):
reactor.stop()
sys.exit(1)
- reactor.callWhenRunning(start)
+ reactor.addSystemEventTrigger("before", "startup", start)
return hs
@@ -563,7 +563,7 @@ def run(hs):
stats["database_server_version"] = hs.get_datastore().get_server_version()
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
- yield hs.get_simple_http_client().put_json(
+ yield hs.get_proxied_http_client().put_json(
"https://matrix.org/report-usage-stats/push", stats
)
except Exception as e:
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index d70780e9d5..ea26f29acb 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -161,7 +161,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-media-repository", config)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 070de7d0b0..692ffa2f04 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -216,7 +216,7 @@ def start(config_options):
_base.start(ps, config.worker_listeners)
ps.get_pusherpool().start()
- reactor.callWhenRunning(start)
+ reactor.addSystemEventTrigger("before", "startup", start)
_base.start_worker_reactor("synapse-pusher", config)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 315c030694..a1c3b162f7 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -451,7 +451,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-synchrotron", config)
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 03ef21bd01..cb29a1afab 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -224,7 +224,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-user-dir", config)
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b26a31dd54..65cbff95b9 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -175,21 +175,21 @@ class ApplicationService(object):
@defer.inlineCallbacks
def _matches_user(self, event, store):
if not event:
- defer.returnValue(False)
+ return False
if self.is_interested_in_user(event.sender):
- defer.returnValue(True)
+ return True
# also check m.room.member state key
if event.type == EventTypes.Member and self.is_interested_in_user(
event.state_key
):
- defer.returnValue(True)
+ return True
if not store:
- defer.returnValue(False)
+ return False
does_match = yield self._matches_user_in_member_list(event.room_id, store)
- defer.returnValue(does_match)
+ return does_match
@cachedInlineCallbacks(num_args=1, cache_context=True)
def _matches_user_in_member_list(self, room_id, store, cache_context):
@@ -200,8 +200,8 @@ class ApplicationService(object):
# check joined member events
for user_id in member_list:
if self.is_interested_in_user(user_id):
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
def _matches_room_id(self, event):
if hasattr(event, "room_id"):
@@ -211,13 +211,13 @@ class ApplicationService(object):
@defer.inlineCallbacks
def _matches_aliases(self, event, store):
if not store or not event:
- defer.returnValue(False)
+ return False
alias_list = yield store.get_aliases_for_room(event.room_id)
for alias in alias_list:
if self.is_interested_in_alias(alias):
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
@defer.inlineCallbacks
def is_interested(self, event, store=None):
@@ -231,15 +231,15 @@ class ApplicationService(object):
"""
# Do cheap checks first
if self._matches_room_id(event):
- defer.returnValue(True)
+ return True
if (yield self._matches_aliases(event, store)):
- defer.returnValue(True)
+ return True
if (yield self._matches_user(event, store)):
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
def is_interested_in_user(self, user_id):
return (
@@ -268,7 +268,7 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
- def get_exlusive_user_regexes(self):
+ def get_exclusive_user_regexes(self):
"""Get the list of regexes used to determine if a user is exclusively
registered by the AS
"""
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 571881775b..007ca75a94 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -97,40 +97,40 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks
def query_user(self, service, user_id):
if service.url is None:
- defer.returnValue(False)
+ return False
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
response = None
try:
response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
- defer.returnValue(True)
+ return True
except CodeMessageException as e:
if e.code == 404:
- defer.returnValue(False)
+ return False
return
logger.warning("query_user to %s received %s", uri, e.code)
except Exception as ex:
logger.warning("query_user to %s threw exception %s", uri, ex)
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def query_alias(self, service, alias):
if service.url is None:
- defer.returnValue(False)
+ return False
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
response = None
try:
response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
- defer.returnValue(True)
+ return True
except CodeMessageException as e:
logger.warning("query_alias to %s received %s", uri, e.code)
if e.code == 404:
- defer.returnValue(False)
+ return False
return
except Exception as ex:
logger.warning("query_alias to %s threw exception %s", uri, ex)
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields):
@@ -141,7 +141,7 @@ class ApplicationServiceApi(SimpleHttpClient):
else:
raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind)
if service.url is None:
- defer.returnValue([])
+ return []
uri = "%s%s/thirdparty/%s/%s" % (
service.url,
@@ -155,7 +155,7 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning(
"query_3pe to %s returned an invalid response %r", uri, response
)
- defer.returnValue([])
+ return []
ret = []
for r in response:
@@ -166,14 +166,14 @@ class ApplicationServiceApi(SimpleHttpClient):
"query_3pe to %s returned an invalid result %r", uri, r
)
- defer.returnValue(ret)
+ return ret
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex)
- defer.returnValue([])
+ return []
def get_3pe_protocol(self, service, protocol):
if service.url is None:
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def _get():
@@ -189,7 +189,7 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning(
"query_3pe_protocol to %s did not return a" " valid result", uri
)
- defer.returnValue(None)
+ return None
for instance in info.get("instances", []):
network_id = instance.get("network_id", None)
@@ -198,10 +198,10 @@ class ApplicationServiceApi(SimpleHttpClient):
service.id, network_id
).to_string()
- defer.returnValue(info)
+ return info
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s", uri, ex)
- defer.returnValue(None)
+ return None
key = (service.id, protocol)
return self.protocol_meta_cache.wrap(key, _get)
@@ -209,7 +209,7 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
if service.url is None:
- defer.returnValue(True)
+ return True
events = self._serialize(events)
@@ -229,14 +229,14 @@ class ApplicationServiceApi(SimpleHttpClient):
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))
- defer.returnValue(True)
+ return True
return
except CodeMessageException as e:
logger.warning("push_bulk to %s received %s", uri, e.code)
except Exception as ex:
logger.warning("push_bulk to %s threw exception %s", uri, ex)
failed_transactions_counter.labels(service.id).inc()
- defer.returnValue(False)
+ return False
def _serialize(self, events):
time_now = self.clock.time_msec()
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index e5b36494f5..42a350bff8 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -193,7 +193,7 @@ class _TransactionController(object):
@defer.inlineCallbacks
def _is_service_up(self, service):
state = yield self.store.get_appservice_state(service)
- defer.returnValue(state == ApplicationServiceState.UP or state is None)
+ return state == ApplicationServiceState.UP or state is None
class _Recoverer(object):
@@ -208,7 +208,7 @@ class _Recoverer(object):
r.service.id,
)
r.recover()
- defer.returnValue(recoverers)
+ return recoverers
def __init__(self, clock, store, as_api, service, callback):
self.clock = clock
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 6ce5cd07fb..54230342a1 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,6 +18,7 @@
import argparse
import errno
import os
+from io import open as io_open
from textwrap import dedent
from six import integer_types
@@ -133,7 +134,7 @@ class Config(object):
@classmethod
def read_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
- with open(file_path) as file_stream:
+ with io_open(file_path, encoding="utf-8") as file_stream:
return file_stream.read()
def invoke_all(self, name, *args, **kargs):
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 40502a5798..d321d00b80 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
import logging.config
import os
@@ -75,10 +76,8 @@ root:
class LoggingConfig(Config):
def read_config(self, config, **kwargs):
- self.verbosity = config.get("verbose", 0)
- self.no_redirect_stdio = config.get("no_redirect_stdio", False)
self.log_config = self.abspath(config.get("log_config"))
- self.log_file = self.abspath(config.get("log_file"))
+ self.no_redirect_stdio = config.get("no_redirect_stdio", False)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
log_config = os.path.join(config_dir_path, server_name + ".log.config")
@@ -94,39 +93,13 @@ class LoggingConfig(Config):
)
def read_arguments(self, args):
- if args.verbose is not None:
- self.verbosity = args.verbose
if args.no_redirect_stdio is not None:
self.no_redirect_stdio = args.no_redirect_stdio
- if args.log_config is not None:
- self.log_config = args.log_config
- if args.log_file is not None:
- self.log_file = args.log_file
@staticmethod
def add_arguments(parser):
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
- "-v",
- "--verbose",
- dest="verbose",
- action="count",
- help="The verbosity level. Specify multiple times to increase "
- "verbosity. (Ignored if --log-config is specified.)",
- )
- logging_group.add_argument(
- "-f",
- "--log-file",
- dest="log_file",
- help="File to log to. (Ignored if --log-config is specified.)",
- )
- logging_group.add_argument(
- "--log-config",
- dest="log_config",
- default=None,
- help="Python logging config file",
- )
- logging_group.add_argument(
"-n",
"--no-redirect-stdio",
action="store_true",
@@ -153,58 +126,29 @@ def setup_logging(config, use_worker_options=False):
config (LoggingConfig | synapse.config.workers.WorkerConfig):
configuration data
- use_worker_options (bool): True to use 'worker_log_config' and
- 'worker_log_file' options instead of 'log_config' and 'log_file'.
+ use_worker_options (bool): True to use the 'worker_log_config' option
+ instead of 'log_config'.
register_sighup (func | None): Function to call to register a
sighup handler.
"""
log_config = config.worker_log_config if use_worker_options else config.log_config
- log_file = config.worker_log_file if use_worker_options else config.log_file
-
- log_format = (
- "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
- " - %(message)s"
- )
if log_config is None:
- # We don't have a logfile, so fall back to the 'verbosity' param from
- # the config or cmdline. (Note that we generate a log config for new
- # installs, so this will be an unusual case)
- level = logging.INFO
- level_for_storage = logging.INFO
- if config.verbosity:
- level = logging.DEBUG
- if config.verbosity > 1:
- level_for_storage = logging.DEBUG
+ log_format = (
+ "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
+ " - %(message)s"
+ )
logger = logging.getLogger("")
- logger.setLevel(level)
-
- logging.getLogger("synapse.storage.SQL").setLevel(level_for_storage)
+ logger.setLevel(logging.INFO)
+ logging.getLogger("synapse.storage.SQL").setLevel(logging.INFO)
formatter = logging.Formatter(log_format)
- if log_file:
- # TODO: Customisable file size / backup count
- handler = logging.handlers.RotatingFileHandler(
- log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, encoding="utf8"
- )
-
- def sighup(signum, stack):
- logger.info("Closing log file due to SIGHUP")
- handler.doRollover()
- logger.info("Opened new log file due to SIGHUP")
-
- else:
- handler = logging.StreamHandler()
-
- def sighup(*args):
- pass
+ handler = logging.StreamHandler()
handler.setFormatter(formatter)
-
handler.addFilter(LoggingContextFilter(request=""))
-
logger.addHandler(handler)
else:
@@ -218,8 +162,7 @@ def setup_logging(config, use_worker_options=False):
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
load_log_config()
-
- appbase.register_sighup(sighup)
+ appbase.register_sighup(sighup)
# make sure that the first thing we log is a thing we can grep backwards
# for
diff --git a/synapse/config/password.py b/synapse/config/password.py
index d5b5953f2f..47df98f41a 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,6 +31,10 @@ class PasswordConfig(Config):
self.password_localdb_enabled = password_config.get("localdb_enabled", True)
self.password_pepper = password_config.get("pepper", "")
+ # Password policy
+ self.password_policy = password_config.get("policy", {})
+ self.password_policy_enabled = self.password_policy.pop("enabled", False)
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
password_config:
@@ -46,4 +52,34 @@ class PasswordConfig(Config):
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
#
#pepper: "EVEN_MORE_SECRET"
+
+ # Define and enforce a password policy. Each parameter is optional, boolean
+ # parameters default to 'false' and integer parameters default to 0.
+ # This is an early implementation of MSC2000.
+ #
+ #policy:
+ # Whether to enforce the password policy.
+ #
+ #enabled: true
+
+ # Minimum accepted length for a password.
+ #
+ #minimum_length: 15
+
+ # Whether a password must contain at least one digit.
+ #
+ #require_digit: true
+
+ # Whether a password must contain at least one symbol.
+ # A symbol is any character that's not a number or a letter.
+ #
+ #require_symbol: true
+
+ # Whether a password must contain at least one lowercase letter.
+ #
+ #require_lowercase: true
+
+ # Whether a password must contain at least one lowercase letter.
+ #
+ #require_uppercase: true
"""
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 33f31cf213..a1ea4fe02d 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -68,6 +68,9 @@ class RatelimitConfig(Config):
)
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+ self.rc_third_party_invite = RateLimitConfig(
+ config.get("rc_third_party_invite", {})
+ )
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
@@ -102,6 +105,8 @@ class RatelimitConfig(Config):
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+ # - one that ratelimits third-party invites requests based on the account
+ # that's making the requests.
#
# The defaults are as shown below.
#
@@ -123,6 +128,10 @@ class RatelimitConfig(Config):
# failed_attempts:
# per_second: 0.17
# burst_count: 3
+ #
+ #rc_third_party_invite:
+ # per_second: 0.2
+ # burst_count: 10
# Ratelimiting settings for incoming federation
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index c3de7a4e32..3240e30f70 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
from distutils.util import strtobool
+import pkg_resources
+
from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias
from synapse.util.stringutils import random_string_with_symbols
@@ -41,8 +44,36 @@ class AccountValidityConfig(Config):
self.startup_job_max_delta = self.period * 10.0 / 100.0
- if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
- raise ConfigError("Can't send renewal emails without 'public_baseurl'")
+ if self.renew_by_email_enabled:
+ if "public_baseurl" not in synapse_config:
+ raise ConfigError("Can't send renewal emails without 'public_baseurl'")
+
+ template_dir = config.get("template_dir")
+
+ if not template_dir:
+ template_dir = pkg_resources.resource_filename("synapse", "res/templates")
+
+ if "account_renewed_html_path" in config:
+ file_path = os.path.join(template_dir, config["account_renewed_html_path"])
+
+ self.account_renewed_html_content = self.read_file(
+ file_path, "account_validity.account_renewed_html_path"
+ )
+ else:
+ self.account_renewed_html_content = (
+ "<html><body>Your account has been successfully renewed.</body><html>"
+ )
+
+ if "invalid_token_html_path" in config:
+ file_path = os.path.join(template_dir, config["invalid_token_html_path"])
+
+ self.invalid_token_html_content = self.read_file(
+ file_path, "account_validity.invalid_token_html_path"
+ )
+ else:
+ self.invalid_token_html_content = (
+ "<html><body>Invalid renewal token.</body><html>"
+ )
class RegistrationConfig(Config):
@@ -61,8 +92,19 @@ class RegistrationConfig(Config):
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
+ self.check_is_for_allowed_local_3pids = config.get(
+ "check_is_for_allowed_local_3pids", None
+ )
+ self.allow_invited_3pids = config.get("allow_invited_3pids", False)
+
+ self.disable_3pid_changes = config.get("disable_3pid_changes", False)
+
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
self.registration_shared_secret = config.get("registration_shared_secret")
+ self.register_mxid_from_3pid = config.get("register_mxid_from_3pid")
+ self.register_just_use_email_for_display_name = config.get(
+ "register_just_use_email_for_display_name", False
+ )
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config.get(
@@ -80,6 +122,18 @@ class RegistrationConfig(Config):
raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
+ self.disable_set_displayname = config.get("disable_set_displayname", False)
+ self.disable_set_avatar_url = config.get("disable_set_avatar_url", False)
+
+ self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", [])
+ if not isinstance(self.replicate_user_profiles_to, list):
+ self.replicate_user_profiles_to = [self.replicate_user_profiles_to]
+
+ self.shadow_server = config.get("shadow_server", None)
+ self.rewrite_identity_server_urls = config.get(
+ "rewrite_identity_server_urls", {}
+ )
+
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
)
@@ -145,6 +199,16 @@ class RegistrationConfig(Config):
# period: 6w
# renew_at: 1w
# renew_email_subject: "Renew your %%(app)s account"
+ # # Directory in which Synapse will try to find the HTML files to serve to the
+ # # user when trying to renew an account. Optional, defaults to
+ # # synapse/res/templates.
+ # template_dir: "res/templates"
+ # # HTML to be displayed to the user after they successfully renewed their
+ # # account. Optional.
+ # account_renewed_html_path: "account_renewed.html"
+ # # HTML to be displayed when the user tries to renew an account with an invalid
+ # # renewal token. Optional.
+ # invalid_token_html_path: "invalid_token.html"
# Time that a user's session remains valid for, after they log in.
#
@@ -168,9 +232,32 @@ class RegistrationConfig(Config):
#
#disable_msisdn_registration: true
+ # Derive the user's matrix ID from a type of 3PID used when registering.
+ # This overrides any matrix ID the user proposes when calling /register
+ # The 3PID type should be present in registrations_require_3pid to avoid
+ # users failing to register if they don't specify the right kind of 3pid.
+ #
+ #register_mxid_from_3pid: email
+
+ # Uncomment to set the display name of new users to their email address,
+ # rather than using the default heuristic.
+ #
+ #register_just_use_email_for_display_name: true
+
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
+ # Use an Identity Server to establish which 3PIDs are allowed to register?
+ # Overrides allowed_local_3pids below.
+ #
+ #check_is_for_allowed_local_3pids: matrix.org
+ #
+ # If you are using an IS you can also check whether that IS registers
+ # pending invites for the given 3PID (and then allow it to sign up on
+ # the platform):
+ #
+ #allow_invited_3pids: False
+ #
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\\.org'
@@ -179,6 +266,11 @@ class RegistrationConfig(Config):
# - medium: msisdn
# pattern: '\\+44'
+ # If true, stop users from trying to change the 3PIDs associated with
+ # their accounts.
+ #
+ #disable_3pid_changes: False
+
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -220,6 +312,30 @@ class RegistrationConfig(Config):
# - matrix.org
# - vector.im
+ # If enabled, user IDs, display names and avatar URLs will be replicated
+ # to this server whenever they change.
+ # This is an experimental API currently implemented by sydent to support
+ # cross-homeserver user directories.
+ #
+ #replicate_user_profiles_to: example.com
+
+ # If specified, attempt to replay registrations, profile changes & 3pid
+ # bindings on the given target homeserver via the AS API. The HS is authed
+ # via a given AS token.
+ #
+ #shadow_server:
+ # hs_url: https://shadow.example.com
+ # hs: shadow.example.com
+ # as_token: 12u394refgbdhivsia
+
+ # If enabled, don't let users set their own display names/avatars
+ # other than for the very first time (unless they are a server admin).
+ # Useful when provisioning users based on the contents of a 3rd party
+ # directory and to avoid ambiguities.
+ #
+ #disable_set_displayname: False
+ #disable_set_avatar_url: False
+
# Users who register on this homeserver will automatically be joined
# to these rooms
#
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 80a628d9b0..c6b737fb6b 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -91,6 +91,12 @@ class ContentRepositoryConfig(Config):
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
+ self.max_avatar_size = config.get("max_avatar_size")
+ if self.max_avatar_size:
+ self.max_avatar_size = self.parse_size(self.max_avatar_size)
+
+ self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes", [])
+
self.media_store_path = self.ensure_directory(
config.get("media_store_path", "media_store")
)
@@ -229,6 +235,30 @@ class ContentRepositoryConfig(Config):
#
#max_upload_size: 10M
+ # The largest allowed size for a user avatar. If not defined, no
+ # restriction will be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #max_avatar_size: 10M
+
+ # Allow mimetypes for a user avatar. If not defined, no restriction will
+ # be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 00170f1393..d32a2e4b0a 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -82,6 +82,12 @@ class ServerConfig(Config):
"require_auth_for_profile_requests", False
)
+ # Whether to require sharing a room with a user to retrieve their
+ # profile data
+ self.limit_profile_requests_to_known_users = config.get(
+ "limit_profile_requests_to_known_users", False
+ )
+
if "restrict_public_rooms_to_local_users" in config and (
"allow_public_rooms_without_auth" in config
or "allow_public_rooms_over_federation" in config
@@ -209,6 +215,130 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # True.
+ self.show_users_in_user_directory = config.get(
+ "show_users_in_user_directory", True
+ )
+
+ retention_config = config.get("retention")
+ if retention_config is None:
+ retention_config = {}
+
+ self.retention_enabled = retention_config.get("enabled", False)
+
+ retention_default_policy = retention_config.get("default_policy")
+
+ if retention_default_policy is not None:
+ self.retention_default_min_lifetime = retention_default_policy.get(
+ "min_lifetime"
+ )
+ if self.retention_default_min_lifetime is not None:
+ self.retention_default_min_lifetime = self.parse_duration(
+ self.retention_default_min_lifetime
+ )
+
+ self.retention_default_max_lifetime = retention_default_policy.get(
+ "max_lifetime"
+ )
+ if self.retention_default_max_lifetime is not None:
+ self.retention_default_max_lifetime = self.parse_duration(
+ self.retention_default_max_lifetime
+ )
+
+ if (
+ self.retention_default_min_lifetime is not None
+ and self.retention_default_max_lifetime is not None
+ and (
+ self.retention_default_min_lifetime
+ > self.retention_default_max_lifetime
+ )
+ ):
+ raise ConfigError(
+ "The default retention policy's 'min_lifetime' can not be greater"
+ " than its 'max_lifetime'"
+ )
+ else:
+ self.retention_default_min_lifetime = None
+ self.retention_default_max_lifetime = None
+
+ self.retention_allowed_lifetime_min = retention_config.get(
+ "allowed_lifetime_min"
+ )
+ if self.retention_allowed_lifetime_min is not None:
+ self.retention_allowed_lifetime_min = self.parse_duration(
+ self.retention_allowed_lifetime_min
+ )
+
+ self.retention_allowed_lifetime_max = retention_config.get(
+ "allowed_lifetime_max"
+ )
+ if self.retention_allowed_lifetime_max is not None:
+ self.retention_allowed_lifetime_max = self.parse_duration(
+ self.retention_allowed_lifetime_max
+ )
+
+ if (
+ self.retention_allowed_lifetime_min is not None
+ and self.retention_allowed_lifetime_max is not None
+ and self.retention_allowed_lifetime_min
+ > self.retention_allowed_lifetime_max
+ ):
+ raise ConfigError(
+ "Invalid retention policy limits: 'allowed_lifetime_min' can not be"
+ " greater than 'allowed_lifetime_max'"
+ )
+
+ self.retention_purge_jobs = []
+ for purge_job_config in retention_config.get("purge_jobs", []):
+ interval_config = purge_job_config.get("interval")
+
+ if interval_config is None:
+ raise ConfigError(
+ "A retention policy's purge jobs configuration must have the"
+ " 'interval' key set."
+ )
+
+ interval = self.parse_duration(interval_config)
+
+ shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime")
+
+ if shortest_max_lifetime is not None:
+ shortest_max_lifetime = self.parse_duration(shortest_max_lifetime)
+
+ longest_max_lifetime = purge_job_config.get("longest_max_lifetime")
+
+ if longest_max_lifetime is not None:
+ longest_max_lifetime = self.parse_duration(longest_max_lifetime)
+
+ if (
+ shortest_max_lifetime is not None
+ and longest_max_lifetime is not None
+ and shortest_max_lifetime > longest_max_lifetime
+ ):
+ raise ConfigError(
+ "A retention policy's purge jobs configuration's"
+ " 'shortest_max_lifetime' value can not be greater than its"
+ " 'longest_max_lifetime' value."
+ )
+
+ self.retention_purge_jobs.append(
+ {
+ "interval": interval,
+ "shortest_max_lifetime": shortest_max_lifetime,
+ "longest_max_lifetime": longest_max_lifetime,
+ }
+ )
+
+ if not self.retention_purge_jobs:
+ self.retention_purge_jobs = [
+ {
+ "interval": self.parse_duration("1d"),
+ "shortest_max_lifetime": None,
+ "longest_max_lifetime": None,
+ }
+ ]
+
self.listeners = []
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
@@ -395,6 +525,13 @@ class ServerConfig(Config):
#
#require_auth_for_profile_requests: true
+ # Whether to require a user to share a room with another user in order
+ # to retrieve their profile information. Only checked on Client-Server
+ # requests. Profile requests from other servers should be checked by the
+ # requesting server. Defaults to 'false'.
+ #
+ # limit_profile_requests_to_known_users: true
+
# If set to 'false', requires authentication to access the server's public rooms
# directory through the client API. Defaults to 'true'.
#
@@ -627,6 +764,74 @@ class ServerConfig(Config):
# Defaults to 'true'.
#
#allow_per_room_profiles: false
+
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # 'true'.
+ #
+ #show_users_in_user_directory: false
+
+ # Message retention policy at the server level.
+ #
+ # Room admins and mods can define a retention period for their rooms using the
+ # 'm.room.retention' state event, and server admins can cap this period by setting
+ # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+ #
+ # If this feature is enabled, Synapse will regularly look for and purge events
+ # which are older than the room's maximum retention period. Synapse will also
+ # filter events received over federation so that events that should have been
+ # purged are ignored and not stored again.
+ #
+ retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
"""
% locals()
)
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 4479454415..95e7ccb3a3 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -23,6 +23,12 @@ class TracerConfig(Config):
opentracing_config = {}
self.opentracer_enabled = opentracing_config.get("enabled", False)
+
+ self.jaeger_config = opentracing_config.get(
+ "jaeger_config",
+ {"sampler": {"type": "const", "param": 1}, "logging": False},
+ )
+
if not self.opentracer_enabled:
return
@@ -56,4 +62,20 @@ class TracerConfig(Config):
#
#homeserver_whitelist:
# - ".*"
+
+ # Jaeger can be configured to sample traces at different rates.
+ # All configuration options provided by Jaeger can be set here.
+ # Jaeger's configuration mostly related to trace sampling which
+ # is documented here:
+ # https://www.jaegertracing.io/docs/1.13/sampling/.
+ #
+ #jaeger_config:
+ # sampler:
+ # type: const
+ # param: 1
+
+ # Logging whether spans were started and reported
+ #
+ # logging:
+ # false
"""
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index f6313e17d4..96493a5dcc 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -24,6 +24,7 @@ class UserDirectoryConfig(Config):
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
+ self.user_directory_defer_to_id_server = None
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_enabled = user_directory_config.get(
@@ -32,6 +33,9 @@ class UserDirectoryConfig(Config):
self.user_directory_search_all_users = user_directory_config.get(
"search_all_users", False
)
+ self.user_directory_defer_to_id_server = user_directory_config.get(
+ "defer_to_id_server", None
+ )
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
@@ -50,4 +54,9 @@ class UserDirectoryConfig(Config):
#user_directory:
# enabled: true
# search_all_users: false
+ #
+ # # If this is set, user search will be delegated to this ID server instead
+ # # of synapse performing the search itself.
+ # # This is an experimental API.
+ # defer_to_id_server: https://id.example.com
"""
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 3b75471d85..bc0fc165e3 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -31,7 +31,6 @@ class WorkerConfig(Config):
self.worker_listeners = config.get("worker_listeners", [])
self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file")
- self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config")
# The host used to connect to the main synapse
@@ -78,9 +77,5 @@ class WorkerConfig(Config):
if args.daemonize is not None:
self.worker_daemonize = args.daemonize
- if args.log_config is not None:
- self.worker_log_config = args.log_config
- if args.log_file is not None:
- self.worker_log_file = args.log_file
if args.manhole is not None:
self.worker_manhole = args.worker_manhole
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 341c863152..6c3e885e72 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -238,27 +238,9 @@ class Keyring(object):
"""
try:
- # create a deferred for each server we're going to look up the keys
- # for; we'll resolve them once we have completed our lookups.
- # These will be passed into wait_for_previous_lookups to block
- # any other lookups until we have finished.
- # The deferreds are called with no logcontext.
- server_to_deferred = {
- rq.server_name: defer.Deferred() for rq in verify_requests
- }
-
- # We want to wait for any previous lookups to complete before
- # proceeding.
- yield self.wait_for_previous_lookups(server_to_deferred)
+ ctx = LoggingContext.current_context()
- # Actually start fetching keys.
- self._get_server_verify_keys(verify_requests)
-
- # When we've finished fetching all the keys for a given server_name,
- # resolve the deferred passed to `wait_for_previous_lookups` so that
- # any lookups waiting will proceed.
- #
- # map from server name to a set of request ids
+ # map from server name to a set of outstanding request ids
server_to_request_ids = {}
for verify_request in verify_requests:
@@ -266,40 +248,61 @@ class Keyring(object):
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
- def remove_deferreds(res, verify_request):
+ # Wait for any previous lookups to complete before proceeding.
+ yield self.wait_for_previous_lookups(server_to_request_ids.keys())
+
+ # take out a lock on each of the servers by sticking a Deferred in
+ # key_downloads
+ for server_name in server_to_request_ids.keys():
+ self.key_downloads[server_name] = defer.Deferred()
+ logger.debug("Got key lookup lock on %s", server_name)
+
+ # When we've finished fetching all the keys for a given server_name,
+ # drop the lock by resolving the deferred in key_downloads.
+ def drop_server_lock(server_name):
+ d = self.key_downloads.pop(server_name)
+ d.callback(None)
+
+ def lookup_done(res, verify_request):
server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids[server_name].discard(request_id)
- if not server_to_request_ids[server_name]:
- d = server_to_deferred.pop(server_name, None)
- if d:
- d.callback(None)
+ server_requests = server_to_request_ids[server_name]
+ server_requests.remove(id(verify_request))
+
+ # if there are no more requests for this server, we can drop the lock.
+ if not server_requests:
+ with PreserveLoggingContext(ctx):
+ logger.debug("Releasing key lookup lock on %s", server_name)
+
+ # ... but not immediately, as that can cause stack explosions if
+ # we get a long queue of lookups.
+ self.clock.call_later(0, drop_server_lock, server_name)
+
return res
for verify_request in verify_requests:
- verify_request.key_ready.addBoth(remove_deferreds, verify_request)
+ verify_request.key_ready.addBoth(lookup_done, verify_request)
+
+ # Actually start fetching keys.
+ self._get_server_verify_keys(verify_requests)
except Exception:
logger.exception("Error starting key lookups")
@defer.inlineCallbacks
- def wait_for_previous_lookups(self, server_to_deferred):
+ def wait_for_previous_lookups(self, server_names):
"""Waits for any previous key lookups for the given servers to finish.
Args:
- server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
- resolved once we've finished looking up keys for that server.
- The Deferreds should be regular twisted ones which call their
- callbacks with no logcontext.
-
- Returns: a Deferred which resolves once all key lookups for the given
- servers have completed. Follows the synapse rules of logcontext
- preservation.
+ server_names (Iterable[str]): list of servers which we want to look up
+
+ Returns:
+ Deferred[None]: resolves once all key lookups for the given servers have
+ completed. Follows the synapse rules of logcontext preservation.
"""
loop_count = 1
while True:
wait_on = [
(server_name, self.key_downloads[server_name])
- for server_name in server_to_deferred.keys()
+ for server_name in server_names
if server_name in self.key_downloads
]
if not wait_on:
@@ -314,19 +317,6 @@ class Keyring(object):
loop_count += 1
- ctx = LoggingContext.current_context()
-
- def rm(r, server_name_):
- with PreserveLoggingContext(ctx):
- logger.debug("Releasing key lookup lock on %s", server_name_)
- self.key_downloads.pop(server_name_, None)
- return r
-
- for server_name, deferred in server_to_deferred.items():
- logger.debug("Got key lookup lock on %s", server_name)
- self.key_downloads[server_name] = deferred
- deferred.addBoth(rm, server_name)
-
def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request
@@ -472,7 +462,7 @@ class StoreKeyFetcher(KeyFetcher):
keys = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
- defer.returnValue(keys)
+ return keys
class BaseV2KeyFetcher(object):
@@ -576,7 +566,7 @@ class BaseV2KeyFetcher(object):
).addErrback(unwrapFirstError)
)
- defer.returnValue(verify_keys)
+ return verify_keys
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
@@ -598,7 +588,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
result = yield self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server
)
- defer.returnValue(result)
+ return result
except KeyLookupError as e:
logger.warning(
"Key lookup failed from %r: %s", key_server.server_name, e
@@ -611,7 +601,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
str(e),
)
- defer.returnValue({})
+ return {}
results = yield make_deferred_yieldable(
defer.gatherResults(
@@ -625,7 +615,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
- defer.returnValue(union_of_keys)
+ return union_of_keys
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
@@ -711,7 +701,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
perspective_name, time_now_ms, added_keys
)
- defer.returnValue(keys)
+ return keys
def _validate_perspectives_response(self, key_server, response):
"""Optionally check the signature on the result of a /key/query request
@@ -853,7 +843,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
)
keys.update(response_keys)
- defer.returnValue(keys)
+ return keys
@defer.inlineCallbacks
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index db011e0407..3997751337 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -144,15 +144,13 @@ class EventBuilder(object):
if self._origin_server_ts is not None:
event_dict["origin_server_ts"] = self._origin_server_ts
- defer.returnValue(
- create_local_event_from_event_dict(
- clock=self._clock,
- hostname=self._hostname,
- signing_key=self._signing_key,
- format_version=self.format_version,
- event_dict=event_dict,
- internal_metadata_dict=self.internal_metadata.get_dict(),
- )
+ return create_local_event_from_event_dict(
+ clock=self._clock,
+ hostname=self._hostname,
+ signing_key=self._signing_key,
+ format_version=self.format_version,
+ event_dict=event_dict,
+ internal_metadata_dict=self.internal_metadata.get_dict(),
)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index a9545e6c1b..acbcbeeced 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -133,19 +133,17 @@ class EventContext(object):
else:
prev_state_id = None
- defer.returnValue(
- {
- "prev_state_id": prev_state_id,
- "event_type": event.type,
- "event_state_key": event.state_key if event.is_state() else None,
- "state_group": self.state_group,
- "rejected": self.rejected,
- "prev_group": self.prev_group,
- "delta_ids": _encode_state_dict(self.delta_ids),
- "prev_state_events": self.prev_state_events,
- "app_service_id": self.app_service.id if self.app_service else None,
- }
- )
+ return {
+ "prev_state_id": prev_state_id,
+ "event_type": event.type,
+ "event_state_key": event.state_key if event.is_state() else None,
+ "state_group": self.state_group,
+ "rejected": self.rejected,
+ "prev_group": self.prev_group,
+ "delta_ids": _encode_state_dict(self.delta_ids),
+ "prev_state_events": self.prev_state_events,
+ "app_service_id": self.app_service.id if self.app_service else None,
+ }
@staticmethod
def deserialize(store, input):
@@ -202,7 +200,7 @@ class EventContext(object):
yield make_deferred_yieldable(self._fetching_state_deferred)
- defer.returnValue(self._current_state_ids)
+ return self._current_state_ids
@defer.inlineCallbacks
def get_prev_state_ids(self, store):
@@ -222,7 +220,7 @@ class EventContext(object):
yield make_deferred_yieldable(self._fetching_state_deferred)
- defer.returnValue(self._prev_state_ids)
+ return self._prev_state_ids
def get_cached_current_state_ids(self):
"""Gets the current state IDs if we have them already cached.
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 129771f183..f0de4d961f 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -46,13 +46,33 @@ class SpamChecker(object):
return self.spam_checker.check_event_for_spam(event)
- def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ def user_may_invite(
+ self,
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room,
+ ):
"""Checks if a given user may send an invite
If this method returns false, the invite will be rejected.
Args:
- userid (string): The sender's user ID
+ inviter_userid (str)
+ invitee_userid (str|None): The user ID of the invitee. Is None
+ if this is a third party invite and the 3PID is not bound to a
+ user ID.
+ third_party_invite (dict|None): If a third party invite then is a
+ dict containing the medium and address of the invitee.
+ room_id (str)
+ new_room (bool): Whether the user is being invited to the room as
+ part of a room creation, if so the invitee would have been
+ included in the call to `user_may_create_room`.
+ published_room (bool): Whether the room the user is being invited
+ to has been published in the local homeserver's public room
+ directory.
Returns:
bool: True if the user may send an invite, otherwise False
@@ -61,16 +81,29 @@ class SpamChecker(object):
return True
return self.spam_checker.user_may_invite(
- inviter_userid, invitee_userid, room_id
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room,
)
- def user_may_create_room(self, userid):
+ def user_may_create_room(
+ self, userid, invite_list, third_party_invite_list, cloning
+ ):
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
Args:
userid (string): The sender's user ID
+ invite_list (list[str]): List of user IDs that would be invited to
+ the new room.
+ third_party_invite_list (list[dict]): List of third party invites
+ for the new room.
+ cloning (bool): Whether the user is cloning an existing room, e.g.
+ upgrading a room.
Returns:
bool: True if the user may create a room, otherwise False
@@ -78,7 +111,9 @@ class SpamChecker(object):
if self.spam_checker is None:
return True
- return self.spam_checker.user_may_create_room(userid)
+ return self.spam_checker.user_may_create_room(
+ userid, invite_list, third_party_invite_list, cloning
+ )
def user_may_create_room_alias(self, userid, room_alias):
"""Checks if a given user may create a room alias
@@ -113,3 +148,21 @@ class SpamChecker(object):
return True
return self.spam_checker.user_may_publish_room(userid, room_id)
+
+ def user_may_join_room(self, userid, room_id, is_invited):
+ """Checks if a given users is allowed to join a room.
+
+ Is not called when the user creates a room.
+
+ Args:
+ userid (str)
+ room_id (str)
+ is_invited (bool): Whether the user is invited into the room
+
+ Returns:
+ bool: Whether the user may join the room
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_join_room(userid, room_id, is_invited)
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 8f5d95696b..714a9b1579 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -51,7 +51,7 @@ class ThirdPartyEventRules(object):
defer.Deferred[bool]: True if the event should be allowed, False if not.
"""
if self.third_party_rules is None:
- defer.returnValue(True)
+ return True
prev_state_ids = yield context.get_prev_state_ids(self.store)
@@ -61,7 +61,7 @@ class ThirdPartyEventRules(object):
state_events[key] = yield self.store.get_event(event_id, allow_none=True)
ret = yield self.third_party_rules.check_event_allowed(event, state_events)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def on_create_room(self, requester, config, is_requester_admin):
@@ -98,7 +98,7 @@ class ThirdPartyEventRules(object):
"""
if self.third_party_rules is None:
- defer.returnValue(True)
+ return True
state_ids = yield self.store.get_filtered_current_state_ids(room_id)
room_state_events = yield self.store.get_events(state_ids.values())
@@ -110,4 +110,4 @@ class ThirdPartyEventRules(object):
ret = yield self.third_party_rules.check_threepid_can_be_invited(
medium, address, state_events
)
- defer.returnValue(ret)
+ return ret
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 9487a886f5..07d1c5bcf0 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -360,7 +360,7 @@ class EventClientSerializer(object):
"""
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
- defer.returnValue(event)
+ return event
event_id = event.event_id
serialized_event = serialize_event(event, time_now, **kwargs)
@@ -406,7 +406,7 @@ class EventClientSerializer(object):
"sender": edit.sender,
}
- defer.returnValue(serialized_event)
+ return serialized_event
def serialize_events(self, events, time_now, **kwargs):
"""Serializes multiple events.
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index f7ffd1d561..9df9287aa7 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six import string_types
+from six import integer_types, string_types
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
@@ -22,11 +22,12 @@ from synapse.types import EventID, RoomID, UserID
class EventValidator(object):
- def validate_new(self, event):
+ def validate_new(self, event, config):
"""Validates the event has roughly the right format
Args:
- event (FrozenEvent)
+ event (FrozenEvent): The event to validate.
+ config (Config): The homeserver's configuration.
"""
self.validate_builder(event)
@@ -67,6 +68,99 @@ class EventValidator(object):
Codes.INVALID_PARAM,
)
+ if event.type == EventTypes.Retention:
+ self._validate_retention(event, config)
+
+ def _validate_retention(self, event, config):
+ """Checks that an event that defines the retention policy for a room respects the
+ boundaries imposed by the server's administrator.
+
+ Args:
+ event (FrozenEvent): The event to validate.
+ config (Config): The homeserver's configuration.
+ """
+ min_lifetime = event.content.get("min_lifetime")
+ max_lifetime = event.content.get("max_lifetime")
+
+ if min_lifetime is not None:
+ if not isinstance(min_lifetime, integer_types):
+ raise SynapseError(
+ code=400,
+ msg="'min_lifetime' must be an integer",
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_min is not None
+ and min_lifetime < config.retention_allowed_lifetime_min
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'min_lifetime' can't be lower than the minimum allowed"
+ " value enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_max is not None
+ and min_lifetime > config.retention_allowed_lifetime_max
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'min_lifetime' can't be greater than the maximum allowed"
+ " value enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if max_lifetime is not None:
+ if not isinstance(max_lifetime, integer_types):
+ raise SynapseError(
+ code=400,
+ msg="'max_lifetime' must be an integer",
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_min is not None
+ and max_lifetime < config.retention_allowed_lifetime_min
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'max_lifetime' can't be lower than the minimum allowed value"
+ " enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_max is not None
+ and max_lifetime > config.retention_allowed_lifetime_max
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'max_lifetime' can't be greater than the maximum allowed"
+ " value enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ min_lifetime is not None
+ and max_lifetime is not None
+ and min_lifetime > max_lifetime
+ ):
+ raise SynapseError(
+ code=400,
+ msg="'min_lifetime' can't be greater than 'max_lifetime",
+ errcode=Codes.BAD_JSON,
+ )
+
def validate_builder(self, event):
"""Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index f7bb806ae7..5a1e23a145 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -106,7 +106,7 @@ class FederationBase(object):
"Failed to find copy of %s with valid signature", pdu.event_id
)
- defer.returnValue(res)
+ return res
handle = preserve_fn(handle_check_result)
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
@@ -116,9 +116,9 @@ class FederationBase(object):
).addErrback(unwrapFirstError)
if include_none:
- defer.returnValue(valid_pdus)
+ return valid_pdus
else:
- defer.returnValue([p for p in valid_pdus if p])
+ return [p for p in valid_pdus if p]
def _check_sigs_and_hash(self, room_version, pdu):
return make_deferred_yieldable(
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 3cb4b94420..25ed1257f1 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -213,7 +213,7 @@ class FederationClient(FederationBase):
).addErrback(unwrapFirstError)
)
- defer.returnValue(pdus)
+ return pdus
@defer.inlineCallbacks
@log_function
@@ -245,7 +245,7 @@ class FederationClient(FederationBase):
ev = self._get_pdu_cache.get(event_id)
if ev:
- defer.returnValue(ev)
+ return ev
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
@@ -307,7 +307,7 @@ class FederationClient(FederationBase):
if signed_pdu:
self._get_pdu_cache[event_id] = signed_pdu
- defer.returnValue(signed_pdu)
+ return signed_pdu
@defer.inlineCallbacks
@log_function
@@ -355,7 +355,7 @@ class FederationClient(FederationBase):
auth_chain.sort(key=lambda e: e.depth)
- defer.returnValue((pdus, auth_chain))
+ return (pdus, auth_chain)
except HttpResponseException as e:
if e.code == 400 or e.code == 404:
logger.info("Failed to use get_room_state_ids API, falling back")
@@ -404,7 +404,7 @@ class FederationClient(FederationBase):
signed_auth.sort(key=lambda e: e.depth)
- defer.returnValue((signed_pdus, signed_auth))
+ return (signed_pdus, signed_auth)
@defer.inlineCallbacks
def get_events_from_store_or_dest(self, destination, room_id, event_ids):
@@ -429,7 +429,7 @@ class FederationClient(FederationBase):
missing_events.discard(k)
if not missing_events:
- defer.returnValue((signed_events, failed_to_fetch))
+ return (signed_events, failed_to_fetch)
logger.debug(
"Fetching unknown state/auth events %s for room %s",
@@ -465,7 +465,7 @@ class FederationClient(FederationBase):
# We removed all events we successfully fetched from `batch`
failed_to_fetch.update(batch)
- defer.returnValue((signed_events, failed_to_fetch))
+ return (signed_events, failed_to_fetch)
@defer.inlineCallbacks
@log_function
@@ -485,7 +485,7 @@ class FederationClient(FederationBase):
signed_auth.sort(key=lambda e: e.depth)
- defer.returnValue(signed_auth)
+ return signed_auth
@defer.inlineCallbacks
def _try_destination_list(self, description, destinations, callback):
@@ -521,7 +521,7 @@ class FederationClient(FederationBase):
try:
res = yield callback(destination)
- defer.returnValue(res)
+ return res
except InvalidResponseError as e:
logger.warn("Failed to %s via %s: %s", description, destination, e)
except HttpResponseException as e:
@@ -615,7 +615,7 @@ class FederationClient(FederationBase):
event_dict=pdu_dict,
)
- defer.returnValue((destination, ev, event_format))
+ return (destination, ev, event_format)
return self._try_destination_list(
"make_" + membership, destinations, send_request
@@ -728,13 +728,11 @@ class FederationClient(FederationBase):
check_authchain_validity(signed_auth)
- defer.returnValue(
- {
- "state": signed_state,
- "auth_chain": signed_auth,
- "origin": destination,
- }
- )
+ return {
+ "state": signed_state,
+ "auth_chain": signed_auth,
+ "origin": destination,
+ }
return self._try_destination_list("send_join", destinations, send_request)
@@ -758,7 +756,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
- defer.returnValue(pdu)
+ return pdu
@defer.inlineCallbacks
def _do_send_invite(self, destination, pdu, room_version):
@@ -786,7 +784,7 @@ class FederationClient(FederationBase):
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
},
)
- defer.returnValue(content)
+ return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -821,7 +819,7 @@ class FederationClient(FederationBase):
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
- defer.returnValue(content)
+ return content
def send_leave(self, destinations, pdu):
"""Sends a leave event to one of a list of homeservers.
@@ -856,7 +854,7 @@ class FederationClient(FederationBase):
)
logger.debug("Got content: %s", content)
- defer.returnValue(None)
+ return None
return self._try_destination_list("send_leave", destinations, send_request)
@@ -917,7 +915,7 @@ class FederationClient(FederationBase):
"missing": content.get("missing", []),
}
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_missing_events(
@@ -974,7 +972,7 @@ class FederationClient(FederationBase):
# get_missing_events
signed_events = []
- defer.returnValue(signed_events)
+ return signed_events
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
@@ -986,7 +984,7 @@ class FederationClient(FederationBase):
yield self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
- defer.returnValue(None)
+ return None
except CodeMessageException:
raise
except Exception as e:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ed2b6d5eef..d216c46dfe 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -99,7 +99,7 @@ class FederationServer(FederationBase):
res = self._transaction_from_pdus(pdus).get_dict()
- defer.returnValue((200, res))
+ return (200, res)
@defer.inlineCallbacks
@log_function
@@ -126,7 +126,7 @@ class FederationServer(FederationBase):
origin, transaction, request_time
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _handle_incoming_transaction(self, origin, transaction, request_time):
@@ -147,8 +147,7 @@ class FederationServer(FederationBase):
"[%s] We've already responded to this request",
transaction.transaction_id,
)
- defer.returnValue(response)
- return
+ return response
logger.debug("[%s] Transaction is new", transaction.transaction_id)
@@ -163,7 +162,7 @@ class FederationServer(FederationBase):
yield self.transaction_actions.set_response(
origin, transaction, 400, response
)
- defer.returnValue((400, response))
+ return (400, response)
received_pdus_counter.inc(len(transaction.pdus))
@@ -265,7 +264,7 @@ class FederationServer(FederationBase):
logger.debug("Returning: %s", str(response))
yield self.transaction_actions.set_response(origin, transaction, 200, response)
- defer.returnValue((200, response))
+ return (200, response)
@defer.inlineCallbacks
def received_edu(self, origin, edu_type, content):
@@ -298,7 +297,7 @@ class FederationServer(FederationBase):
event_id,
)
- defer.returnValue((200, resp))
+ return (200, resp)
@defer.inlineCallbacks
def on_state_ids_request(self, origin, room_id, event_id):
@@ -315,9 +314,7 @@ class FederationServer(FederationBase):
state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
- defer.returnValue(
- (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids})
- )
+ return (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids})
@defer.inlineCallbacks
def _on_context_state_request_compute(self, room_id, event_id):
@@ -336,12 +333,10 @@ class FederationServer(FederationBase):
)
)
- defer.returnValue(
- {
- "pdus": [pdu.get_pdu_json() for pdu in pdus],
- "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
- }
- )
+ return {
+ "pdus": [pdu.get_pdu_json() for pdu in pdus],
+ "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
+ }
@defer.inlineCallbacks
@log_function
@@ -349,15 +344,15 @@ class FederationServer(FederationBase):
pdu = yield self.handler.get_persisted_pdu(origin, event_id)
if pdu:
- defer.returnValue((200, self._transaction_from_pdus([pdu]).get_dict()))
+ return (200, self._transaction_from_pdus([pdu]).get_dict())
else:
- defer.returnValue((404, ""))
+ return (404, "")
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
received_queries_counter.labels(query_type).inc()
resp = yield self.registry.on_query(query_type, args)
- defer.returnValue((200, resp))
+ return (200, resp)
@defer.inlineCallbacks
def on_make_join_request(self, origin, room_id, user_id, supported_versions):
@@ -371,9 +366,7 @@ class FederationServer(FederationBase):
pdu = yield self.handler.on_make_join_request(origin, room_id, user_id)
time_now = self._clock.time_msec()
- defer.returnValue(
- {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- )
+ return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
@defer.inlineCallbacks
def on_invite_request(self, origin, content, room_version):
@@ -391,7 +384,7 @@ class FederationServer(FederationBase):
yield self.check_server_matches_acl(origin_host, pdu.room_id)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
- defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)})
+ return {"event": ret_pdu.get_pdu_json(time_now)}
@defer.inlineCallbacks
def on_send_join_request(self, origin, content, room_id):
@@ -407,16 +400,14 @@ class FederationServer(FederationBase):
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
- defer.returnValue(
- (
- 200,
- {
- "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
- "auth_chain": [
- p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
- ],
- },
- )
+ return (
+ 200,
+ {
+ "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+ "auth_chain": [
+ p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
+ ],
+ },
)
@defer.inlineCallbacks
@@ -428,9 +419,7 @@ class FederationServer(FederationBase):
room_version = yield self.store.get_room_version(room_id)
time_now = self._clock.time_msec()
- defer.returnValue(
- {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- )
+ return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content, room_id):
@@ -445,7 +434,7 @@ class FederationServer(FederationBase):
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
@@ -456,7 +445,7 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
- defer.returnValue((200, res))
+ return (200, res)
@defer.inlineCallbacks
def on_query_auth_request(self, origin, content, room_id, event_id):
@@ -509,7 +498,7 @@ class FederationServer(FederationBase):
"missing": ret.get("missing", []),
}
- defer.returnValue((200, send_content))
+ return (200, send_content)
@log_function
def on_query_client_keys(self, origin, content):
@@ -548,7 +537,7 @@ class FederationServer(FederationBase):
),
)
- defer.returnValue({"one_time_keys": json_result})
+ return {"one_time_keys": json_result}
@defer.inlineCallbacks
@log_function
@@ -580,9 +569,7 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
- defer.returnValue(
- {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
- )
+ return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
@log_function
def on_openid_userinfo(self, token):
@@ -676,14 +663,14 @@ class FederationServer(FederationBase):
ret = yield self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed
)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
ret = yield self.handler.on_exchange_third_party_invite_request(
origin, room_id, event_dict
)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def check_server_matches_acl(self, server_name, room_id):
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index d46f4aaeb1..36f6d470dc 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -49,7 +49,7 @@ sent_pdus_destination_dist_count = Counter(
sent_pdus_destination_dist_total = Counter(
"synapse_federation_client_sent_pdu_destinations:total",
- "" "Total number of PDUs queued for sending across all destinations",
+ "Total number of PDUs queued for sending across all destinations",
)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 9aab12c0d3..fad980b893 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -374,7 +374,7 @@ class PerDestinationQueue(object):
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
- defer.returnValue((edus, now_stream_id))
+ return (edus, now_stream_id)
@defer.inlineCallbacks
def _get_to_device_message_edus(self, limit):
@@ -393,4 +393,4 @@ class PerDestinationQueue(object):
for content in contents
]
- defer.returnValue((edus, stream_id))
+ return (edus, stream_id)
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 0460a8c4ac..52706302f2 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -133,4 +133,4 @@ class TransactionManager(object):
)
success = False
- defer.returnValue(success)
+ return success
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 1aae9ec9e7..2a6709ff48 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -183,7 +183,7 @@ class TransportLayerClient(object):
try_trailing_slash_on_400=True,
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -201,7 +201,7 @@ class TransportLayerClient(object):
ignore_backoff=ignore_backoff,
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -259,7 +259,7 @@ class TransportLayerClient(object):
ignore_backoff=ignore_backoff,
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -270,7 +270,7 @@ class TransportLayerClient(object):
destination=destination, path=path, data=content
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -288,7 +288,7 @@ class TransportLayerClient(object):
ignore_backoff=True,
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -299,7 +299,7 @@ class TransportLayerClient(object):
destination=destination, path=path, data=content, ignore_backoff=True
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -310,7 +310,7 @@ class TransportLayerClient(object):
destination=destination, path=path, data=content, ignore_backoff=True
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -339,7 +339,7 @@ class TransportLayerClient(object):
destination=remote_server, path=path, args=args, ignore_backoff=True
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -350,7 +350,7 @@ class TransportLayerClient(object):
destination=destination, path=path, data=event_dict
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -359,7 +359,7 @@ class TransportLayerClient(object):
content = yield self.client.get_json(destination=destination, path=path)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -370,7 +370,7 @@ class TransportLayerClient(object):
destination=destination, path=path, data=content
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -402,7 +402,7 @@ class TransportLayerClient(object):
content = yield self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -426,7 +426,7 @@ class TransportLayerClient(object):
content = yield self.client.get_json(
destination=destination, path=path, timeout=timeout
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -460,7 +460,7 @@ class TransportLayerClient(object):
content = yield self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -488,7 +488,7 @@ class TransportLayerClient(object):
timeout=timeout,
)
- defer.returnValue(content)
+ return content
@log_function
def get_group_profile(self, destination, group_id, requester_user_id):
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index f497711133..dfd7ae041b 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -157,7 +157,7 @@ class GroupAttestionRenewer(object):
yield self.store.update_remote_attestion(group_id, user_id, attestation)
- defer.returnValue({})
+ return {}
def _start_renew_attestations(self):
return run_as_background_process("renew_attestations", self._renew_attestations)
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 168c9e3f84..d50e691436 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -85,7 +85,7 @@ class GroupsServerHandler(object):
if not is_admin:
raise SynapseError(403, "User is not admin in group")
- defer.returnValue(group)
+ return group
@defer.inlineCallbacks
def get_group_summary(self, group_id, requester_user_id):
@@ -151,22 +151,20 @@ class GroupsServerHandler(object):
group_id, requester_user_id
)
- defer.returnValue(
- {
- "profile": profile,
- "users_section": {
- "users": users,
- "roles": roles,
- "total_user_count_estimate": 0, # TODO
- },
- "rooms_section": {
- "rooms": rooms,
- "categories": categories,
- "total_room_count_estimate": 0, # TODO
- },
- "user": membership_info,
- }
- )
+ return {
+ "profile": profile,
+ "users_section": {
+ "users": users,
+ "roles": roles,
+ "total_user_count_estimate": 0, # TODO
+ },
+ "rooms_section": {
+ "rooms": rooms,
+ "categories": categories,
+ "total_room_count_estimate": 0, # TODO
+ },
+ "user": membership_info,
+ }
@defer.inlineCallbacks
def update_group_summary_room(
@@ -192,7 +190,7 @@ class GroupsServerHandler(object):
is_public=is_public,
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def delete_group_summary_room(
@@ -208,7 +206,7 @@ class GroupsServerHandler(object):
group_id=group_id, room_id=room_id, category_id=category_id
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def set_group_join_policy(self, group_id, requester_user_id, content):
@@ -228,7 +226,7 @@ class GroupsServerHandler(object):
yield self.store.set_group_join_policy(group_id, join_policy=join_policy)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def get_group_categories(self, group_id, requester_user_id):
@@ -237,7 +235,7 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
categories = yield self.store.get_group_categories(group_id=group_id)
- defer.returnValue({"categories": categories})
+ return {"categories": categories}
@defer.inlineCallbacks
def get_group_category(self, group_id, requester_user_id, category_id):
@@ -249,7 +247,7 @@ class GroupsServerHandler(object):
group_id=group_id, category_id=category_id
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def update_group_category(self, group_id, requester_user_id, category_id, content):
@@ -269,7 +267,7 @@ class GroupsServerHandler(object):
profile=profile,
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def delete_group_category(self, group_id, requester_user_id, category_id):
@@ -283,7 +281,7 @@ class GroupsServerHandler(object):
group_id=group_id, category_id=category_id
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def get_group_roles(self, group_id, requester_user_id):
@@ -292,7 +290,7 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
roles = yield self.store.get_group_roles(group_id=group_id)
- defer.returnValue({"roles": roles})
+ return {"roles": roles}
@defer.inlineCallbacks
def get_group_role(self, group_id, requester_user_id, role_id):
@@ -301,7 +299,7 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = yield self.store.get_group_role(group_id=group_id, role_id=role_id)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def update_group_role(self, group_id, requester_user_id, role_id, content):
@@ -319,7 +317,7 @@ class GroupsServerHandler(object):
group_id=group_id, role_id=role_id, is_public=is_public, profile=profile
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def delete_group_role(self, group_id, requester_user_id, role_id):
@@ -331,7 +329,7 @@ class GroupsServerHandler(object):
yield self.store.remove_group_role(group_id=group_id, role_id=role_id)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def update_group_summary_user(
@@ -355,7 +353,7 @@ class GroupsServerHandler(object):
is_public=is_public,
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id):
@@ -369,7 +367,7 @@ class GroupsServerHandler(object):
group_id=group_id, user_id=user_id, role_id=role_id
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def get_group_profile(self, group_id, requester_user_id):
@@ -391,7 +389,7 @@ class GroupsServerHandler(object):
group_description = {key: group[key] for key in cols}
group_description["is_openly_joinable"] = group["join_policy"] == "open"
- defer.returnValue(group_description)
+ return group_description
else:
raise SynapseError(404, "Unknown group")
@@ -461,9 +459,7 @@ class GroupsServerHandler(object):
# TODO: If admin add lists of users whose attestations have timed out
- defer.returnValue(
- {"chunk": chunk, "total_user_count_estimate": len(user_results)}
- )
+ return {"chunk": chunk, "total_user_count_estimate": len(user_results)}
@defer.inlineCallbacks
def get_invited_users_in_group(self, group_id, requester_user_id):
@@ -494,9 +490,7 @@ class GroupsServerHandler(object):
logger.warn("Error getting profile for %s: %s", user_id, e)
user_profiles.append(user_profile)
- defer.returnValue(
- {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
- )
+ return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
@defer.inlineCallbacks
def get_rooms_in_group(self, group_id, requester_user_id):
@@ -533,9 +527,7 @@ class GroupsServerHandler(object):
chunk.sort(key=lambda e: -e["num_joined_members"])
- defer.returnValue(
- {"chunk": chunk, "total_room_count_estimate": len(room_results)}
- )
+ return {"chunk": chunk, "total_room_count_estimate": len(room_results)}
@defer.inlineCallbacks
def add_room_to_group(self, group_id, requester_user_id, room_id, content):
@@ -551,7 +543,7 @@ class GroupsServerHandler(object):
yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def update_room_in_group(
@@ -574,7 +566,7 @@ class GroupsServerHandler(object):
else:
raise SynapseError(400, "Uknown config option")
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def remove_room_from_group(self, group_id, requester_user_id, room_id):
@@ -586,7 +578,7 @@ class GroupsServerHandler(object):
yield self.store.remove_room_from_group(group_id, room_id)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def invite_to_group(self, group_id, user_id, requester_user_id, content):
@@ -644,9 +636,9 @@ class GroupsServerHandler(object):
)
elif res["state"] == "invite":
yield self.store.add_group_invite(group_id, user_id)
- defer.returnValue({"state": "invite"})
+ return {"state": "invite"}
elif res["state"] == "reject":
- defer.returnValue({"state": "reject"})
+ return {"state": "reject"}
else:
raise SynapseError(502, "Unknown state returned by HS")
@@ -679,7 +671,7 @@ class GroupsServerHandler(object):
remote_attestation=remote_attestation,
)
- defer.returnValue(local_attestation)
+ return local_attestation
@defer.inlineCallbacks
def accept_invite(self, group_id, requester_user_id, content):
@@ -699,7 +691,7 @@ class GroupsServerHandler(object):
local_attestation = yield self._add_user(group_id, requester_user_id, content)
- defer.returnValue({"state": "join", "attestation": local_attestation})
+ return {"state": "join", "attestation": local_attestation}
@defer.inlineCallbacks
def join_group(self, group_id, requester_user_id, content):
@@ -716,7 +708,7 @@ class GroupsServerHandler(object):
local_attestation = yield self._add_user(group_id, requester_user_id, content)
- defer.returnValue({"state": "join", "attestation": local_attestation})
+ return {"state": "join", "attestation": local_attestation}
@defer.inlineCallbacks
def knock(self, group_id, requester_user_id, content):
@@ -769,7 +761,7 @@ class GroupsServerHandler(object):
if not self.hs.is_mine_id(user_id):
yield self.store.maybe_delete_remote_profile_cache(user_id)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def create_group(self, group_id, requester_user_id, content):
@@ -845,7 +837,7 @@ class GroupsServerHandler(object):
avatar_url=user_profile.get("avatar_url"),
)
- defer.returnValue({"group_id": group_id})
+ return {"group_id": group_id}
@defer.inlineCallbacks
def delete_group(self, group_id, requester_user_id):
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index e62e6cab77..8acd9f9a83 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -51,8 +51,8 @@ class AccountDataEventSource(object):
{"type": account_data_type, "content": content, "room_id": room_id}
)
- defer.returnValue((results, current_stream_id))
+ return (results, current_stream_id)
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
- defer.returnValue(([], config.to_id))
+ return ([], config.to_id)
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 1f1708ba7d..51305b0c90 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -43,6 +43,8 @@ class AccountValidityHandler(object):
self.clock = self.hs.get_clock()
self._account_validity = self.hs.config.account_validity
+ self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory
+ self.profile_handler = self.hs.get_profile_handler()
if self._account_validity.renew_by_email_enabled and load_jinja2_templates:
# Don't do email-specific configuration if renewal by email is disabled.
@@ -77,6 +79,9 @@ class AccountValidityHandler(object):
self.clock.looping_call(send_emails, 30 * 60 * 1000)
+ # Check every hour to remove expired users from the user directory
+ self.clock.looping_call(self._mark_expired_users_as_inactive, 60 * 60 * 1000)
+
@defer.inlineCallbacks
def send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
@@ -193,7 +198,7 @@ class AccountValidityHandler(object):
if threepid["medium"] == "email":
addresses.append(threepid["address"])
- defer.returnValue(addresses)
+ return addresses
@defer.inlineCallbacks
def _get_renewal_token(self, user_id):
@@ -214,7 +219,7 @@ class AccountValidityHandler(object):
try:
renewal_token = stringutils.random_string(32)
yield self.store.set_renewal_token_for_user(user_id, renewal_token)
- defer.returnValue(renewal_token)
+ return renewal_token
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
@@ -226,11 +231,19 @@ class AccountValidityHandler(object):
Args:
renewal_token (str): Token sent with the renewal request.
+ Returns:
+ bool: Whether the provided token is valid.
"""
- user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+ try:
+ user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+ except StoreError:
+ defer.returnValue(False)
+
logger.debug("Renewing an account for user %s", user_id)
yield self.renew_account_for_user(user_id)
+ defer.returnValue(True)
+
@defer.inlineCallbacks
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
"""Renews the account attached to a given user by pushing back the
@@ -254,4 +267,27 @@ class AccountValidityHandler(object):
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
- defer.returnValue(expiration_ts)
+ # Check if renewed users should be reintroduced to the user directory
+ if self._show_users_in_user_directory:
+ # Show the user in the directory again by setting them to active
+ yield self.profile_handler.set_active(
+ UserID.from_string(user_id), True, True
+ )
+
+ return expiration_ts
+
+ @defer.inlineCallbacks
+ def _mark_expired_users_as_inactive(self):
+ """Iterate over expired users. Mark them as inactive in order to hide them from the
+ user directory.
+
+ Returns:
+ Deferred
+ """
+ # Get expired users
+ expired_user_ids = yield self.store.get_expired_users()
+ expired_users = [UserID.from_string(user_id) for user_id in expired_user_ids]
+
+ # Mark each one as non-active
+ for user in expired_users:
+ yield self.profile_handler.set_active(user, False, True)
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index fbef2f3d38..46ac73106d 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -100,4 +100,4 @@ class AcmeHandler(object):
logger.exception("Failed saving!")
raise
- defer.returnValue(True)
+ return True
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index e8a651e231..2f22f56ca4 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -49,7 +49,7 @@ class AdminHandler(BaseHandler):
"devices": {"": {"sessions": [{"connections": connections}]}},
}
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_users(self):
@@ -61,7 +61,7 @@ class AdminHandler(BaseHandler):
"""
ret = yield self.store.get_users()
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_users_paginate(self, order, start, limit):
@@ -78,7 +78,7 @@ class AdminHandler(BaseHandler):
"""
ret = yield self.store.get_users_paginate(order, start, limit)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def search_users(self, term):
@@ -92,7 +92,7 @@ class AdminHandler(BaseHandler):
"""
ret = yield self.store.search_users(term)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def export_user_data(self, user_id, writer):
@@ -225,7 +225,7 @@ class AdminHandler(BaseHandler):
state = yield self.store.get_state_for_event(event_id)
writer.write_state(room_id, event_id, state)
- defer.returnValue(writer.finished())
+ return writer.finished()
class ExfiltrationWriter(object):
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 8f089f0e33..d1a51df6f9 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -167,8 +167,8 @@ class ApplicationServicesHandler(object):
for user_service in user_query_services:
is_known_user = yield self.appservice_api.query_user(user_service, user_id)
if is_known_user:
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
@defer.inlineCallbacks
def query_room_alias_exists(self, room_alias):
@@ -192,7 +192,7 @@ class ApplicationServicesHandler(object):
if is_known_alias:
# the alias exists now so don't query more ASes.
result = yield self.store.get_association_from_room_alias(room_alias)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields):
@@ -215,7 +215,7 @@ class ApplicationServicesHandler(object):
if success:
ret.extend(result)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_3pe_protocols(self, only_protocol=None):
@@ -254,7 +254,7 @@ class ApplicationServicesHandler(object):
for p in protocols.keys():
protocols[p] = _merge_instances(protocols[p])
- defer.returnValue(protocols)
+ return protocols
@defer.inlineCallbacks
def _get_services_for_event(self, event):
@@ -276,7 +276,7 @@ class ApplicationServicesHandler(object):
if (yield s.is_interested(event, self.store)):
interested_list.append(s)
- defer.returnValue(interested_list)
+ return interested_list
def _get_services_for_user(self, user_id):
services = self.store.get_app_services()
@@ -293,23 +293,23 @@ class ApplicationServicesHandler(object):
if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
- defer.returnValue(False)
+ return False
return
user_info = yield self.store.get_user_by_id(user_id)
if user_info:
- defer.returnValue(False)
+ return False
return
# user not found; could be the AS though, so check.
services = self.store.get_app_services()
service_list = [s for s in services if s.sender == user_id]
- defer.returnValue(len(service_list) == 0)
+ return len(service_list) == 0
@defer.inlineCallbacks
def _check_user_exists(self, user_id):
unknown_user = yield self._is_unknown_user(user_id)
if unknown_user:
exists = yield self.query_user_exists(user_id)
- defer.returnValue(exists)
- defer.returnValue(True)
+ return exists
+ return True
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d4d6574975..bf124032f1 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -155,7 +155,7 @@ class AuthHandler(BaseHandler):
if user_id != requester.user.to_string():
raise AuthError(403, "Invalid auth")
- defer.returnValue(params)
+ return params
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip, password_servlet=False):
@@ -280,7 +280,7 @@ class AuthHandler(BaseHandler):
creds,
list(clientdict),
)
- defer.returnValue((creds, clientdict, session["id"]))
+ return (creds, clientdict, session["id"])
ret = self._auth_dict_for_flows(flows, session)
ret["completed"] = list(creds)
@@ -307,8 +307,8 @@ class AuthHandler(BaseHandler):
if result:
creds[stagetype] = result
self._save_session(sess)
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
def get_session_id(self, clientdict):
"""
@@ -379,7 +379,7 @@ class AuthHandler(BaseHandler):
res = yield checker(
authdict, clientip=clientip, password_servlet=password_servlet
)
- defer.returnValue(res)
+ return res
# build a v1-login-style dict out of the authdict and fall back to the
# v1 code
@@ -389,7 +389,7 @@ class AuthHandler(BaseHandler):
raise SynapseError(400, "", Codes.MISSING_PARAM)
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
- defer.returnValue(canonical_id)
+ return canonical_id
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip, **kwargs):
@@ -409,7 +409,7 @@ class AuthHandler(BaseHandler):
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
- client = self.hs.get_simple_http_client()
+ client = self.hs.get_proxied_http_client()
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
args={
@@ -433,7 +433,7 @@ class AuthHandler(BaseHandler):
resp_body.get("hostname"),
)
if resp_body["success"]:
- defer.returnValue(True)
+ return True
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
def _check_email_identity(self, authdict, **kwargs):
@@ -502,7 +502,7 @@ class AuthHandler(BaseHandler):
threepid["threepid_creds"] = authdict["threepid_creds"]
- defer.returnValue(threepid)
+ return threepid
def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key}
@@ -606,7 +606,7 @@ class AuthHandler(BaseHandler):
yield self.store.delete_access_token(access_token)
raise StoreError(400, "Login raced against device deletion")
- defer.returnValue(access_token)
+ return access_token
@defer.inlineCallbacks
def check_user_exists(self, user_id):
@@ -629,8 +629,8 @@ class AuthHandler(BaseHandler):
self.ratelimit_login_per_account(user_id)
res = yield self._find_user_id_and_pwd_hash(user_id)
if res is not None:
- defer.returnValue(res[0])
- defer.returnValue(None)
+ return res[0]
+ return None
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
@@ -661,7 +661,7 @@ class AuthHandler(BaseHandler):
user_id,
user_infos.keys(),
)
- defer.returnValue(result)
+ return result
def get_supported_login_types(self):
"""Get a the login types supported for the /login API
@@ -722,7 +722,7 @@ class AuthHandler(BaseHandler):
known_login_type = True
is_valid = yield provider.check_password(qualified_user_id, password)
if is_valid:
- defer.returnValue((qualified_user_id, None))
+ return (qualified_user_id, None)
if not hasattr(provider, "get_supported_login_types") or not hasattr(
provider, "check_auth"
@@ -756,7 +756,7 @@ class AuthHandler(BaseHandler):
if result:
if isinstance(result, str):
result = (result, None)
- defer.returnValue(result)
+ return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
@@ -766,7 +766,7 @@ class AuthHandler(BaseHandler):
)
if canonical_user_id:
- defer.returnValue((canonical_user_id, None))
+ return (canonical_user_id, None)
if not known_login_type:
raise SynapseError(400, "Unknown login type %s" % login_type)
@@ -814,9 +814,9 @@ class AuthHandler(BaseHandler):
if isinstance(result, str):
# If it's a str, set callback function to None
result = (result, None)
- defer.returnValue(result)
+ return result
- defer.returnValue((None, None))
+ return (None, None)
@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
@@ -838,7 +838,7 @@ class AuthHandler(BaseHandler):
"""
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
- defer.returnValue(None)
+ return None
(user_id, password_hash) = lookupres
# If the password hash is None, the account has likely been deactivated
@@ -850,8 +850,8 @@ class AuthHandler(BaseHandler):
result = yield self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
- defer.returnValue(None)
- defer.returnValue(user_id)
+ return None
+ return user_id
@defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token):
@@ -865,7 +865,7 @@ class AuthHandler(BaseHandler):
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
self.ratelimit_login_per_account(user_id)
yield self.auth.check_auth_blocking(user_id)
- defer.returnValue(user_id)
+ return user_id
@defer.inlineCallbacks
def delete_access_token(self, access_token):
@@ -976,7 +976,7 @@ class AuthHandler(BaseHandler):
)
yield self.store.user_delete_threepid(user_id, medium, address)
- defer.returnValue(result)
+ return result
def _save_session(self, session):
# TODO: Persistent storage
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index e8f9da6098..ad00dcecfd 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -35,6 +35,7 @@ class DeactivateAccountHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_handlers().identity_handler
+ self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
# Flag that indicates whether the process to part users from rooms is running
@@ -102,6 +103,9 @@ class DeactivateAccountHandler(BaseHandler):
yield self.store.user_set_password_hash(user_id, None)
+ user = UserID.from_string(user_id)
+ yield self._profile_handler.set_active(user, False, False)
+
# Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of)
yield self.store.add_user_pending_deactivation(user_id)
@@ -118,6 +122,10 @@ class DeactivateAccountHandler(BaseHandler):
# parts users from rooms (if it isn't already running)
self._start_user_parting()
+ # Reject all pending invites for the user, so that the user doesn't show up in the
+ # "invited" section of rooms' members list.
+ yield self._reject_pending_invites_for_user(user_id)
+
# Remove all information on the user from the account_validity table.
if self._account_validity_enabled:
yield self.store.delete_account_validity_for_user(user_id)
@@ -125,7 +133,40 @@ class DeactivateAccountHandler(BaseHandler):
# Mark the user as deactivated.
yield self.store.set_user_deactivated_status(user_id, True)
- defer.returnValue(identity_server_supports_unbinding)
+ return identity_server_supports_unbinding
+
+ @defer.inlineCallbacks
+ def _reject_pending_invites_for_user(self, user_id):
+ """Reject pending invites addressed to a given user ID.
+
+ Args:
+ user_id (str): The user ID to reject pending invites for.
+ """
+ user = UserID.from_string(user_id)
+ pending_invites = yield self.store.get_invited_rooms_for_user(user_id)
+
+ for room in pending_invites:
+ try:
+ yield self._room_member_handler.update_membership(
+ create_requester(user),
+ user,
+ room.room_id,
+ "leave",
+ ratelimit=False,
+ require_consent=False,
+ )
+ logger.info(
+ "Rejected invite for deactivated user %r in room %r",
+ user_id,
+ room.room_id,
+ )
+ except Exception:
+ logger.exception(
+ "Failed to reject invite for user %r in room %r:"
+ " ignoring and continuing",
+ user_id,
+ room.room_id,
+ )
def _start_user_parting(self):
"""
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 99e8413092..d36dd850fd 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -64,7 +64,7 @@ class DeviceWorkerHandler(BaseHandler):
for device in devices:
_update_device_from_client_ips(device, ips)
- defer.returnValue(devices)
+ return devices
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
@@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler):
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
- defer.returnValue(device)
+ return device
@measure_func("device.get_user_ids_changed")
@defer.inlineCallbacks
@@ -200,9 +200,7 @@ class DeviceWorkerHandler(BaseHandler):
possibly_joined = []
possibly_left = []
- defer.returnValue(
- {"changed": list(possibly_joined), "left": list(possibly_left)}
- )
+ return {"changed": list(possibly_joined), "left": list(possibly_left)}
class DeviceHandler(DeviceWorkerHandler):
@@ -250,7 +248,7 @@ class DeviceHandler(DeviceWorkerHandler):
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
- defer.returnValue(device_id)
+ return device_id
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
@@ -264,7 +262,7 @@ class DeviceHandler(DeviceWorkerHandler):
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
- defer.returnValue(device_id)
+ return device_id
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@@ -411,9 +409,7 @@ class DeviceHandler(DeviceWorkerHandler):
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- defer.returnValue(
- {"user_id": user_id, "stream_id": stream_id, "devices": devices}
- )
+ return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
@@ -556,6 +552,14 @@ class DeviceListEduUpdater(object):
stream_id = result["stream_id"]
devices = result["devices"]
+ for device in devices:
+ logger.debug(
+ "Handling resync update %r/%r, ID: %r",
+ user_id,
+ device["device_id"],
+ stream_id,
+ )
+
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
@@ -623,7 +627,7 @@ class DeviceListEduUpdater(object):
for _, stream_id, prev_ids, _ in updates:
if not prev_ids:
# We always do a resync if there are no previous IDs
- defer.returnValue(True)
+ return True
for prev_id in prev_ids:
if prev_id == extremity:
@@ -633,8 +637,8 @@ class DeviceListEduUpdater(object):
elif prev_id in stream_id_in_updates:
continue
else:
- defer.returnValue(True)
+ return True
stream_id_in_updates.add(stream_id)
- defer.returnValue(False)
+ return False
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 42d5b3db30..0fd423197c 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -210,7 +210,7 @@ class DirectoryHandler(BaseHandler):
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
- defer.returnValue(room_id)
+ return room_id
@defer.inlineCallbacks
def delete_appservice_association(self, service, room_alias):
@@ -229,7 +229,7 @@ class DirectoryHandler(BaseHandler):
room_id = yield self.store.delete_room_alias(room_alias)
- defer.returnValue(room_id)
+ return room_id
@defer.inlineCallbacks
def get_association(self, room_alias):
@@ -277,7 +277,7 @@ class DirectoryHandler(BaseHandler):
else:
servers = list(servers)
- defer.returnValue({"room_id": room_id, "servers": servers})
+ return {"room_id": room_id, "servers": servers}
return
@defer.inlineCallbacks
@@ -289,7 +289,7 @@ class DirectoryHandler(BaseHandler):
result = yield self.get_association_from_room_alias(room_alias)
if result is not None:
- defer.returnValue({"room_id": result.room_id, "servers": result.servers})
+ return {"room_id": result.room_id, "servers": result.servers}
else:
raise SynapseError(
404,
@@ -342,7 +342,7 @@ class DirectoryHandler(BaseHandler):
# Query AS to see if it exists
as_handler = self.appservice_handler
result = yield as_handler.query_room_alias_exists(room_alias)
- defer.returnValue(result)
+ return result
def can_modify_alias(self, alias, user_id=None):
# Any application service "interested" in an alias they are regexing on
@@ -369,10 +369,10 @@ class DirectoryHandler(BaseHandler):
creator = yield self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id:
- defer.returnValue(True)
+ return True
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
- defer.returnValue(is_admin)
+ return is_admin
@defer.inlineCallbacks
def edit_published_room_list(self, requester, room_id, visibility):
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index fdfe8611b6..1300b540e3 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -144,7 +144,7 @@ class E2eKeysHandler(object):
)
)
- defer.returnValue({"device_keys": results, "failures": failures})
+ return {"device_keys": results, "failures": failures}
@defer.inlineCallbacks
def query_local_devices(self, query):
@@ -189,7 +189,7 @@ class E2eKeysHandler(object):
r["unsigned"]["device_display_name"] = display_name
result_dict[user_id][device_id] = r
- defer.returnValue(result_dict)
+ return result_dict
@defer.inlineCallbacks
def on_federation_query_client_keys(self, query_body):
@@ -197,7 +197,7 @@ class E2eKeysHandler(object):
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
- defer.returnValue({"device_keys": res})
+ return {"device_keys": res}
@defer.inlineCallbacks
def claim_one_time_keys(self, query, timeout):
@@ -259,7 +259,7 @@ class E2eKeysHandler(object):
),
)
- defer.returnValue({"one_time_keys": json_result, "failures": failures})
+ return {"one_time_keys": json_result, "failures": failures}
@defer.inlineCallbacks
def upload_keys_for_user(self, user_id, device_id, keys):
@@ -297,7 +297,7 @@ class E2eKeysHandler(object):
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
- defer.returnValue({"one_time_key_counts": result})
+ return {"one_time_key_counts": result}
@defer.inlineCallbacks
def _upload_one_time_keys_for_user(
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index ebd807bca6..41b871fc59 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -84,7 +84,7 @@ class E2eRoomKeysHandler(object):
user_id, version, room_id, session_id
)
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
@@ -262,7 +262,7 @@ class E2eRoomKeysHandler(object):
new_version = yield self.store.create_e2e_room_keys_version(
user_id, version_info
)
- defer.returnValue(new_version)
+ return new_version
@defer.inlineCallbacks
def get_version_info(self, user_id, version=None):
@@ -292,7 +292,7 @@ class E2eRoomKeysHandler(object):
raise NotFoundError("Unknown backup version")
else:
raise
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def delete_version(self, user_id, version=None):
@@ -350,4 +350,4 @@ class E2eRoomKeysHandler(object):
user_id, version, version_info
)
- defer.returnValue({})
+ return {}
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 6a38328af3..2f1f10a9af 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -143,7 +143,7 @@ class EventStreamHandler(BaseHandler):
"end": tokens[1].to_string(),
}
- defer.returnValue(chunk)
+ return chunk
class EventHandler(BaseHandler):
@@ -166,7 +166,7 @@ class EventHandler(BaseHandler):
event = yield self.store.get_event(event_id, check_room_id=room_id)
if not event:
- defer.returnValue(None)
+ return None
return
users = yield self.store.get_users_in_room(event.room_id)
@@ -179,4 +179,4 @@ class EventHandler(BaseHandler):
if not filtered:
raise AuthError(403, "You don't have permission to access that event.")
- defer.returnValue(event)
+ return event
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 30b69af82c..319ee35d9a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -210,7 +210,7 @@ class FederationHandler(BaseHandler):
event_id,
origin,
)
- defer.returnValue(None)
+ return None
state = None
auth_chain = []
@@ -676,7 +676,7 @@ class FederationHandler(BaseHandler):
events = [e for e in events if e.event_id not in seen_events]
if not events:
- defer.returnValue([])
+ return []
event_map = {e.event_id: e for e in events}
@@ -838,7 +838,7 @@ class FederationHandler(BaseHandler):
# TODO: We can probably do something more clever here.
yield self._handle_new_event(dest, event, backfilled=True)
- defer.returnValue(events)
+ return events
@defer.inlineCallbacks
def maybe_backfill(self, room_id, current_depth):
@@ -894,7 +894,7 @@ class FederationHandler(BaseHandler):
)
if not filtered_extremities:
- defer.returnValue(False)
+ return False
# Check if we reached a point where we should start backfilling.
sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
@@ -965,7 +965,7 @@ class FederationHandler(BaseHandler):
# If this succeeded then we probably already have the
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
- defer.returnValue(True)
+ return True
except SynapseError as e:
logger.info("Failed to backfill from %s because %s", dom, e)
continue
@@ -985,11 +985,11 @@ class FederationHandler(BaseHandler):
logger.exception("Failed to backfill from %s because %s", dom, e)
continue
- defer.returnValue(False)
+ return False
success = yield try_backfill(likely_domains)
if success:
- defer.returnValue(True)
+ return True
# Huh, well *those* domains didn't work out. Lets try some domains
# from the time.
@@ -1031,11 +1031,11 @@ class FederationHandler(BaseHandler):
[dom for dom, _ in likely_domains if dom not in tried_domains]
)
if success:
- defer.returnValue(True)
+ return True
tried_domains.update(dom for dom, _ in likely_domains)
- defer.returnValue(False)
+ return False
def _sanity_check_event(self, ev):
"""
@@ -1082,7 +1082,7 @@ class FederationHandler(BaseHandler):
pdu=event,
)
- defer.returnValue(pdu)
+ return pdu
@defer.inlineCallbacks
def on_event_auth(self, event_id):
@@ -1090,7 +1090,7 @@ class FederationHandler(BaseHandler):
auth = yield self.store.get_auth_chain(
[auth_id for auth_id in event.auth_event_ids()], include_given=True
)
- defer.returnValue([e for e in auth])
+ return [e for e in auth]
@log_function
@defer.inlineCallbacks
@@ -1177,7 +1177,7 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue)
- defer.returnValue(True)
+ return True
@defer.inlineCallbacks
def _handle_queued_pdus(self, room_queue):
@@ -1264,7 +1264,7 @@ class FederationHandler(BaseHandler):
room_version, event, context, do_sig_check=False
)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
@log_function
@@ -1325,7 +1325,7 @@ class FederationHandler(BaseHandler):
state = yield self.store.get_events(list(prev_state_ids.values()))
- defer.returnValue({"state": list(state.values()), "auth_chain": auth_chain})
+ return {"state": list(state.values()), "auth_chain": auth_chain}
@defer.inlineCallbacks
def on_invite_request(self, origin, pdu):
@@ -1345,8 +1345,15 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
+ is_published = yield self.store.is_room_published(event.room_id)
+
if not self.spam_checker.user_may_invite(
- event.sender, event.state_key, event.room_id
+ event.sender,
+ event.state_key,
+ None,
+ room_id=event.room_id,
+ new_room=False,
+ published_room=is_published,
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
@@ -1381,7 +1388,7 @@ class FederationHandler(BaseHandler):
context = yield self.state_handler.compute_event_context(event)
yield self.persist_events_and_notify([(event, context)])
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
@@ -1406,7 +1413,7 @@ class FederationHandler(BaseHandler):
context = yield self.state_handler.compute_event_context(event)
yield self.persist_events_and_notify([(event, context)])
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def _make_and_verify_event(
@@ -1424,7 +1431,7 @@ class FederationHandler(BaseHandler):
assert event.user_id == user_id
assert event.state_key == user_id
assert event.room_id == room_id
- defer.returnValue((origin, event, format_ver))
+ return (origin, event, format_ver)
@defer.inlineCallbacks
@log_function
@@ -1484,7 +1491,7 @@ class FederationHandler(BaseHandler):
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
@log_function
@@ -1517,7 +1524,7 @@ class FederationHandler(BaseHandler):
event.signatures,
)
- defer.returnValue(None)
+ return None
@defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id):
@@ -1545,9 +1552,9 @@ class FederationHandler(BaseHandler):
del results[(event.type, event.state_key)]
res = list(results.values())
- defer.returnValue(res)
+ return res
else:
- defer.returnValue([])
+ return []
@defer.inlineCallbacks
def get_state_ids_for_pdu(self, room_id, event_id):
@@ -1572,9 +1579,9 @@ class FederationHandler(BaseHandler):
else:
results.pop((event.type, event.state_key), None)
- defer.returnValue(list(results.values()))
+ return list(results.values())
else:
- defer.returnValue([])
+ return []
@defer.inlineCallbacks
@log_function
@@ -1587,7 +1594,7 @@ class FederationHandler(BaseHandler):
events = yield filter_events_for_server(self.store, origin, events)
- defer.returnValue(events)
+ return events
@defer.inlineCallbacks
@log_function
@@ -1617,9 +1624,9 @@ class FederationHandler(BaseHandler):
events = yield filter_events_for_server(self.store, origin, [event])
event = events[0]
- defer.returnValue(event)
+ return event
else:
- defer.returnValue(None)
+ return None
def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context)
@@ -1651,7 +1658,7 @@ class FederationHandler(BaseHandler):
self.store.remove_push_actions_from_staging, event.event_id
)
- defer.returnValue(context)
+ return context
@defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False):
@@ -1674,7 +1681,7 @@ class FederationHandler(BaseHandler):
auth_events=ev_info.get("auth_events"),
backfilled=backfilled,
)
- defer.returnValue(res)
+ return res
contexts = yield make_deferred_yieldable(
defer.gatherResults(
@@ -1833,7 +1840,7 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected:
yield self.maybe_kick_guest_users(event)
- defer.returnValue(context)
+ return context
@defer.inlineCallbacks
def _check_for_soft_fail(self, event, state, backfilled):
@@ -1952,7 +1959,7 @@ class FederationHandler(BaseHandler):
logger.debug("on_query_auth returning: %s", ret)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def on_get_missing_events(
@@ -1975,7 +1982,7 @@ class FederationHandler(BaseHandler):
self.store, origin, missing_events
)
- defer.returnValue(missing_events)
+ return missing_events
@defer.inlineCallbacks
@log_function
@@ -2451,16 +2458,14 @@ class FederationHandler(BaseHandler):
logger.debug("construct_auth_difference returning")
- defer.returnValue(
- {
- "auth_chain": local_auth,
- "rejects": {
- e.event_id: {"reason": reason_map[e.event_id], "proof": None}
- for e in base_remote_rejected
- },
- "missing": [e.event_id for e in missing_locals],
- }
- )
+ return {
+ "auth_chain": local_auth,
+ "rejects": {
+ e.event_id: {"reason": reason_map[e.event_id], "proof": None}
+ for e in base_remote_rejected
+ },
+ "missing": [e.event_id for e in missing_locals],
+ }
@defer.inlineCallbacks
@log_function
@@ -2505,7 +2510,7 @@ class FederationHandler(BaseHandler):
room_version, event_dict, event, context
)
- EventValidator().validate_new(event)
+ EventValidator().validate_new(event, self.config)
# We need to tell the transaction queue to send this out, even
# though the sender isn't a local user.
@@ -2563,7 +2568,7 @@ class FederationHandler(BaseHandler):
)
try:
- self.auth.check_from_context(room_version, event, context)
+ yield self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
@@ -2592,7 +2597,12 @@ class FederationHandler(BaseHandler):
original_invite_id, allow_none=True
)
if original_invite:
- display_name = original_invite.content["display_name"]
+ # If the m.room.third_party_invite event's content is empty, it means the
+ # invite has been revoked. In this case, we don't have to raise an error here
+ # because the auth check will fail on the invite (because it's not able to
+ # fetch public keys from the m.room.third_party_invite event's content, which
+ # is empty).
+ display_name = original_invite.content.get("display_name")
event_dict["content"]["third_party_invite"]["display_name"] = display_name
else:
logger.info(
@@ -2607,8 +2617,8 @@ class FederationHandler(BaseHandler):
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
- EventValidator().validate_new(event)
- defer.returnValue((event, context))
+ EventValidator().validate_new(event, self.config)
+ return (event, context)
@defer.inlineCallbacks
def _check_signature(self, event, context):
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 7da63bb643..7b67c8ae0f 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -162,7 +162,7 @@ class GroupsLocalHandler(object):
res.setdefault("user", {})["is_publicised"] = is_publicised
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def create_group(self, group_id, user_id, content):
@@ -207,7 +207,7 @@ class GroupsLocalHandler(object):
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def get_users_in_group(self, group_id, requester_user_id):
@@ -217,7 +217,7 @@ class GroupsLocalHandler(object):
res = yield self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
- defer.returnValue(res)
+ return res
group_server_name = get_domain_from_id(group_id)
@@ -244,7 +244,7 @@ class GroupsLocalHandler(object):
res["chunk"] = valid_entries
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def join_group(self, group_id, user_id, content):
@@ -285,7 +285,7 @@ class GroupsLocalHandler(object):
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def accept_invite(self, group_id, user_id, content):
@@ -326,7 +326,7 @@ class GroupsLocalHandler(object):
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def invite(self, group_id, user_id, requester_user_id, config):
@@ -346,7 +346,7 @@ class GroupsLocalHandler(object):
content,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def on_invite(self, group_id, user_id, content):
@@ -377,7 +377,7 @@ class GroupsLocalHandler(object):
logger.warn("No profile for user %s: %s", user_id, e)
user_profile = {}
- defer.returnValue({"state": "invite", "user_profile": user_profile})
+ return {"state": "invite", "user_profile": user_profile}
@defer.inlineCallbacks
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
@@ -406,7 +406,7 @@ class GroupsLocalHandler(object):
content,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def user_removed_from_group(self, group_id, user_id, content):
@@ -421,7 +421,7 @@ class GroupsLocalHandler(object):
@defer.inlineCallbacks
def get_joined_groups(self, user_id):
group_ids = yield self.store.get_joined_groups(user_id)
- defer.returnValue({"groups": group_ids})
+ return {"groups": group_ids}
@defer.inlineCallbacks
def get_publicised_groups_for_user(self, user_id):
@@ -433,14 +433,14 @@ class GroupsLocalHandler(object):
for app_service in self.store.get_app_services():
result.extend(app_service.get_groups_for_user(user_id))
- defer.returnValue({"groups": result})
+ return {"groups": result}
else:
bulk_result = yield self.transport_client.bulk_get_publicised_groups(
get_domain_from_id(user_id), [user_id]
)
result = bulk_result.get("users", {}).get(user_id)
# TODO: Verify attestations
- defer.returnValue({"groups": result})
+ return {"groups": result}
@defer.inlineCallbacks
def bulk_get_publicised_groups(self, user_ids, proxy=True):
@@ -475,4 +475,4 @@ class GroupsLocalHandler(object):
for app_service in self.store.get_app_services():
results[uid].extend(app_service.get_groups_for_user(uid))
- defer.returnValue({"users": results})
+ return {"users": results}
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 546d6169e9..339e0dd04d 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,13 +20,18 @@
import logging
from canonicaljson import json
+from signedjson.key import decode_verify_key_bytes
+from signedjson.sign import verify_signed_json
+from unpaddedbase64 import decode_base64
from twisted.internet import defer
from synapse.api.errors import (
+ AuthError,
CodeMessageException,
Codes,
HttpResponseException,
+ ProxiedRequestError,
SynapseError,
)
@@ -46,6 +51,8 @@ class IdentityHandler(BaseHandler):
self.trust_any_id_server_just_for_testing_do_not_use = (
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
)
+ self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls
+ self._enable_lookup = hs.config.enable_3pid_lookup
def _should_trust_id_server(self, id_server):
if id_server not in self.trusted_id_servers:
@@ -82,8 +89,11 @@ class IdentityHandler(BaseHandler):
"%s is not a trusted ID server: rejecting 3pid " + "credentials",
id_server,
)
- defer.returnValue(None)
-
+ return None
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ if id_server in self.rewrite_identity_server_urls:
+ id_server = self.rewrite_identity_server_urls[id_server]
try:
data = yield self.http_client.get_json(
"https://%s%s"
@@ -95,8 +105,8 @@ class IdentityHandler(BaseHandler):
raise e.to_synapse_error()
if "medium" in data:
- defer.returnValue(data)
- defer.returnValue(None)
+ return data
+ return None
@defer.inlineCallbacks
def bind_threepid(self, creds, mxid):
@@ -117,9 +127,17 @@ class IdentityHandler(BaseHandler):
else:
raise SynapseError(400, "No client_secret in creds")
+ # if we have a rewrite rule set for the identity server,
+ # apply it now, but only for sending the request (not
+ # storing in the database).
+ if id_server in self.rewrite_identity_server_urls:
+ id_server_host = self.rewrite_identity_server_urls[id_server]
+ else:
+ id_server_host = id_server
+
try:
data = yield self.http_client.post_json_get_json(
- "https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"),
+ "https://%s%s" % (id_server_host, "/_matrix/identity/api/v1/3pid/bind"),
{"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid},
)
logger.debug("bound threepid %r to %s", creds, mxid)
@@ -133,7 +151,7 @@ class IdentityHandler(BaseHandler):
)
except CodeMessageException as e:
data = json.loads(e.msg) # XXX WAT?
- defer.returnValue(data)
+ return data
@defer.inlineCallbacks
def try_unbind_threepid(self, mxid, threepid):
@@ -161,7 +179,7 @@ class IdentityHandler(BaseHandler):
# We don't know where to unbind, so we don't have a choice but to return
if not id_servers:
- defer.returnValue(False)
+ return False
changed = True
for id_server in id_servers:
@@ -169,7 +187,7 @@ class IdentityHandler(BaseHandler):
mxid, threepid, id_server
)
- defer.returnValue(changed)
+ return changed
@defer.inlineCallbacks
def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
@@ -205,6 +223,16 @@ class IdentityHandler(BaseHandler):
)
headers = {b"Authorization": auth_headers}
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ #
+ # Note that destination_is has to be the real id_server, not
+ # the server we connect to.
+ if id_server in self.rewrite_identity_server_urls:
+ id_server = self.rewrite_identity_server_urls[id_server]
+
+ url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
+
try:
yield self.http_client.post_json_get_json(url, content, headers)
changed = True
@@ -224,7 +252,7 @@ class IdentityHandler(BaseHandler):
id_server=id_server,
)
- defer.returnValue(changed)
+ return changed
@defer.inlineCallbacks
def requestEmailToken(
@@ -241,6 +269,11 @@ class IdentityHandler(BaseHandler):
"send_attempt": send_attempt,
}
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ if id_server in self.rewrite_identity_server_urls:
+ id_server = self.rewrite_identity_server_urls[id_server]
+
if next_link:
params.update({"next_link": next_link})
@@ -250,7 +283,7 @@ class IdentityHandler(BaseHandler):
% (id_server, "/_matrix/identity/api/v1/validate/email/requestToken"),
params,
)
- defer.returnValue(data)
+ return data
except HttpResponseException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e.to_synapse_error()
@@ -271,14 +304,134 @@ class IdentityHandler(BaseHandler):
"send_attempt": send_attempt,
}
params.update(kwargs)
-
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ if id_server in self.rewrite_identity_server_urls:
+ id_server = self.rewrite_identity_server_urls[id_server]
try:
data = yield self.http_client.post_json_get_json(
"https://%s%s"
% (id_server, "/_matrix/identity/api/v1/validate/msisdn/requestToken"),
params,
)
- defer.returnValue(data)
+ return data
except HttpResponseException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e.to_synapse_error()
+
+ @defer.inlineCallbacks
+ def lookup_3pid(self, id_server, medium, address):
+ """Looks up a 3pid in the passed identity server.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ medium (str): The type of the third party identifier (e.g. "email").
+ address (str): The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ Deferred[dict]: The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._should_trust_id_server(id_server):
+ raise SynapseError(
+ 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED
+ )
+
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ target = self.rewrite_identity_server_urls.get(id_server, id_server)
+
+ try:
+ data = yield self.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/lookup" % (target,),
+ {"medium": medium, "address": address},
+ )
+
+ if "mxid" in data:
+ if "signatures" not in data:
+ raise AuthError(401, "No signatures on 3pid binding")
+ yield self._verify_any_signature(data, id_server)
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %r: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ defer.returnValue(data)
+
+ @defer.inlineCallbacks
+ def bulk_lookup_3pid(self, id_server, threepids):
+ """Looks up given 3pids in the passed identity server.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ threepids ([[str, str]]): The third party identifiers to lookup, as
+ a list of 2-string sized lists ([medium, address]).
+
+ Returns:
+ Deferred[dict]: The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._should_trust_id_server(id_server):
+ raise SynapseError(
+ 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED
+ )
+
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ target = self.rewrite_identity_server_urls.get(id_server, id_server)
+
+ try:
+ data = yield self.http_client.post_json_get_json(
+ "https://%s/_matrix/identity/api/v1/bulk_lookup" % (target,),
+ {"threepids": threepids},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %r: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ defer.returnValue(data)
+
+ @defer.inlineCallbacks
+ def _verify_any_signature(self, data, server_hostname):
+ if server_hostname not in data["signatures"]:
+ raise AuthError(401, "No signature from server %s" % (server_hostname,))
+
+ for key_name, signature in data["signatures"][server_hostname].items():
+ target = self.rewrite_identity_server_urls.get(
+ server_hostname, server_hostname
+ )
+
+ key_data = yield self.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/pubkey/%s" % (target, key_name)
+ )
+ if "public_key" not in key_data:
+ raise AuthError(
+ 401, "No public key named %s from %s" % (key_name, server_hostname)
+ )
+ verify_signed_json(
+ data,
+ server_hostname,
+ decode_verify_key_bytes(
+ key_name, decode_base64(key_data["public_key"])
+ ),
+ )
+ return
+
+ raise AuthError(401, "No signature from server %s" % (server_hostname,))
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 54c966c8a6..42d6650ed9 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -250,7 +250,7 @@ class InitialSyncHandler(BaseHandler):
"end": now_token.to_string(),
}
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def room_initial_sync(self, requester, room_id, pagin_config=None):
@@ -301,7 +301,7 @@ class InitialSyncHandler(BaseHandler):
result["account_data"] = account_data_events
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _room_initial_sync_parted(
@@ -330,28 +330,24 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
- defer.returnValue(
- {
- "membership": membership,
- "room_id": room_id,
- "messages": {
- "chunk": (
- yield self._event_serializer.serialize_events(
- messages, time_now
- )
- ),
- "start": start_token.to_string(),
- "end": end_token.to_string(),
- },
- "state": (
- yield self._event_serializer.serialize_events(
- room_state.values(), time_now
- )
+ return {
+ "membership": membership,
+ "room_id": room_id,
+ "messages": {
+ "chunk": (
+ yield self._event_serializer.serialize_events(messages, time_now)
),
- "presence": [],
- "receipts": [],
- }
- )
+ "start": start_token.to_string(),
+ "end": end_token.to_string(),
+ },
+ "state": (
+ yield self._event_serializer.serialize_events(
+ room_state.values(), time_now
+ )
+ ),
+ "presence": [],
+ "receipts": [],
+ }
@defer.inlineCallbacks
def _room_initial_sync_joined(
@@ -384,13 +380,13 @@ class InitialSyncHandler(BaseHandler):
def get_presence():
# If presence is disabled, return an empty list
if not self.hs.config.use_presence:
- defer.returnValue([])
+ return []
states = yield presence_handler.get_states(
[m.user_id for m in room_members], as_event=True
)
- defer.returnValue(states)
+ return states
@defer.inlineCallbacks
def get_receipts():
@@ -399,7 +395,7 @@ class InitialSyncHandler(BaseHandler):
)
if not receipts:
receipts = []
- defer.returnValue(receipts)
+ return receipts
presence, receipts, (messages, token) = yield make_deferred_yieldable(
defer.gatherResults(
@@ -442,7 +438,7 @@ class InitialSyncHandler(BaseHandler):
if not is_peeking:
ret["membership"] = membership
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def _check_in_room_or_world_readable(self, room_id, user_id):
@@ -453,7 +449,7 @@ class InitialSyncHandler(BaseHandler):
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
- defer.returnValue((member_event.membership, member_event.event_id))
+ return (member_event.membership, member_event.event_id)
return
except AuthError:
visibility = yield self.state_handler.get_current_state(
@@ -463,7 +459,7 @@ class InitialSyncHandler(BaseHandler):
visibility
and visibility.content["history_visibility"] == "world_readable"
):
- defer.returnValue((Membership.JOIN, None))
+ return (Membership.JOIN, None)
return
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 6d7a987f13..2e1a989782 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -87,7 +87,7 @@ class MessageHandler(object):
)
data = room_state[membership_event_id].get(key)
- defer.returnValue(data)
+ return data
@defer.inlineCallbacks
def get_state_events(
@@ -135,7 +135,7 @@ class MessageHandler(object):
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client(
- self.store, user_id, last_events
+ self.store, user_id, last_events, apply_retention_policies=False
)
event = last_events[0]
@@ -174,7 +174,7 @@ class MessageHandler(object):
# events, as clients won't use them.
bundle_aggregations=False,
)
- defer.returnValue(events)
+ return events
@defer.inlineCallbacks
def get_joined_members(self, requester, room_id):
@@ -213,15 +213,13 @@ class MessageHandler(object):
# Loop fell through, AS has no interested users in room
raise AuthError(403, "Appservice not in room")
- defer.returnValue(
- {
- user_id: {
- "avatar_url": profile.avatar_url,
- "display_name": profile.display_name,
- }
- for user_id, profile in iteritems(users_with_profile)
+ return {
+ user_id: {
+ "avatar_url": profile.avatar_url,
+ "display_name": profile.display_name,
}
- )
+ for user_id, profile in iteritems(users_with_profile)
+ }
class EventCreationHandler(object):
@@ -380,7 +378,11 @@ class EventCreationHandler(object):
# tolerate them in event_auth.check().
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
- prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ prev_event = (
+ yield self.store.get_event(prev_event_id, allow_none=True)
+ if prev_event_id
+ else None
+ )
if not prev_event or prev_event.membership != Membership.JOIN:
logger.warning(
(
@@ -396,9 +398,9 @@ class EventCreationHandler(object):
403, "You must be in the room to create an alias for it"
)
- self.validator.validate_new(event)
+ self.validator.validate_new(event, self.config)
- defer.returnValue((event, context))
+ return (event, context)
def _is_exempt_from_privacy_policy(self, builder, requester):
""""Determine if an event to be sent is exempt from having to consent
@@ -425,9 +427,9 @@ class EventCreationHandler(object):
@defer.inlineCallbacks
def _is_server_notices_room(self, room_id):
if self.config.server_notices_mxid is None:
- defer.returnValue(False)
+ return False
user_ids = yield self.store.get_users_in_room(room_id)
- defer.returnValue(self.config.server_notices_mxid in user_ids)
+ return self.config.server_notices_mxid in user_ids
@defer.inlineCallbacks
def assert_accepted_privacy_policy(self, requester):
@@ -507,7 +509,7 @@ class EventCreationHandler(object):
event.event_id,
prev_state.event_id,
)
- defer.returnValue(prev_state)
+ return prev_state
yield self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
@@ -523,6 +525,8 @@ class EventCreationHandler(object):
"""
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_event_id = prev_state_ids.get((event.type, event.state_key))
+ if not prev_event_id:
+ return
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
@@ -531,7 +535,7 @@ class EventCreationHandler(object):
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
- defer.returnValue(prev_event)
+ return prev_event
return
@defer.inlineCallbacks
@@ -563,7 +567,7 @@ class EventCreationHandler(object):
yield self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit
)
- defer.returnValue(event)
+ return event
@measure_func("create_new_client_event")
@defer.inlineCallbacks
@@ -608,7 +612,7 @@ class EventCreationHandler(object):
if requester:
context.app_service = requester.app_service
- self.validator.validate_new(event)
+ self.validator.validate_new(event, self.config)
# If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an
@@ -626,7 +630,7 @@ class EventCreationHandler(object):
logger.debug("Created event %s", event.event_id)
- defer.returnValue((event, context))
+ return (event, context)
@measure_func("handle_new_client_event")
@defer.inlineCallbacks
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 20bcfed334..6711ced51a 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -15,12 +15,15 @@
# limitations under the License.
import logging
+from six import iteritems
+
from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
@@ -77,6 +80,111 @@ class PaginationHandler(object):
self._purges_by_id = {}
self._event_serializer = hs.get_event_client_serializer()
+ self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
+
+ if hs.config.retention_enabled:
+ # Run the purge jobs described in the configuration file.
+ for job in hs.config.retention_purge_jobs:
+ self.clock.looping_call(
+ run_as_background_process,
+ job["interval"],
+ "purge_history_for_rooms_in_range",
+ self.purge_history_for_rooms_in_range,
+ job["shortest_max_lifetime"],
+ job["longest_max_lifetime"],
+ )
+
+ @defer.inlineCallbacks
+ def purge_history_for_rooms_in_range(self, min_ms, max_ms):
+ """Purge outdated events from rooms within the given retention range.
+
+ If a default retention policy is defined in the server's configuration and its
+ 'max_lifetime' is within this range, also targets rooms which don't have a
+ retention policy.
+
+ Args:
+ min_ms (int|None): Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, it means that the range has no
+ lower limit.
+ max_ms (int|None): Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, it means that the range has no
+ upper limit.
+ """
+ # We want the storage layer to to include rooms with no retention policy in its
+ # return value only if a default retention policy is defined in the server's
+ # configuration and that policy's 'max_lifetime' is either lower (or equal) than
+ # max_ms or higher than min_ms (or both).
+ if self._retention_default_max_lifetime is not None:
+ include_null = True
+
+ if min_ms is not None and min_ms >= self._retention_default_max_lifetime:
+ # The default max_lifetime is lower than (or equal to) min_ms.
+ include_null = False
+
+ if max_ms is not None and max_ms < self._retention_default_max_lifetime:
+ # The default max_lifetime is higher than max_ms.
+ include_null = False
+ else:
+ include_null = False
+
+ rooms = yield self.store.get_rooms_for_retention_period_in_range(
+ min_ms, max_ms, include_null
+ )
+
+ for room_id, retention_policy in iteritems(rooms):
+ if room_id in self._purges_in_progress_by_room:
+ logger.warning(
+ "[purge] not purging room %s as there's an ongoing purge running"
+ " for this room",
+ room_id,
+ )
+ continue
+
+ max_lifetime = retention_policy["max_lifetime"]
+
+ if max_lifetime is None:
+ # If max_lifetime is None, it means that include_null equals True,
+ # therefore we can safely assume that there is a default policy defined
+ # in the server's configuration.
+ max_lifetime = self._retention_default_max_lifetime
+
+ # Figure out what token we should start purging at.
+ ts = self.clock.time_msec() - max_lifetime
+
+ stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts))
+
+ r = (
+ yield self.store.get_room_event_after_stream_ordering(
+ room_id, stream_ordering
+ )
+ )
+ if not r:
+ logger.warning(
+ "[purge] purging events not possible: No event found "
+ "(ts %i => stream_ordering %i)",
+ ts,
+ stream_ordering,
+ )
+ continue
+
+ (stream, topo, _event_id) = r
+ token = "t%d-%d" % (topo, stream)
+
+ purge_id = random_string(16)
+
+ self._purges_by_id[purge_id] = PurgeStatus()
+
+ logger.info(
+ "Starting purging events in room %s (purge_id %s)" % (room_id, purge_id)
+ )
+
+ # We want to purge everything, including local events, and to run the purge in
+ # the background so that it's not blocking any other operation apart from
+ # other purges in the same room.
+ run_as_background_process(
+ "_purge_history", self._purge_history, purge_id, room_id, token, True
+ )
+
def start_purge_history(self, room_id, token, delete_local_events=False):
"""Start off a history purge on a room.
@@ -242,13 +350,11 @@ class PaginationHandler(object):
)
if not events:
- defer.returnValue(
- {
- "chunk": [],
- "start": pagin_config.from_token.to_string(),
- "end": next_token.to_string(),
- }
- )
+ return {
+ "chunk": [],
+ "start": pagin_config.from_token.to_string(),
+ "end": next_token.to_string(),
+ }
state = None
if event_filter and event_filter.lazy_load_members() and len(events) > 0:
@@ -286,4 +392,4 @@ class PaginationHandler(object):
)
)
- defer.returnValue(chunk)
+ return chunk
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
new file mode 100644
index 0000000000..d06b110269
--- /dev/null
+++ b/synapse/handlers/password_policy.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from synapse.api.errors import Codes, PasswordRefusedError
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordPolicyHandler(object):
+ def __init__(self, hs):
+ self.policy = hs.config.password_policy
+ self.enabled = hs.config.password_policy_enabled
+
+ # Regexps for the spec'd policy parameters.
+ self.regexp_digit = re.compile("[0-9]")
+ self.regexp_symbol = re.compile("[^a-zA-Z0-9]")
+ self.regexp_uppercase = re.compile("[A-Z]")
+ self.regexp_lowercase = re.compile("[a-z]")
+
+ def validate_password(self, password):
+ """Checks whether a given password complies with the server's policy.
+
+ Args:
+ password (str): The password to check against the server's policy.
+
+ Raises:
+ PasswordRefusedError: The password doesn't comply with the server's policy.
+ """
+
+ if not self.enabled:
+ return
+
+ minimum_accepted_length = self.policy.get("minimum_length", 0)
+ if len(password) < minimum_accepted_length:
+ raise PasswordRefusedError(
+ msg=(
+ "The password must be at least %d characters long"
+ % minimum_accepted_length
+ ),
+ errcode=Codes.PASSWORD_TOO_SHORT,
+ )
+
+ if (
+ self.policy.get("require_digit", False)
+ and self.regexp_digit.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one digit",
+ errcode=Codes.PASSWORD_NO_DIGIT,
+ )
+
+ if (
+ self.policy.get("require_symbol", False)
+ and self.regexp_symbol.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one symbol",
+ errcode=Codes.PASSWORD_NO_SYMBOL,
+ )
+
+ if (
+ self.policy.get("require_uppercase", False)
+ and self.regexp_uppercase.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one uppercase letter",
+ errcode=Codes.PASSWORD_NO_UPPERCASE,
+ )
+
+ if (
+ self.policy.get("require_lowercase", False)
+ and self.regexp_lowercase.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one lowercase letter",
+ errcode=Codes.PASSWORD_NO_LOWERCASE,
+ )
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6f3537e435..ea54d0b991 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -461,7 +461,7 @@ class PresenceHandler(object):
if affect_presence:
run_in_background(_end)
- defer.returnValue(_user_syncing())
+ return _user_syncing()
def get_currently_syncing_users(self):
"""Get the set of user ids that are currently syncing on this HS.
@@ -556,7 +556,7 @@ class PresenceHandler(object):
"""Get the current presence state for a user.
"""
res = yield self.current_state_for_users([user_id])
- defer.returnValue(res[user_id])
+ return res[user_id]
@defer.inlineCallbacks
def current_state_for_users(self, user_ids):
@@ -585,7 +585,7 @@ class PresenceHandler(object):
states.update(new)
self.user_to_current_state.update(new)
- defer.returnValue(states)
+ return states
@defer.inlineCallbacks
def _persist_and_notify(self, states):
@@ -681,7 +681,7 @@ class PresenceHandler(object):
def get_state(self, target_user, as_event=False):
results = yield self.get_states([target_user.to_string()], as_event=as_event)
- defer.returnValue(results[0])
+ return results[0]
@defer.inlineCallbacks
def get_states(self, target_user_ids, as_event=False):
@@ -703,17 +703,15 @@ class PresenceHandler(object):
now = self.clock.time_msec()
if as_event:
- defer.returnValue(
- [
- {
- "type": "m.presence",
- "content": format_user_presence_state(state, now),
- }
- for state in updates
- ]
- )
+ return [
+ {
+ "type": "m.presence",
+ "content": format_user_presence_state(state, now),
+ }
+ for state in updates
+ ]
else:
- defer.returnValue(updates)
+ return updates
@defer.inlineCallbacks
def set_state(self, target_user, state, ignore_status_msg=False):
@@ -757,9 +755,9 @@ class PresenceHandler(object):
)
if observer_room_ids & observed_room_ids:
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def get_all_presence_updates(self, last_id, current_id):
@@ -778,7 +776,7 @@ class PresenceHandler(object):
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
rows = yield self.store.get_all_presence_updates(last_id, current_id)
- defer.returnValue(rows)
+ return rows
def notify_new_event(self):
"""Called when new events have happened. Handles users and servers
@@ -1034,7 +1032,7 @@ class PresenceEventSource(object):
#
# Hence this guard where we just return nothing so that the sync
# doesn't return. C.f. #5503.
- defer.returnValue(([], max_token))
+ return ([], max_token)
presence = self.get_presence_handler()
stream_change_cache = self.store.presence_stream_cache
@@ -1068,17 +1066,11 @@ class PresenceEventSource(object):
updates = yield presence.current_state_for_users(user_ids_changed)
if include_offline:
- defer.returnValue((list(updates.values()), max_token))
+ return (list(updates.values()), max_token)
else:
- defer.returnValue(
- (
- [
- s
- for s in itervalues(updates)
- if s.state != PresenceState.OFFLINE
- ],
- max_token,
- )
+ return (
+ [s for s in itervalues(updates) if s.state != PresenceState.OFFLINE],
+ max_token,
)
def get_current_key(self):
@@ -1107,7 +1099,7 @@ class PresenceEventSource(object):
)
users_interested_in.update(user_ids)
- defer.returnValue(users_interested_in)
+ return users_interested_in
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
@@ -1287,7 +1279,7 @@ def get_interested_parties(store, states):
# Always notify self
users_to_states.setdefault(state.user_id, []).append(state)
- defer.returnValue((room_ids_to_states, users_to_states))
+ return (room_ids_to_states, users_to_states)
@defer.inlineCallbacks
@@ -1321,4 +1313,4 @@ def get_interested_remotes(store, states, state_handler):
host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states))
- defer.returnValue(hosts_and_states)
+ return hosts_and_states
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index a2388a7091..136128b625 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,8 +17,11 @@
import logging
from six import raise_from
+from six.moves import range
-from twisted.internet import defer
+from signedjson.sign import sign_json
+
+from twisted.internet import defer, reactor
from synapse.api.errors import (
AuthError,
@@ -27,6 +31,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, get_domain_from_id
@@ -46,6 +51,8 @@ class BaseProfileHandler(BaseHandler):
subclass MasterProfileHandler
"""
+ PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000
+
def __init__(self, hs):
super(BaseProfileHandler, self).__init__(hs)
@@ -56,6 +63,87 @@ class BaseProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
+
+ self.max_avatar_size = hs.config.max_avatar_size
+ self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes
+
+ if hs.config.worker_app is None:
+ self.clock.looping_call(
+ self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
+ )
+
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ reactor.callWhenRunning(self._assign_profile_replication_batches)
+ reactor.callWhenRunning(self._replicate_profiles)
+ # Add a looping call to replicate_profiles: this handles retries
+ # if the replication is unsuccessful when the user updated their
+ # profile.
+ self.clock.looping_call(
+ self._replicate_profiles, self.PROFILE_REPLICATE_INTERVAL
+ )
+
+ @defer.inlineCallbacks
+ def _assign_profile_replication_batches(self):
+ """If no profile replication has been done yet, allocate replication batch
+ numbers to each profile to start the replication process.
+ """
+ logger.info("Assigning profile batch numbers...")
+ total = 0
+ while True:
+ assigned = yield self.store.assign_profile_batch()
+ total += assigned
+ if assigned == 0:
+ break
+ logger.info("Assigned %d profile batch numbers", total)
+
+ @defer.inlineCallbacks
+ def _replicate_profiles(self):
+ """If any profile data has been updated and not pushed to the replication targets,
+ replicate it.
+ """
+ host_batches = yield self.store.get_replication_hosts()
+ latest_batch = yield self.store.get_latest_profile_replication_batch_number()
+ if latest_batch is None:
+ latest_batch = -1
+ for repl_host in self.hs.config.replicate_user_profiles_to:
+ if repl_host not in host_batches:
+ host_batches[repl_host] = -1
+ try:
+ for i in range(host_batches[repl_host] + 1, latest_batch + 1):
+ yield self._replicate_host_profile_batch(repl_host, i)
+ except Exception:
+ logger.exception(
+ "Exception while replicating to %s: aborting for now", repl_host
+ )
+
+ @defer.inlineCallbacks
+ def _replicate_host_profile_batch(self, host, batchnum):
+ logger.info("Replicating profile batch %d to %s", batchnum, host)
+ batch_rows = yield self.store.get_profile_batch(batchnum)
+ batch = {
+ UserID(r["user_id"], self.hs.hostname).to_string(): (
+ {"display_name": r["displayname"], "avatar_url": r["avatar_url"]}
+ if r["active"]
+ else None
+ )
+ for r in batch_rows
+ }
+
+ url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,)
+ body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
+ signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
+ try:
+ yield self.http_client.post_json_get_json(url, signed_body)
+ yield self.store.update_replication_batch_for_host(host, batchnum)
+ logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host)
+ except Exception:
+ # This will get retried when the looping call next comes around
+ logger.exception(
+ "Failed to replicate profile batch %d to %s", batchnum, host
+ )
+ raise
+
@defer.inlineCallbacks
def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
@@ -73,7 +161,7 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue({"displayname": displayname, "avatar_url": avatar_url})
+ return {"displayname": displayname, "avatar_url": avatar_url}
else:
try:
result = yield self.federation.make_query(
@@ -82,7 +170,7 @@ class BaseProfileHandler(BaseHandler):
args={"user_id": user_id},
ignore_backoff=True,
)
- defer.returnValue(result)
+ return result
except RequestSendFailed as e:
raise_from(SynapseError(502, "Failed to fetch profile"), e)
except HttpResponseException as e:
@@ -108,10 +196,10 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue({"displayname": displayname, "avatar_url": avatar_url})
+ return {"displayname": displayname, "avatar_url": avatar_url}
else:
profile = yield self.store.get_from_remote_profile_cache(user_id)
- defer.returnValue(profile or {})
+ return profile or {}
@defer.inlineCallbacks
def get_displayname(self, target_user):
@@ -125,7 +213,7 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue(displayname)
+ return displayname
else:
try:
result = yield self.federation.make_query(
@@ -139,7 +227,7 @@ class BaseProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
- defer.returnValue(result["displayname"])
+ return result["displayname"]
@defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
@@ -154,9 +242,16 @@ class BaseProfileHandler(BaseHandler):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
- if not by_admin and target_user != requester.user:
+ if not by_admin and requester and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
+ if not by_admin and self.hs.config.disable_set_displayname:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ if profile.display_name:
+ raise SynapseError(
+ 400, "Changing displayname is disabled on this server"
+ )
+
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -165,7 +260,17 @@ class BaseProfileHandler(BaseHandler):
if new_displayname == "":
new_displayname = None
- yield self.store.set_profile_displayname(target_user.localpart, new_displayname)
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
+ yield self.store.set_profile_displayname(
+ target_user.localpart, new_displayname, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
@@ -173,7 +278,39 @@ class BaseProfileHandler(BaseHandler):
target_user.to_string(), profile
)
- yield self._update_join_states(requester, target_user)
+ if requester:
+ yield self._update_join_states(requester, target_user)
+
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ @defer.inlineCallbacks
+ def set_active(self, target_user, active, hide):
+ """
+ Sets the 'active' flag on a user profile. If set to false, the user
+ account is considered deactivated or hidden.
+
+ If 'hide' is true, then we interpret active=False as a request to try to
+ hide the user rather than deactivating it. This means withholding the
+ profile from replication (and mark it as inactive) rather than clearing
+ the profile from the HS DB. Note that unlike set_displayname and
+ set_avatar_url, this does *not* perform authorization checks! This is
+ because the only place it's used currently is in account deactivation
+ where we've already done these checks anyway.
+ """
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+ yield self.store.set_profile_active(
+ target_user.localpart, active, hide, new_batchnum
+ )
+
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@@ -186,7 +323,7 @@ class BaseProfileHandler(BaseHandler):
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue(avatar_url)
+ return avatar_url
else:
try:
result = yield self.federation.make_query(
@@ -200,7 +337,7 @@ class BaseProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
- defer.returnValue(result["avatar_url"])
+ return result["avatar_url"]
@defer.inlineCallbacks
def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
@@ -212,12 +349,59 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
+ if not by_admin and self.hs.config.disable_set_avatar_url:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ if profile.avatar_url:
+ raise SynapseError(
+ 400, "Changing avatar url is disabled on this server"
+ )
+
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
- yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+ # Enforce a max avatar size if one is defined
+ if self.max_avatar_size or self.allowed_avatar_mimetypes:
+ media_id = self._validate_and_parse_media_id_from_avatar_url(new_avatar_url)
+
+ # Check that this media exists locally
+ media_info = yield self.store.get_local_media(media_id)
+ if not media_info:
+ raise SynapseError(
+ 400, "Unknown media id supplied", errcode=Codes.NOT_FOUND
+ )
+
+ # Ensure avatar does not exceed max allowed avatar size
+ media_size = media_info["media_length"]
+ if self.max_avatar_size and media_size > self.max_avatar_size:
+ raise SynapseError(
+ 400,
+ "Avatars must be less than %s bytes in size"
+ % (self.max_avatar_size,),
+ errcode=Codes.TOO_LARGE,
+ )
+
+ # Ensure the avatar's file type is allowed
+ if (
+ self.allowed_avatar_mimetypes
+ and media_info["media_type"] not in self.allowed_avatar_mimetypes
+ ):
+ raise SynapseError(
+ 400, "Avatar file type '%s' not allowed" % media_info["media_type"]
+ )
+
+ yield self.store.set_profile_avatar_url(
+ target_user.localpart, new_avatar_url, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
@@ -227,6 +411,23 @@ class BaseProfileHandler(BaseHandler):
yield self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ def _validate_and_parse_media_id_from_avatar_url(self, mxc):
+ """Validate and parse a provided avatar url and return the local media id
+
+ Args:
+ mxc (str): A mxc URL
+
+ Returns:
+ str: The ID of the media
+ """
+ avatar_pieces = mxc.split("/")
+ if len(avatar_pieces) != 4 or avatar_pieces[0] != "mxc:":
+ raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc)
+ return avatar_pieces[-1]
+
@defer.inlineCallbacks
def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
@@ -251,7 +452,7 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
def _update_join_states(self, requester, target_user):
@@ -282,7 +483,7 @@ class BaseProfileHandler(BaseHandler):
@defer.inlineCallbacks
def check_profile_query_allowed(self, target_user, requester=None):
"""Checks whether a profile query is allowed. If the
- 'require_auth_for_profile_requests' config flag is set to True and a
+ 'limit_profile_requests_to_known_users' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
share a room.
@@ -300,7 +501,11 @@ class BaseProfileHandler(BaseHandler):
# be None when this function is called outside of a profile query, e.g.
# when building a membership event. In this case, we must allow the
# lookup.
- if not self.hs.config.require_auth_for_profile_requests or not requester:
+ if not self.hs.config.limit_profile_requests_to_known_users or not requester:
+ return
+
+ # Always allow the user to query their own profile.
+ if target_user.to_string() == requester.to_string():
return
# Always allow the user to query their own profile.
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index e58bf7e360..73973502a4 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -93,7 +93,7 @@ class ReceiptsHandler(BaseHandler):
if min_batch_id is None:
# no new receipts
- defer.returnValue(False)
+ return False
affected_room_ids = list(set([r.room_id for r in receipts]))
@@ -103,7 +103,7 @@ class ReceiptsHandler(BaseHandler):
min_batch_id, max_batch_id, affected_room_ids
)
- defer.returnValue(True)
+ return True
@defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
@@ -133,9 +133,9 @@ class ReceiptsHandler(BaseHandler):
)
if not result:
- defer.returnValue([])
+ return []
- defer.returnValue(result)
+ return result
class ReceiptEventSource(object):
@@ -148,13 +148,13 @@ class ReceiptEventSource(object):
to_key = yield self.get_current_key()
if from_key == to_key:
- defer.returnValue(([], to_key))
+ return ([], to_key)
events = yield self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key
)
- defer.returnValue((events, to_key))
+ return (events, to_key)
def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()
@@ -173,4 +173,4 @@ class ReceiptEventSource(object):
room_ids, from_key=from_key, to_key=to_key
)
- defer.returnValue((events, to_key))
+ return (events, to_key)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index bb7cfd71b9..0daf193945 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -60,6 +60,7 @@ class RegistrationHandler(BaseHandler):
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self.captcha_client = CaptchaServerHttpClient(hs)
+ self.http_client = hs.get_simple_http_client()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()
@@ -72,6 +73,8 @@ class RegistrationHandler(BaseHandler):
)
self._server_notices_mxid = hs.config.server_notices_mxid
+ self._show_in_user_directory = self.hs.config.show_users_in_user_directory
+
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -213,6 +216,11 @@ class RegistrationHandler(BaseHandler):
address=address,
)
+ if default_display_name:
+ yield self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
+ )
+
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(localpart)
yield self.user_directory_handler.handle_local_profile_change(
@@ -238,6 +246,11 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=default_display_name,
address=address,
)
+
+ yield self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
+ )
+
except SynapseError:
# if user id is taken, just generate another
user = None
@@ -265,7 +278,15 @@ class RegistrationHandler(BaseHandler):
# Bind email to new account
yield self._register_email_threepid(user_id, threepid_dict, None, False)
- defer.returnValue(user_id)
+ # Prevent the new user from showing up in the user directory if the server
+ # mandates it.
+ if not self._show_in_user_directory:
+ yield self.store.add_account_data_for_user(
+ user_id, "im.vector.hide_profile", {"hide_profile": True}
+ )
+ yield self.profile_handler.set_active(user, False, True)
+
+ return user_id
@defer.inlineCallbacks
def _auto_join_rooms(self, user_id):
@@ -335,7 +356,9 @@ class RegistrationHandler(BaseHandler):
yield self._auto_join_rooms(user_id)
@defer.inlineCallbacks
- def appservice_register(self, user_localpart, as_token):
+ def appservice_register(self, user_localpart, as_token, password, display_name):
+ # FIXME: this should be factored out and merged with normal register()
+
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -354,13 +377,30 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service
)
+ password_hash = ""
+ if password:
+ password_hash = yield self.auth_handler().hash(password)
+
+ display_name = display_name or user.localpart
+
yield self.register_with_store(
user_id=user_id,
- password_hash="",
+ password_hash=password_hash,
appservice_id=service_id,
- create_profile_with_displayname=user.localpart,
+ create_profile_with_displayname=display_name,
)
- defer.returnValue(user_id)
+
+ yield self.profile_handler.set_displayname(
+ user, None, display_name, by_admin=True
+ )
+
+ if self.hs.config.user_directory_search_all_users:
+ profile = yield self.store.get_profileinfo(user_localpart)
+ yield self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
+ return user_id
@defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response):
@@ -384,6 +424,38 @@ class RegistrationHandler(BaseHandler):
logger.info("Valid captcha entered from %s", ip)
@defer.inlineCallbacks
+ def register_saml2(self, localpart):
+ """
+ Registers email_id as SAML2 Based Auth.
+ """
+ if types.contains_invalid_mxid_characters(localpart):
+ raise SynapseError(
+ 400, "User ID can only contain characters a-z, 0-9, or '=_-./'"
+ )
+ yield self.auth.check_auth_blocking()
+ user = UserID(localpart, self.hs.hostname)
+ user_id = user.to_string()
+
+ yield self.check_user_id_not_appservice_exclusive(user_id)
+ token = self.macaroon_gen.generate_access_token(user_id)
+ try:
+ yield self.register_with_store(
+ user_id=user_id,
+ token=token,
+ password_hash=None,
+ create_profile_with_displayname=user.localpart,
+ )
+
+ yield self.profile_handler.set_displayname(
+ user, None, user.localpart, by_admin=True
+ )
+ except Exception as e:
+ yield self.store.add_access_token_to_user(user_id, token)
+ # Ignore Registration errors
+ logger.exception(e)
+ defer.returnValue((user_id, token))
+
+ @defer.inlineCallbacks
def register_email(self, threepidCreds):
"""
Registers emails with an identity server.
@@ -451,6 +523,39 @@ class RegistrationHandler(BaseHandler):
)
@defer.inlineCallbacks
+ def shadow_register(self, localpart, display_name, auth_result, params):
+ """Invokes the current registration on another server, using
+ shared secret registration, passing in any auth_results from
+ other registration UI auth flows (e.g. validated 3pids)
+ Useful for setting up shadow/backup accounts on a parallel deployment.
+ """
+
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token),
+ {
+ # XXX: auth_result is an unspecified extension for shadow registration
+ "auth_result": auth_result,
+ # XXX: another unspecified extension for shadow registration to ensure
+ # that the displayname is correctly set by the masters erver
+ "display_name": display_name,
+ "username": localpart,
+ "password": params.get("password"),
+ "bind_email": params.get("bind_email"),
+ "bind_msisdn": params.get("bind_msisdn"),
+ "device_id": params.get("device_id"),
+ "initial_device_display_name": params.get(
+ "initial_device_display_name"
+ ),
+ "inhibit_login": False,
+ "access_token": as_token,
+ },
+ )
+
+ @defer.inlineCallbacks
def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
with (yield self._generate_user_id_linearizer.queue(())):
@@ -461,7 +566,7 @@ class RegistrationHandler(BaseHandler):
id = self._next_generated_user_id
self._next_generated_user_id += 1
- defer.returnValue(str(id))
+ return str(id)
@defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response):
@@ -481,7 +586,7 @@ class RegistrationHandler(BaseHandler):
"error_url": "http://www.recaptcha.net/recaptcha/api/challenge?"
+ "error=%s" % lines[1],
}
- defer.returnValue(json)
+ return json
@defer.inlineCallbacks
def _submit_captcha(self, ip_addr, private_key, challenge, response):
@@ -497,7 +602,56 @@ class RegistrationHandler(BaseHandler):
"response": response,
},
)
- defer.returnValue(data)
+ return data
+
+ @defer.inlineCallbacks
+ def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
+ """Creates a new user if the user does not exist,
+ else revokes all previous access tokens and generates a new one.
+
+ Args:
+ localpart : The local part of the user ID to register. If None,
+ one will be randomly generated.
+ Returns:
+ A tuple of (user_id, access_token).
+ Raises:
+ RegistrationError if there was a problem registering.
+
+ NB this is only used in tests. TODO: move it to the test package!
+ """
+ if localpart is None:
+ raise SynapseError(400, "Request must include user id")
+ yield self.auth.check_auth_blocking()
+ need_register = True
+
+ try:
+ yield self.check_username(localpart)
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ need_register = False
+ else:
+ raise
+
+ user = UserID(localpart, self.hs.hostname)
+ user_id = user.to_string()
+ token = self.macaroon_gen.generate_access_token(user_id)
+
+ if need_register:
+ yield self.register_with_store(
+ user_id=user_id,
+ token=token,
+ password_hash=password_hash,
+ create_profile_with_displayname=displayname or user.localpart,
+ )
+ if displayname is not None:
+ yield self.profile_handler.set_displayname(
+ user, None, displayname or user.localpart, by_admin=True
+ )
+ else:
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
+ yield self.store.add_access_token_to_user(user_id=user_id, token=token)
+
+ defer.returnValue((user_id, token))
@defer.inlineCallbacks
def _join_user_to_room(self, requester, room_identifier):
@@ -622,7 +776,7 @@ class RegistrationHandler(BaseHandler):
initial_display_name=initial_display_name,
is_guest=is_guest,
)
- defer.returnValue((r["device_id"], r["access_token"]))
+ return (r["device_id"], r["access_token"])
valid_until_ms = None
if self.session_lifetime is not None:
@@ -645,7 +799,7 @@ class RegistrationHandler(BaseHandler):
user_id, device_id=device_id, valid_until_ms=valid_until_ms
)
- defer.returnValue((device_id, access_token))
+ return (device_id, access_token)
@defer.inlineCallbacks
def post_registration_actions(
@@ -798,7 +952,7 @@ class RegistrationHandler(BaseHandler):
if ex.errcode == Codes.MISSING_PARAM:
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
- defer.returnValue(None)
+ return None
raise
yield self._auth_handler.add_threepid(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index db3f8cb76b..af7cfa7888 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -52,12 +52,14 @@ class RoomCreationHandler(BaseHandler):
"history_visibility": "shared",
"original_invitees_have_ops": False,
"guest_can_join": True,
+ "encryption_alg": "m.megolm.v1.aes-sha2",
},
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": "shared",
"original_invitees_have_ops": True,
"guest_can_join": True,
+ "encryption_alg": "m.megolm.v1.aes-sha2",
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
@@ -128,7 +130,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id,
new_version, # args for _upgrade_room
)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def _upgrade_room(self, requester, old_room_id, new_version):
@@ -193,7 +195,7 @@ class RoomCreationHandler(BaseHandler):
requester, old_room_id, new_room_id, old_room_state
)
- defer.returnValue(new_room_id)
+ return new_room_id
@defer.inlineCallbacks
def _update_upgraded_room_pls(
@@ -294,7 +296,19 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_create_room(user_id):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to create rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ user_id, invite_list=[], third_party_invite_list=[], cloning=True
+ ):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -516,8 +530,14 @@ class RoomCreationHandler(BaseHandler):
requester, config, is_requester_admin=is_requester_admin
)
+ invite_list = config.get("invite", [])
+ invite_3pid_list = config.get("invite_3pid", [])
+
if not is_requester_admin and not self.spam_checker.user_may_create_room(
- user_id
+ user_id,
+ invite_list=invite_list,
+ third_party_invite_list=invite_3pid_list,
+ cloning=False,
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -551,7 +571,6 @@ class RoomCreationHandler(BaseHandler):
else:
room_alias = None
- invite_list = config.get("invite", [])
for i in invite_list:
try:
UserID.from_string(i)
@@ -560,8 +579,6 @@ class RoomCreationHandler(BaseHandler):
yield self.event_creation_handler.assert_accepted_privacy_policy(requester)
- invite_3pid_list = config.get("invite_3pid", [])
-
visibility = config.get("visibility", None)
is_public = visibility == "public"
@@ -649,6 +666,7 @@ class RoomCreationHandler(BaseHandler):
"invite",
ratelimit=False,
content=content,
+ new_room=True,
)
for invite_3pid in invite_3pid_list:
@@ -663,6 +681,7 @@ class RoomCreationHandler(BaseHandler):
id_server,
requester,
txn_id=None,
+ new_room=True,
)
result = {"room_id": room_id}
@@ -671,7 +690,7 @@ class RoomCreationHandler(BaseHandler):
result["room_alias"] = room_alias.to_string()
yield directory_handler.send_room_alias_update_event(requester, room_id)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _send_events_for_new_room(
@@ -719,6 +738,7 @@ class RoomCreationHandler(BaseHandler):
"join",
ratelimit=False,
content=creator_join_profile,
+ new_room=True,
)
# We treat the power levels override specially as this needs to be one
@@ -780,6 +800,13 @@ class RoomCreationHandler(BaseHandler):
for (etype, state_key), content in initial_state.items():
yield send(etype=etype, state_key=state_key, content=content)
+ if "encryption_alg" in config:
+ yield send(
+ etype=EventTypes.Encryption,
+ state_key="",
+ content={"algorithm": config["encryption_alg"]},
+ )
+
@defer.inlineCallbacks
def _generate_room_id(self, creator_id, is_public):
# autogen room IDs and try to create it. We may clash, so just
@@ -796,7 +823,7 @@ class RoomCreationHandler(BaseHandler):
room_creator_user_id=creator_id,
is_public=is_public,
)
- defer.returnValue(gen_room_id)
+ return gen_room_id
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a room ID.")
@@ -839,7 +866,7 @@ class RoomContextHandler(object):
event_id, get_prev_content=True, allow_none=True
)
if not event:
- defer.returnValue(None)
+ return None
return
filtered = yield (filter_evts([event]))
@@ -890,7 +917,7 @@ class RoomContextHandler(object):
results["end"] = token.copy_and_replace("room_key", results["end"]).to_string()
- defer.returnValue(results)
+ return results
class RoomEventSource(object):
@@ -941,7 +968,7 @@ class RoomEventSource(object):
else:
end_key = to_key
- defer.returnValue((events, end_key))
+ return (events, end_key)
def get_current_key(self):
return self.store.get_room_events_max_id()
@@ -959,4 +986,4 @@ class RoomEventSource(object):
limit=config.limit,
)
- defer.returnValue((events, next_key))
+ return (events, next_key)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index aae696a7e8..e9094ad02b 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -325,7 +325,7 @@ class RoomListHandler(BaseHandler):
current_limit=since_token.current_limit - 1,
).to_token()
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def _append_room_entry_to_chunk(
@@ -420,7 +420,7 @@ class RoomListHandler(BaseHandler):
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if not allow_private and join_rule and join_rule != JoinRules.PUBLIC:
- defer.returnValue(None)
+ return None
# Return whether this room is open to federation users or not
create_event = current_state.get((EventTypes.Create, ""))
@@ -469,7 +469,7 @@ class RoomListHandler(BaseHandler):
if avatar_url:
result["avatar_url"] = avatar_url
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def get_remote_public_room_list(
@@ -482,7 +482,7 @@ class RoomListHandler(BaseHandler):
third_party_instance_id=None,
):
if not self.enable_room_list_search:
- defer.returnValue({"chunk": [], "total_room_count_estimate": 0})
+ return {"chunk": [], "total_room_count_estimate": 0}
if search_filter:
# We currently don't support searching across federation, so we have
@@ -507,7 +507,7 @@ class RoomListHandler(BaseHandler):
]
}
- defer.returnValue(res)
+ return res
def _get_remote_list_cached(
self,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index e0196ef83e..e2ac2637c5 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -20,22 +20,23 @@ import logging
from six.moves import http_client
-from signedjson.key import decode_verify_key_bytes
-from signedjson.sign import verify_signed_json
-from unpaddedbase64 import decode_base64
-
from twisted.internet import defer
import synapse.server
import synapse.types
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError
+from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ ProxiedRequestError,
+ HttpResponseException,
+ SynapseError,
+)
from synapse.types import RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
-from ._base import BaseHandler
-
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
@@ -67,6 +68,7 @@ class RoomMemberHandler(object):
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.identity_handler = hs.get_handlers().identity_handler
self.member_linearizer = Linearizer(name="member")
@@ -74,13 +76,10 @@ class RoomMemberHandler(object):
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules = hs.get_third_party_event_rules()
self._server_notices_mxid = self.config.server_notices_mxid
+ self.rewrite_identity_server_urls = self.config.rewrite_identity_server_urls
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
-
- # This is only used to get at ratelimit function, and
- # maybe_kick_guest_users. It's fine there are multiple of these as
- # it doesn't store state.
- self.base_handler = BaseHandler(hs)
+ self.ratelimiter = Ratelimiter()
@abc.abstractmethod
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
@@ -191,7 +190,7 @@ class RoomMemberHandler(object):
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
- defer.returnValue(duplicate)
+ return duplicate
yield self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit
@@ -233,7 +232,7 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target, room_id)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id):
@@ -285,8 +284,31 @@ class RoomMemberHandler(object):
third_party_signed=None,
ratelimit=True,
content=None,
+ new_room=False,
require_consent=True,
):
+ """Update a users membership in a room
+
+ Args:
+ requester (Requester)
+ target (UserID)
+ room_id (str)
+ action (str): The "action" the requester is performing against the
+ target. One of join/leave/kick/ban/invite/unban.
+ txn_id (str|None): The transaction ID associated with the request,
+ or None not provided.
+ remote_room_hosts (list[str]|None): List of remote servers to try
+ and join via if server isn't already in the room.
+ third_party_signed (dict|None): The signed object for third party
+ invites.
+ ratelimit (bool): Whether to apply ratelimiting to this request.
+ content (dict|None): Fields to include in the new events content.
+ new_room (bool): Whether these membership changes are happening
+ as part of a room creation (e.g. initial joins and invites)
+
+ Returns:
+ Deferred[FrozenEvent]
+ """
key = (room_id,)
with (yield self.member_linearizer.queue(key)):
@@ -300,10 +322,11 @@ class RoomMemberHandler(object):
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
+ new_room=new_room,
require_consent=require_consent,
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _update_membership(
@@ -317,6 +340,7 @@ class RoomMemberHandler(object):
third_party_signed=None,
ratelimit=True,
content=None,
+ new_room=False,
require_consent=True,
):
content_specified = bool(content)
@@ -381,8 +405,15 @@ class RoomMemberHandler(object):
)
block_invite = True
+ is_published = yield self.store.is_room_published(room_id)
+
if not self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(), room_id
+ requester.user.to_string(),
+ target.to_string(),
+ third_party_invite=None,
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
):
logger.info("Blocking invite due to spam checker")
block_invite = True
@@ -423,7 +454,7 @@ class RoomMemberHandler(object):
same_membership = old_membership == effective_membership_state
same_sender = requester.user.to_string() == old_state.sender
if same_sender and same_membership and same_content:
- defer.returnValue(old_state)
+ return old_state
if old_membership in ["ban", "leave"] and action == "kick":
raise AuthError(403, "The target user is not in the room")
@@ -455,8 +486,26 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to join rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
+
+ inviter = yield self._get_inviter(target.to_string(), room_id)
+ if not is_requester_admin:
+ # We assume that if the spam checker allowed the user to create
+ # a room then they're allowed to join it.
+ if not new_room and not self.spam_checker.user_may_join_room(
+ target.to_string(), room_id, is_invited=inviter is not None
+ ):
+ raise SynapseError(403, "Not allowed to join this room")
+
if not is_host_in_room:
- inviter = yield self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@@ -473,7 +522,7 @@ class RoomMemberHandler(object):
ret = yield self._remote_join(
requester, remote_room_hosts, room_id, target, content
)
- defer.returnValue(ret)
+ return ret
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
@@ -495,7 +544,7 @@ class RoomMemberHandler(object):
res = yield self._remote_reject_invite(
requester, remote_room_hosts, room_id, target
)
- defer.returnValue(res)
+ return res
res = yield self._local_membership_update(
requester=requester,
@@ -508,7 +557,7 @@ class RoomMemberHandler(object):
content=content,
require_consent=require_consent,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def send_membership_event(
@@ -596,11 +645,11 @@ class RoomMemberHandler(object):
"""
guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
if not guest_access_id:
- defer.returnValue(False)
+ return False
guest_access = yield self.store.get_event(guest_access_id)
- defer.returnValue(
+ return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
@@ -635,7 +684,7 @@ class RoomMemberHandler(object):
servers.remove(room_alias.domain)
servers.insert(0, room_alias.domain)
- defer.returnValue((RoomID.from_string(room_id), servers))
+ return (RoomID.from_string(room_id), servers)
@defer.inlineCallbacks
def _get_inviter(self, user_id, room_id):
@@ -643,11 +692,19 @@ class RoomMemberHandler(object):
user_id=user_id, room_id=room_id
)
if invite:
- defer.returnValue(UserID.from_string(invite.sender))
+ return UserID.from_string(invite.sender)
@defer.inlineCallbacks
def do_3pid_invite(
- self, room_id, inviter, medium, address, id_server, requester, txn_id
+ self,
+ room_id,
+ inviter,
+ medium,
+ address,
+ id_server,
+ requester,
+ txn_id,
+ new_room=False,
):
if self.config.block_non_admin_invites:
is_requester_admin = yield self.auth.is_server_admin(requester.user)
@@ -658,7 +715,23 @@ class RoomMemberHandler(object):
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
- yield self.base_handler.ratelimit(requester)
+ self.ratelimiter.ratelimit(
+ requester.user.to_string(),
+ time_now_s=self.hs.clock.time(),
+ rate_hz=self.hs.config.rc_third_party_invite.per_second,
+ burst_count=self.hs.config.rc_third_party_invite.burst_count,
+ update=True,
+ )
+
+ can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
+ medium, address, room_id
+ )
+ if not can_invite:
+ raise SynapseError(
+ 403,
+ "This third-party identifier can not be invited in this room",
+ Codes.FORBIDDEN,
+ )
can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id
@@ -672,6 +745,19 @@ class RoomMemberHandler(object):
invitee = yield self._lookup_3pid(id_server, medium, address)
+ is_published = yield self.store.is_room_published(room_id)
+
+ if not self.spam_checker.user_may_invite(
+ requester.user.to_string(),
+ invitee,
+ third_party_invite={"medium": medium, "address": address},
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ raise SynapseError(403, "Invites have been disabled on this server")
+
if invitee:
yield self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
@@ -681,6 +767,20 @@ class RoomMemberHandler(object):
requester, id_server, medium, address, room_id, inviter, txn_id=txn_id
)
+ def _get_id_server_target(self, id_server):
+ """Looks up an id_server's actual http endpoint
+
+ Args:
+ id_server (str): the server name to lookup.
+
+ Returns:
+ the http endpoint to connect to.
+ """
+ if id_server in self.rewrite_identity_server_urls:
+ return self.rewrite_identity_server_urls[id_server]
+
+ return id_server
+
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
@@ -694,47 +794,12 @@ class RoomMemberHandler(object):
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
- if not self._enable_lookup:
- raise SynapseError(
- 403, "Looking up third-party identifiers is denied from this server"
- )
try:
- data = yield self.simple_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
- {"medium": medium, "address": address},
- )
-
- if "mxid" in data:
- if "signatures" not in data:
- raise AuthError(401, "No signatures on 3pid binding")
- yield self._verify_any_signature(data, id_server)
- defer.returnValue(data["mxid"])
-
- except IOError as e:
+ data = yield self.identity_handler.lookup_3pid(id_server, medium, address)
+ return data.get("mxid")
+ except ProxiedRequestError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
- defer.returnValue(None)
-
- @defer.inlineCallbacks
- def _verify_any_signature(self, data, server_hostname):
- if server_hostname not in data["signatures"]:
- raise AuthError(401, "No signature from server %s" % (server_hostname,))
- for key_name, signature in data["signatures"][server_hostname].items():
- key_data = yield self.simple_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/pubkey/%s"
- % (id_server_scheme, server_hostname, key_name)
- )
- if "public_key" not in key_data:
- raise AuthError(
- 401, "No public key named %s from %s" % (key_name, server_hostname)
- )
- verify_signed_json(
- data,
- server_hostname,
- decode_verify_key_bytes(
- key_name, decode_base64(key_data["public_key"])
- ),
- )
- return
+ return None
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
@@ -854,9 +919,10 @@ class RoomMemberHandler(object):
user.
"""
+ target = self._get_id_server_target(id_server)
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme,
- id_server,
+ target,
)
invite_config = {
@@ -896,7 +962,7 @@ class RoomMemberHandler(object):
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid"
- % (id_server_scheme, id_server),
+ % (id_server_scheme, target),
}
else:
fallback_public_key = public_keys[0]
@@ -904,7 +970,7 @@ class RoomMemberHandler(object):
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
- defer.returnValue((token, public_keys, fallback_public_key, display_name))
+ return (token, public_keys, fallback_public_key, display_name)
@defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
@@ -913,7 +979,7 @@ class RoomMemberHandler(object):
create_event_id = current_state_ids.get(("m.room.create", ""))
if len(current_state_ids) == 1 and create_event_id:
# We can only get here if we're in the process of creating the room
- defer.returnValue(True)
+ return True
for etype, state_key in current_state_ids:
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
@@ -925,16 +991,16 @@ class RoomMemberHandler(object):
continue
if event.membership == Membership.JOIN:
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def _is_server_notice_room(self, room_id):
if self._server_notices_mxid is None:
- defer.returnValue(False)
+ return False
user_ids = yield self.store.get_users_in_room(room_id)
- defer.returnValue(self._server_notices_mxid in user_ids)
+ return self._server_notices_mxid in user_ids
class RoomMemberMasterHandler(RoomMemberHandler):
@@ -978,7 +1044,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
ret = yield fed_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, target.to_string()
)
- defer.returnValue(ret)
+ return ret
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
@@ -989,7 +1055,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(target.to_string(), room_id)
- defer.returnValue({})
+ return {}
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index fc873a3ba6..75e96ae1a2 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -53,7 +53,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
yield self._user_joined_room(user, room_id)
- defer.returnValue(ret)
+ return ret
def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
"""Implements RoomMemberHandler._remote_reject_invite
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index ddc4430d03..cd5e90bacb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -69,7 +69,7 @@ class SearchHandler(BaseHandler):
# Scan through the old room for further predecessors
room_id = predecessor["room_id"]
- defer.returnValue(historical_room_ids)
+ return historical_room_ids
@defer.inlineCallbacks
def search(self, user, content, batch=None):
@@ -186,13 +186,11 @@ class SearchHandler(BaseHandler):
room_ids.intersection_update({batch_group_key})
if not room_ids:
- defer.returnValue(
- {
- "search_categories": {
- "room_events": {"results": [], "count": 0, "highlights": []}
- }
+ return {
+ "search_categories": {
+ "room_events": {"results": [], "count": 0, "highlights": []}
}
- )
+ }
rank_map = {} # event_id -> rank of event
allowed_events = []
@@ -455,4 +453,4 @@ class SearchHandler(BaseHandler):
if global_next_batch:
rooms_cat_res["next_batch"] = global_next_batch
- defer.returnValue({"search_categories": {"room_events": rooms_cat_res}})
+ return {"search_categories": {"room_events": rooms_cat_res}}
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index d90c9e0108..3f50d6de47 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,12 +31,15 @@ class SetPasswordHandler(BaseHandler):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
+ self._password_policy_handler = hs.get_password_policy_handler()
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
+ self._password_policy_handler.validate_password(newpassword)
+
password_hash = yield self._auth_handler.hash(newpassword)
except_device_id = requester.device_id if requester else None
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index 6b364befd5..f065970c40 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -48,7 +48,7 @@ class StateDeltasHandler(object):
if not event and not prev_event:
logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
- defer.returnValue(None)
+ return None
prev_value = None
value = None
@@ -62,8 +62,8 @@ class StateDeltasHandler(object):
logger.debug("prev_value: %r -> value: %r", prev_value, value)
if value == public_value and prev_value != public_value:
- defer.returnValue(True)
+ return True
elif value != public_value and prev_value == public_value:
- defer.returnValue(False)
+ return False
else:
- defer.returnValue(None)
+ return None
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index a0ee8db988..4449da6669 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -86,7 +86,7 @@ class StatsHandler(StateDeltasHandler):
# If still None then the initial background update hasn't happened yet
if self.pos is None:
- defer.returnValue(None)
+ return None
# Loop round handling deltas until we're up to date
while True:
@@ -328,6 +328,6 @@ class StatsHandler(StateDeltasHandler):
== "world_readable"
)
):
- defer.returnValue(True)
+ return True
else:
- defer.returnValue(False)
+ return False
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index cd1ac0a27a..4007284e5b 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -263,7 +263,7 @@ class SyncHandler(object):
timeout,
full_state,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state):
@@ -303,7 +303,7 @@ class SyncHandler(object):
lazy_loaded = "false"
non_empty_sync_counter.labels(sync_type, lazy_loaded).inc()
- defer.returnValue(result)
+ return result
def current_sync_for_user(self, sync_config, since_token=None, full_state=False):
"""Get the sync for client needed to match what the server has now.
@@ -317,7 +317,7 @@ class SyncHandler(object):
user_id = user.to_string()
rules = yield self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules)
- defer.returnValue(rules)
+ return rules
@defer.inlineCallbacks
def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
@@ -378,7 +378,7 @@ class SyncHandler(object):
event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
- defer.returnValue((now_token, ephemeral_by_room))
+ return (now_token, ephemeral_by_room)
@defer.inlineCallbacks
def _load_filtered_recents(
@@ -426,8 +426,8 @@ class SyncHandler(object):
recents = []
if not limited or block_all_timeline:
- defer.returnValue(
- TimelineBatch(events=recents, prev_batch=now_token, limited=False)
+ return TimelineBatch(
+ events=recents, prev_batch=now_token, limited=False
)
filtering_factor = 2
@@ -490,12 +490,10 @@ class SyncHandler(object):
prev_batch_token = now_token.copy_and_replace("room_key", room_key)
- defer.returnValue(
- TimelineBatch(
- events=recents,
- prev_batch=prev_batch_token,
- limited=limited or newly_joined_room,
- )
+ return TimelineBatch(
+ events=recents,
+ prev_batch=prev_batch_token,
+ limited=limited or newly_joined_room,
)
@defer.inlineCallbacks
@@ -517,7 +515,7 @@ class SyncHandler(object):
if event.is_state():
state_ids = state_ids.copy()
state_ids[(event.type, event.state_key)] = event.event_id
- defer.returnValue(state_ids)
+ return state_ids
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
@@ -549,7 +547,7 @@ class SyncHandler(object):
else:
# no events in this room - so presumably no state
state = {}
- defer.returnValue(state)
+ return state
@defer.inlineCallbacks
def compute_summary(self, room_id, sync_config, batch, state, now_token):
@@ -579,7 +577,7 @@ class SyncHandler(object):
)
if not last_events:
- defer.returnValue(None)
+ return None
return
last_event = last_events[-1]
@@ -611,14 +609,14 @@ class SyncHandler(object):
if name_id:
name = yield self.store.get_event(name_id, allow_none=True)
if name and name.content.get("name"):
- defer.returnValue(summary)
+ return summary
if canonical_alias_id:
canonical_alias = yield self.store.get_event(
canonical_alias_id, allow_none=True
)
if canonical_alias and canonical_alias.content.get("alias"):
- defer.returnValue(summary)
+ return summary
me = sync_config.user.to_string()
@@ -652,7 +650,7 @@ class SyncHandler(object):
summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5]
if not sync_config.filter_collection.lazy_load_members():
- defer.returnValue(summary)
+ return summary
# ensure we send membership events for heroes if needed
cache_key = (sync_config.user.to_string(), sync_config.device_id)
@@ -686,7 +684,7 @@ class SyncHandler(object):
cache.set(s.state_key, s.event_id)
state[(EventTypes.Member, s.state_key)] = s
- defer.returnValue(summary)
+ return summary
def get_lazy_loaded_members_cache(self, cache_key):
cache = self.lazy_loaded_members_cache.get(cache_key)
@@ -871,14 +869,12 @@ class SyncHandler(object):
if state_ids:
state = yield self.store.get_events(list(state_ids.values()))
- defer.returnValue(
- {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(
- list(state.values())
- )
- }
- )
+ return {
+ (e.type, e.state_key): e
+ for e in sync_config.filter_collection.filter_room_state(
+ list(state.values())
+ )
+ }
@defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config):
@@ -894,11 +890,11 @@ class SyncHandler(object):
notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
- defer.returnValue(notifs)
+ return notifs
# There is no new information in this period, so your notification
# count is whatever it was last time.
- defer.returnValue(None)
+ return None
@defer.inlineCallbacks
def generate_sync_result(self, sync_config, since_token=None, full_state=False):
@@ -989,19 +985,17 @@ class SyncHandler(object):
"Sync result for newly joined room %s: %r", room_id, joined_room
)
- defer.returnValue(
- SyncResult(
- presence=sync_result_builder.presence,
- account_data=sync_result_builder.account_data,
- joined=sync_result_builder.joined,
- invited=sync_result_builder.invited,
- archived=sync_result_builder.archived,
- to_device=sync_result_builder.to_device,
- device_lists=device_lists,
- groups=sync_result_builder.groups,
- device_one_time_keys_count=one_time_key_counts,
- next_batch=sync_result_builder.now_token,
- )
+ return SyncResult(
+ presence=sync_result_builder.presence,
+ account_data=sync_result_builder.account_data,
+ joined=sync_result_builder.joined,
+ invited=sync_result_builder.invited,
+ archived=sync_result_builder.archived,
+ to_device=sync_result_builder.to_device,
+ device_lists=device_lists,
+ groups=sync_result_builder.groups,
+ device_one_time_keys_count=one_time_key_counts,
+ next_batch=sync_result_builder.now_token,
)
@measure_func("_generate_sync_entry_for_groups")
@@ -1124,11 +1118,9 @@ class SyncHandler(object):
# Remove any users that we still share a room with.
newly_left_users -= users_who_share_room
- defer.returnValue(
- DeviceLists(changed=users_that_have_changed, left=newly_left_users)
- )
+ return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
else:
- defer.returnValue(DeviceLists(changed=[], left=[]))
+ return DeviceLists(changed=[], left=[])
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
@@ -1225,7 +1217,7 @@ class SyncHandler(object):
sync_result_builder.account_data = account_data_for_user
- defer.returnValue(account_data_by_room)
+ return account_data_by_room
@defer.inlineCallbacks
def _generate_sync_entry_for_presence(
@@ -1325,7 +1317,7 @@ class SyncHandler(object):
)
if not tags_by_room:
logger.debug("no-oping sync")
- defer.returnValue(([], [], [], []))
+ return ([], [], [], [])
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id=user_id
@@ -1388,13 +1380,11 @@ class SyncHandler(object):
newly_left_users -= newly_joined_or_invited_users
- defer.returnValue(
- (
- newly_joined_rooms,
- newly_joined_or_invited_users,
- newly_left_rooms,
- newly_left_users,
- )
+ return (
+ newly_joined_rooms,
+ newly_joined_or_invited_users,
+ newly_left_rooms,
+ newly_left_users,
)
@defer.inlineCallbacks
@@ -1414,13 +1404,13 @@ class SyncHandler(object):
)
if rooms_changed:
- defer.returnValue(True)
+ return True
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id):
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
@defer.inlineCallbacks
def _get_rooms_changed(self, sync_result_builder, ignored_users):
@@ -1637,7 +1627,7 @@ class SyncHandler(object):
)
room_entries.append(entry)
- defer.returnValue((room_entries, invited, newly_joined_rooms, newly_left_rooms))
+ return (room_entries, invited, newly_joined_rooms, newly_left_rooms)
@defer.inlineCallbacks
def _get_all_rooms(self, sync_result_builder, ignored_users):
@@ -1711,7 +1701,7 @@ class SyncHandler(object):
)
)
- defer.returnValue((room_entries, invited, []))
+ return (room_entries, invited, [])
@defer.inlineCallbacks
def _generate_room_entry(
@@ -1912,7 +1902,7 @@ class SyncHandler(object):
joined_room_ids.add(room_id)
joined_room_ids = frozenset(joined_room_ids)
- defer.returnValue(joined_room_ids)
+ return joined_room_ids
def _action_has_highlight(actions):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c3e0c8fc7e..6b661aa93d 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -140,7 +140,7 @@ class TypingHandler(object):
if was_present:
# No point sending another notification
- defer.returnValue(None)
+ return None
self._push_update(member=member, typing=True)
@@ -173,7 +173,7 @@ class TypingHandler(object):
def _stopped_typing(self, member):
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
- defer.returnValue(None)
+ return None
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 5de9630950..e53669e40d 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -133,7 +133,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# If still None then the initial background update hasn't happened yet
if self.pos is None:
- defer.returnValue(None)
+ return None
# Loop round handling deltas until we're up to date
while True:
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 45d5010952..ccdcb7a770 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -45,6 +45,7 @@ from synapse.http import (
cancelled_to_request_timed_out_error,
redact_uri,
)
+from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.util.async_helpers import timeout_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR
@@ -182,7 +183,15 @@ class SimpleHttpClient(object):
using HTTP in Matrix
"""
- def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
+ def __init__(
+ self,
+ hs,
+ treq_args={},
+ ip_whitelist=None,
+ ip_blacklist=None,
+ http_proxy=None,
+ https_proxy=None,
+ ):
"""
Args:
hs (synapse.server.HomeServer)
@@ -191,6 +200,8 @@ class SimpleHttpClient(object):
we may not request.
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
+ http_proxy (bytes): proxy server to use for http connections. host[:port]
+ https_proxy (bytes): proxy server to use for https connections. host[:port]
"""
self.hs = hs
@@ -235,11 +246,13 @@ class SimpleHttpClient(object):
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
- self.agent = Agent(
+ self.agent = ProxyAgent(
self.reactor,
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
+ http_proxy=http_proxy,
+ https_proxy=https_proxy,
)
if self._ip_blacklist:
@@ -294,7 +307,7 @@ class SimpleHttpClient(object):
logger.info(
"Received response to %s %s: %s", method, redact_uri(uri), response.code
)
- defer.returnValue(response)
+ return response
except Exception as e:
incoming_responses_counter.labels(method, "ERR").inc()
logger.info(
@@ -345,7 +358,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- defer.returnValue(json.loads(body))
+ return json.loads(body)
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -385,7 +398,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- defer.returnValue(json.loads(body))
+ return json.loads(body)
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -410,7 +423,7 @@ class SimpleHttpClient(object):
ValueError: if the response was not JSON
"""
body = yield self.get_raw(uri, args, headers=headers)
- defer.returnValue(json.loads(body))
+ return json.loads(body)
@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}, headers=None):
@@ -453,7 +466,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- defer.returnValue(json.loads(body))
+ return json.loads(body)
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -488,7 +501,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- defer.returnValue(body)
+ return body
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -545,13 +558,11 @@ class SimpleHttpClient(object):
except Exception as e:
raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e)
- defer.returnValue(
- (
- length,
- resp_headers,
- response.request.absoluteURI.decode("ascii"),
- response.code,
- )
+ return (
+ length,
+ resp_headers,
+ response.request.absoluteURI.decode("ascii"),
+ response.code,
)
@@ -627,10 +638,10 @@ class CaptchaServerHttpClient(SimpleHttpClient):
try:
body = yield make_deferred_yieldable(readBody(response))
- defer.returnValue(body)
+ return body
except PartialDownloadError as e:
# twisted dislikes google's response, no content length.
- defer.returnValue(e.response)
+ return e.response
def encode_urlencode_args(args):
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
new file mode 100644
index 0000000000..be7b2ceb8e
--- /dev/null
+++ b/synapse/http/connectproxyclient.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from zope.interface import implementer
+
+from twisted.internet import defer, protocol
+from twisted.internet.error import ConnectError
+from twisted.internet.interfaces import IStreamClientEndpoint
+from twisted.internet.protocol import connectionDone
+from twisted.web import http
+
+logger = logging.getLogger(__name__)
+
+
+class ProxyConnectError(ConnectError):
+ pass
+
+
+@implementer(IStreamClientEndpoint)
+class HTTPConnectProxyEndpoint(object):
+ """An Endpoint implementation which will send a CONNECT request to an http proxy
+
+ Wraps an existing HostnameEndpoint for the proxy.
+
+ When we get the connect() request from the connection pool (via the TLS wrapper),
+ we'll first connect to the proxy endpoint with a ProtocolFactory which will make the
+ CONNECT request. Once that completes, we invoke the protocolFactory which was passed
+ in.
+
+ Args:
+ reactor: the Twisted reactor to use for the connection
+ proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
+ proxy
+ host (bytes): hostname that we want to CONNECT to
+ port (int): port that we want to connect to
+ """
+
+ def __init__(self, reactor, proxy_endpoint, host, port):
+ self._reactor = reactor
+ self._proxy_endpoint = proxy_endpoint
+ self._host = host
+ self._port = port
+
+ def __repr__(self):
+ return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
+
+ def connect(self, protocolFactory):
+ f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
+ d = self._proxy_endpoint.connect(f)
+ # once the tcp socket connects successfully, we need to wait for the
+ # CONNECT to complete.
+ d.addCallback(lambda conn: f.on_connection)
+ return d
+
+
+class HTTPProxiedClientFactory(protocol.ClientFactory):
+ """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect.
+
+ Once the CONNECT completes, invokes the original ClientFactory to build the
+ HTTP Protocol object and run the rest of the connection.
+
+ Args:
+ dst_host (bytes): hostname that we want to CONNECT to
+ dst_port (int): port that we want to connect to
+ wrapped_factory (protocol.ClientFactory): The original Factory
+ """
+
+ def __init__(self, dst_host, dst_port, wrapped_factory):
+ self.dst_host = dst_host
+ self.dst_port = dst_port
+ self.wrapped_factory = wrapped_factory
+ self.on_connection = defer.Deferred()
+
+ def startedConnecting(self, connector):
+ return self.wrapped_factory.startedConnecting(connector)
+
+ def buildProtocol(self, addr):
+ wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
+
+ return HTTPConnectProtocol(
+ self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
+ )
+
+ def clientConnectionFailed(self, connector, reason):
+ logger.debug("Connection to proxy failed: %s", reason)
+ if not self.on_connection.called:
+ self.on_connection.errback(reason)
+ return self.wrapped_factory.clientConnectionFailed(connector, reason)
+
+ def clientConnectionLost(self, connector, reason):
+ logger.debug("Connection to proxy lost: %s", reason)
+ if not self.on_connection.called:
+ self.on_connection.errback(reason)
+ return self.wrapped_factory.clientConnectionLost(connector, reason)
+
+
+class HTTPConnectProtocol(protocol.Protocol):
+ """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
+
+ Args:
+ host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
+ to put in the CONNECT request
+
+ port (int): The original HTTP(s) port to put in the CONNECT request
+
+ wrapped_protocol (interfaces.IProtocol): the original protocol (probably
+ HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
+
+ connected_deferred (Deferred): a Deferred which will be callbacked with
+ wrapped_protocol when the CONNECT completes
+ """
+
+ def __init__(self, host, port, wrapped_protocol, connected_deferred):
+ self.host = host
+ self.port = port
+ self.wrapped_protocol = wrapped_protocol
+ self.connected_deferred = connected_deferred
+ self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
+ self.http_setup_client.on_connected.addCallback(self.proxyConnected)
+
+ def connectionMade(self):
+ self.http_setup_client.makeConnection(self.transport)
+
+ def connectionLost(self, reason=connectionDone):
+ if self.wrapped_protocol.connected:
+ self.wrapped_protocol.connectionLost(reason)
+
+ self.http_setup_client.connectionLost(reason)
+
+ if not self.connected_deferred.called:
+ self.connected_deferred.errback(reason)
+
+ def proxyConnected(self, _):
+ self.wrapped_protocol.makeConnection(self.transport)
+
+ self.connected_deferred.callback(self.wrapped_protocol)
+
+ # Get any pending data from the http buf and forward it to the original protocol
+ buf = self.http_setup_client.clearLineBuffer()
+ if buf:
+ self.wrapped_protocol.dataReceived(buf)
+
+ def dataReceived(self, data):
+ # if we've set up the HTTP protocol, we can send the data there
+ if self.wrapped_protocol.connected:
+ return self.wrapped_protocol.dataReceived(data)
+
+ # otherwise, we must still be setting up the connection: send the data to the
+ # setup client
+ return self.http_setup_client.dataReceived(data)
+
+
+class HTTPConnectSetupClient(http.HTTPClient):
+ """HTTPClient protocol to send a CONNECT message for proxies and read the response.
+
+ Args:
+ host (bytes): The hostname to send in the CONNECT message
+ port (int): The port to send in the CONNECT message
+ """
+
+ def __init__(self, host, port):
+ self.host = host
+ self.port = port
+ self.on_connected = defer.Deferred()
+
+ def connectionMade(self):
+ logger.debug("Connected to proxy, sending CONNECT")
+ self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
+ self.endHeaders()
+
+ def handleStatus(self, version, status, message):
+ logger.debug("Got Status: %s %s %s", status, message, version)
+ if status != b"200":
+ raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
+
+ def handleEndHeaders(self):
+ logger.debug("End Headers")
+ self.on_connected.callback(None)
+
+ def handleResponse(self, body):
+ pass
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 054c321a20..c03ddb724f 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -177,7 +177,7 @@ class MatrixFederationAgent(object):
res = yield make_deferred_yieldable(
agent.request(method, uri, headers, bodyProducer)
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
@@ -205,24 +205,20 @@ class MatrixFederationAgent(object):
port = parsed_uri.port
if port == -1:
port = 8448
- defer.returnValue(
- _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=port,
- )
+ return _RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=port,
)
if parsed_uri.port != -1:
# there is an explicit port
- defer.returnValue(
- _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=parsed_uri.port,
- )
+ return _RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=parsed_uri.port,
)
if lookup_well_known:
@@ -259,7 +255,7 @@ class MatrixFederationAgent(object):
)
res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
- defer.returnValue(res)
+ return res
# try a SRV lookup
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
@@ -283,13 +279,11 @@ class MatrixFederationAgent(object):
parsed_uri.host.decode("ascii"),
)
- defer.returnValue(
- _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=target_host,
- target_port=port,
- )
+ return _RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=target_host,
+ target_port=port,
)
@defer.inlineCallbacks
@@ -314,7 +308,7 @@ class MatrixFederationAgent(object):
if cache_period > 0:
self._well_known_cache.set(server_name, result, cache_period)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _do_get_well_known(self, server_name):
@@ -354,7 +348,7 @@ class MatrixFederationAgent(object):
# after startup
cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
- defer.returnValue((None, cache_period))
+ return (None, cache_period)
result = parsed_body["m.server"].encode("ascii")
@@ -369,7 +363,7 @@ class MatrixFederationAgent(object):
else:
cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
- defer.returnValue((result, cache_period))
+ return (result, cache_period)
@implementer(IStreamClientEndpoint)
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index ecc88f9b96..b32188766d 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -120,7 +120,7 @@ class SrvResolver(object):
if cache_entry:
if all(s.expires > now for s in cache_entry):
servers = list(cache_entry)
- defer.returnValue(servers)
+ return servers
try:
answers, _, _ = yield make_deferred_yieldable(
@@ -129,7 +129,7 @@ class SrvResolver(object):
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
# the negative-TTL value.
- defer.returnValue([])
+ return []
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
@@ -138,7 +138,7 @@ class SrvResolver(object):
logger.warn(
"Failed to resolve %r, falling back to cache. %r", service_name, e
)
- defer.returnValue(list(cache_entry))
+ return list(cache_entry)
else:
raise e
@@ -169,4 +169,4 @@ class SrvResolver(object):
)
self._cache[service_name] = list(servers)
- defer.returnValue(servers)
+ return servers
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index e60334547e..d07d356464 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -158,7 +158,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
response.code,
response.phrase.decode("ascii", errors="replace"),
)
- defer.returnValue(body)
+ return body
class MatrixFederationHttpClient(object):
@@ -256,7 +256,7 @@ class MatrixFederationHttpClient(object):
response = yield self._send_request(request, **send_request_args)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
def _send_request(
@@ -520,7 +520,7 @@ class MatrixFederationHttpClient(object):
_flatten_response_never_received(e),
)
raise
- defer.returnValue(response)
+ return response
def build_auth_headers(
self, destination, method, url_bytes, content=None, destination_is=None
@@ -644,7 +644,7 @@ class MatrixFederationHttpClient(object):
self.reactor, self.default_timeout, request, response
)
- defer.returnValue(body)
+ return body
@defer.inlineCallbacks
def post_json(
@@ -713,7 +713,7 @@ class MatrixFederationHttpClient(object):
body = yield _handle_json_response(
self.reactor, _sec_timeout, request, response
)
- defer.returnValue(body)
+ return body
@defer.inlineCallbacks
def get_json(
@@ -778,7 +778,7 @@ class MatrixFederationHttpClient(object):
self.reactor, self.default_timeout, request, response
)
- defer.returnValue(body)
+ return body
@defer.inlineCallbacks
def delete_json(
@@ -836,7 +836,7 @@ class MatrixFederationHttpClient(object):
body = yield _handle_json_response(
self.reactor, self.default_timeout, request, response
)
- defer.returnValue(body)
+ return body
@defer.inlineCallbacks
def get_file(
@@ -902,7 +902,7 @@ class MatrixFederationHttpClient(object):
response.phrase.decode("ascii", errors="replace"),
length,
)
- defer.returnValue((length, headers))
+ return (length, headers)
class _ReadBodyToFileProtocol(protocol.Protocol):
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
new file mode 100644
index 0000000000..332da02a8d
--- /dev/null
+++ b/synapse/http/proxyagent.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import re
+
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.python.failure import Failure
+from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
+from twisted.web.error import SchemeNotSupported
+from twisted.web.iweb import IAgent
+
+from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
+
+logger = logging.getLogger(__name__)
+
+_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
+
+
+@implementer(IAgent)
+class ProxyAgent(_AgentBase):
+ """An Agent implementation which will use an HTTP proxy if one was requested
+
+ Args:
+ reactor: twisted reactor to place outgoing
+ connections.
+
+ contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
+ verification parameters of OpenSSL. The default is to use a
+ `BrowserLikePolicyForHTTPS`, so unless you have special
+ requirements you can leave this as-is.
+
+ connectTimeout (float): The amount of time that this Agent will wait
+ for the peer to accept a connection.
+
+ bindAddress (bytes): The local address for client sockets to bind to.
+
+ pool (HTTPConnectionPool|None): connection pool to be used. If None, a
+ non-persistent pool instance will be created.
+ """
+
+ def __init__(
+ self,
+ reactor,
+ contextFactory=BrowserLikePolicyForHTTPS(),
+ connectTimeout=None,
+ bindAddress=None,
+ pool=None,
+ http_proxy=None,
+ https_proxy=None,
+ ):
+ _AgentBase.__init__(self, reactor, pool)
+
+ self._endpoint_kwargs = {}
+ if connectTimeout is not None:
+ self._endpoint_kwargs["timeout"] = connectTimeout
+ if bindAddress is not None:
+ self._endpoint_kwargs["bindAddress"] = bindAddress
+
+ self.http_proxy_endpoint = _http_proxy_endpoint(
+ http_proxy, reactor, **self._endpoint_kwargs
+ )
+
+ self.https_proxy_endpoint = _http_proxy_endpoint(
+ https_proxy, reactor, **self._endpoint_kwargs
+ )
+
+ self._policy_for_https = contextFactory
+ self._reactor = reactor
+
+ def request(self, method, uri, headers=None, bodyProducer=None):
+ """
+ Issue a request to the server indicated by the given uri.
+
+ Supports `http` and `https` schemes.
+
+ An existing connection from the connection pool may be used or a new one may be
+ created.
+
+ See also: twisted.web.iweb.IAgent.request
+
+ Args:
+ method (bytes): The request method to use, such as `GET`, `POST`, etc
+
+ uri (bytes): The location of the resource to request.
+
+ headers (Headers|None): Extra headers to send with the request
+
+ bodyProducer (IBodyProducer|None): An object which can generate bytes to
+ make up the body of this request (for example, the properly encoded
+ contents of a file for a file upload). Or, None if the request is to
+ have no body.
+
+ Returns:
+ Deferred[IResponse]: completes when the header of the response has
+ been received (regardless of the response status code).
+ """
+ uri = uri.strip()
+ if not _VALID_URI.match(uri):
+ raise ValueError("Invalid URI {!r}".format(uri))
+
+ parsed_uri = URI.fromBytes(uri)
+ pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
+ request_path = parsed_uri.originForm
+
+ if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
+ # Cache *all* connections under the same key, since we are only
+ # connecting to a single destination, the proxy:
+ pool_key = ("http-proxy", self.http_proxy_endpoint)
+ endpoint = self.http_proxy_endpoint
+ request_path = uri
+ elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
+ endpoint = HTTPConnectProxyEndpoint(
+ self._reactor,
+ self.https_proxy_endpoint,
+ parsed_uri.host,
+ parsed_uri.port,
+ )
+ else:
+ # not using a proxy
+ endpoint = HostnameEndpoint(
+ self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
+ )
+
+ logger.debug("Requesting %s via %s", uri, endpoint)
+
+ if parsed_uri.scheme == b"https":
+ tls_connection_creator = self._policy_for_https.creatorForNetloc(
+ parsed_uri.host, parsed_uri.port
+ )
+ endpoint = wrapClientTLS(tls_connection_creator, endpoint)
+ elif parsed_uri.scheme == b"http":
+ pass
+ else:
+ return defer.fail(
+ Failure(
+ SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,))
+ )
+ )
+
+ return self._requestWithEndpoint(
+ pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path
+ )
+
+
+def _http_proxy_endpoint(proxy, reactor, **kwargs):
+ """Parses an http proxy setting and returns an endpoint for the proxy
+
+ Args:
+ proxy (bytes|None): the proxy setting
+ reactor: reactor to be used to connect to the proxy
+ kwargs: other args to be passed to HostnameEndpoint
+
+ Returns:
+ interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
+ or None
+ """
+ if proxy is None:
+ return None
+
+ # currently we only support hostname:port. Some apps also support
+ # protocol://<host>[:port], which allows a way of requiring a TLS connection to the
+ # proxy.
+
+ host, port = parse_host_port(proxy, default_port=1080)
+ return HostnameEndpoint(reactor, host, port, **kwargs)
+
+
+def parse_host_port(hostport, default_port=None):
+ # could have sworn we had one of these somewhere else...
+ if b":" in hostport:
+ host, port = hostport.rsplit(b":", 1)
+ try:
+ port = int(port)
+ return host, port
+ except ValueError:
+ # the thing after the : wasn't a valid port; presumably this is an
+ # IPv6 address.
+ pass
+
+ return hostport, default_port
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 3da33d7826..d2c209c471 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -11,7 +11,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License.import opentracing
+# limitations under the License.
# NOTE
@@ -89,7 +89,7 @@ the function becomes the operation name for the span.
# We start
yield we_wait
# we finish
- defer.returnValue(something_usual_and_useful)
+ return something_usual_and_useful
Operation names can be explicitly set for functions by using
``trace_using_operation_name`` and
@@ -113,7 +113,7 @@ Operation names can be explicitly set for functions by using
# We start
yield we_wait
# we finish
- defer.returnValue(something_usual_and_useful)
+ return something_usual_and_useful
Contexts and carriers
---------------------
@@ -150,10 +150,13 @@ Gotchas
"""
import contextlib
+import inspect
import logging
import re
from functools import wraps
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.config import ConfigError
@@ -173,36 +176,12 @@ except ImportError:
logger = logging.getLogger(__name__)
-class _DumTagNames(object):
- """wrapper of opentracings tags. We need to have them if we
- want to reference them without opentracing around. Clearly they
- should never actually show up in a trace. `set_tags` overwrites
- these with the correct ones."""
+# Block everything by default
+# A regex which matches the server_names to expose traces for.
+# None means 'block everything'.
+_homeserver_whitelist = None
- INVALID_TAG = "invalid-tag"
- COMPONENT = INVALID_TAG
- DATABASE_INSTANCE = INVALID_TAG
- DATABASE_STATEMENT = INVALID_TAG
- DATABASE_TYPE = INVALID_TAG
- DATABASE_USER = INVALID_TAG
- ERROR = INVALID_TAG
- HTTP_METHOD = INVALID_TAG
- HTTP_STATUS_CODE = INVALID_TAG
- HTTP_URL = INVALID_TAG
- MESSAGE_BUS_DESTINATION = INVALID_TAG
- PEER_ADDRESS = INVALID_TAG
- PEER_HOSTNAME = INVALID_TAG
- PEER_HOST_IPV4 = INVALID_TAG
- PEER_HOST_IPV6 = INVALID_TAG
- PEER_PORT = INVALID_TAG
- PEER_SERVICE = INVALID_TAG
- SAMPLING_PRIORITY = INVALID_TAG
- SERVICE = INVALID_TAG
- SPAN_KIND = INVALID_TAG
- SPAN_KIND_CONSUMER = INVALID_TAG
- SPAN_KIND_PRODUCER = INVALID_TAG
- SPAN_KIND_RPC_CLIENT = INVALID_TAG
- SPAN_KIND_RPC_SERVER = INVALID_TAG
+# Util methods
def only_if_tracing(func):
@@ -219,11 +198,13 @@ def only_if_tracing(func):
return _only_if_tracing_inner
-# A regex which matches the server_names to expose traces for.
-# None means 'block everything'.
-_homeserver_whitelist = None
+@contextlib.contextmanager
+def _noop_context_manager(*args, **kwargs):
+ """Does exactly what it says on the tin"""
+ yield
+
-tags = _DumTagNames
+# Setup
def init_tracer(config):
@@ -247,26 +228,55 @@ def init_tracer(config):
# Include the worker name
name = config.worker_name if config.worker_name else "master"
+ # Pull out the jaeger config if it was given. Otherwise set it to something sensible.
+ # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
+
set_homeserver_whitelist(config.opentracer_whitelist)
- jaeger_config = JaegerConfig(
- config={"sampler": {"type": "const", "param": 1}, "logging": True},
+
+ JaegerConfig(
+ config=config.jaeger_config,
service_name="{} {}".format(config.server_name, name),
scope_manager=LogContextScopeManager(config),
- )
- jaeger_config.initialize_tracer()
+ ).initialize_tracer()
# Set up tags to be opentracing's tags
global tags
tags = opentracing.tags
-@contextlib.contextmanager
-def _noop_context_manager(*args, **kwargs):
- """Does absolutely nothing really well. Can be entered and exited arbitrarily.
- Good substitute for an opentracing scope."""
- yield
+# Whitelisting
+
+
+@only_if_tracing
+def set_homeserver_whitelist(homeserver_whitelist):
+ """Sets the homeserver whitelist
+
+ Args:
+ homeserver_whitelist (Iterable[str]): regex of whitelisted homeservers
+ """
+ global _homeserver_whitelist
+ if homeserver_whitelist:
+ # Makes a single regex which accepts all passed in regexes in the list
+ _homeserver_whitelist = re.compile(
+ "({})".format(")|(".join(homeserver_whitelist))
+ )
+
+
+@only_if_tracing
+def whitelisted_homeserver(destination):
+ """Checks if a destination matches the whitelist
+
+ Args:
+ destination (str)
+ """
+ _homeserver_whitelist
+ if _homeserver_whitelist:
+ return _homeserver_whitelist.match(destination)
+ return False
+# Start spans and scopes
+
# Could use kwargs but I want these to be explicit
def start_active_span(
operation_name,
@@ -285,8 +295,10 @@ def start_active_span(
Returns:
scope (Scope) or noop_context_manager
"""
+
if opentracing is None:
return _noop_context_manager()
+
else:
# We need to enter the scope here for the logcontext to become active
return opentracing.tracer.start_active_span(
@@ -300,63 +312,13 @@ def start_active_span(
)
-@only_if_tracing
-def close_active_span():
- """Closes the active span. This will close it's logcontext if the context
- was made for the span"""
- opentracing.tracer.scope_manager.active.__exit__(None, None, None)
-
-
-@only_if_tracing
-def set_tag(key, value):
- """Set's a tag on the active span"""
- opentracing.tracer.active_span.set_tag(key, value)
-
-
-@only_if_tracing
-def log_kv(key_values, timestamp=None):
- """Log to the active span"""
- opentracing.tracer.active_span.log_kv(key_values, timestamp)
-
-
-# Note: we don't have a get baggage items because we're trying to hide all
-# scope and span state from synapse. I think this method may also be useless
-# as a result
-@only_if_tracing
-def set_baggage_item(key, value):
- """Attach baggage to the active span"""
- opentracing.tracer.active_span.set_baggage_item(key, value)
-
-
-@only_if_tracing
-def set_operation_name(operation_name):
- """Sets the operation name of the active span"""
- opentracing.tracer.active_span.set_operation_name(operation_name)
-
-
-@only_if_tracing
-def set_homeserver_whitelist(homeserver_whitelist):
- """Sets the whitelist
-
- Args:
- homeserver_whitelist (iterable of strings): regex of whitelisted homeservers
- """
- global _homeserver_whitelist
- if homeserver_whitelist:
- # Makes a single regex which accepts all passed in regexes in the list
- _homeserver_whitelist = re.compile(
- "({})".format(")|(".join(homeserver_whitelist))
- )
-
-
-@only_if_tracing
-def whitelisted_homeserver(destination):
- """Checks if a destination matches the whitelist
- Args:
- destination (String)"""
- if _homeserver_whitelist:
- return _homeserver_whitelist.match(destination)
- return False
+def start_active_span_follows_from(operation_name, contexts):
+ if opentracing is None:
+ return _noop_context_manager()
+ else:
+ references = [opentracing.follows_from(context) for context in contexts]
+ scope = start_active_span(operation_name, references=references)
+ return scope
def start_active_span_from_context(
@@ -372,12 +334,16 @@ def start_active_span_from_context(
Extracts a span context from Twisted Headers.
args:
headers (twisted.web.http_headers.Headers)
+
+ For the other args see opentracing.tracer
+
returns:
span_context (opentracing.span.SpanContext)
"""
# Twisted encodes the values as lists whereas opentracing doesn't.
# So, we take the first item in the list.
# Also, twisted uses byte arrays while opentracing expects strings.
+
if opentracing is None:
return _noop_context_manager()
@@ -395,17 +361,90 @@ def start_active_span_from_context(
)
+def start_active_span_from_edu(
+ edu_content,
+ operation_name,
+ references=[],
+ tags=None,
+ start_time=None,
+ ignore_active_span=False,
+ finish_on_close=True,
+):
+ """
+ Extracts a span context from an edu and uses it to start a new active span
+
+ Args:
+ edu_content (dict): and edu_content with a `context` field whose value is
+ canonical json for a dict which contains opentracing information.
+
+ For the other args see opentracing.tracer
+ """
+
+ if opentracing is None:
+ return _noop_context_manager()
+
+ carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
+ context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+ _references = [
+ opentracing.child_of(span_context_from_string(x))
+ for x in carrier.get("references", [])
+ ]
+
+ # For some reason jaeger decided not to support the visualization of multiple parent
+ # spans or explicitely show references. I include the span context as a tag here as
+ # an aid to people debugging but it's really not an ideal solution.
+
+ references += _references
+
+ scope = opentracing.tracer.start_active_span(
+ operation_name,
+ child_of=context,
+ references=references,
+ tags=tags,
+ start_time=start_time,
+ ignore_active_span=ignore_active_span,
+ finish_on_close=finish_on_close,
+ )
+
+ scope.span.set_tag("references", carrier.get("references", []))
+ return scope
+
+
+# Opentracing setters for tags, logs, etc
+
+
+@only_if_tracing
+def set_tag(key, value):
+ """Sets a tag on the active span"""
+ opentracing.tracer.active_span.set_tag(key, value)
+
+
+@only_if_tracing
+def log_kv(key_values, timestamp=None):
+ """Log to the active span"""
+ opentracing.tracer.active_span.log_kv(key_values, timestamp)
+
+
+@only_if_tracing
+def set_operation_name(operation_name):
+ """Sets the operation name of the active span"""
+ opentracing.tracer.active_span.set_operation_name(operation_name)
+
+
+# Injection and extraction
+
+
@only_if_tracing
def inject_active_span_twisted_headers(headers, destination):
"""
- Injects a span context into twisted headers inplace
+ Injects a span context into twisted headers in-place
Args:
headers (twisted.web.http_headers.Headers)
span (opentracing.Span)
Returns:
- Inplace modification of headers
+ In-place modification of headers
Note:
The headers set by the tracer are custom to the tracer implementation which
@@ -437,7 +476,7 @@ def inject_active_span_byte_dict(headers, destination):
span (opentracing.Span)
Returns:
- Inplace modification of headers
+ In-place modification of headers
Note:
The headers set by the tracer are custom to the tracer implementation which
@@ -458,15 +497,195 @@ def inject_active_span_byte_dict(headers, destination):
headers[key.encode()] = [value.encode()]
+@only_if_tracing
+def inject_active_span_text_map(carrier, destination=None):
+ """
+ Injects a span context into a dict
+
+ Args:
+ carrier (dict)
+ destination (str): the name of the remote server. The span context
+ will only be injected if the destination matches the homeserver_whitelist
+ or destination is None.
+
+ Returns:
+ In-place modification of carrier
+
+ Note:
+ The headers set by the tracer are custom to the tracer implementation which
+ should be unique enough that they don't interfere with any headers set by
+ synapse or twisted. If we're still using jaeger these headers would be those
+ here:
+ https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
+ """
+
+ if destination and not whitelisted_homeserver(destination):
+ return
+
+ opentracing.tracer.inject(
+ opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+ )
+
+
+def active_span_context_as_string():
+ """
+ Returns:
+ The active span context encoded as a string.
+ """
+ carrier = {}
+ if opentracing:
+ opentracing.tracer.inject(
+ opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+ )
+ return json.dumps(carrier)
+
+
+@only_if_tracing
+def span_context_from_string(carrier):
+ """
+ Returns:
+ The active span context decoded from a string.
+ """
+ carrier = json.loads(carrier)
+ return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+
+
+@only_if_tracing
+def extract_text_map(carrier):
+ """
+ Wrapper method for opentracing's tracer.extract for TEXT_MAP.
+ Args:
+ carrier (dict): a dict possibly containing a span context.
+
+ Returns:
+ The active span context extracted from carrier.
+ """
+ return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+
+
+# Tracing decorators
+
+
+def trace(func):
+ """
+ Decorator to trace a function.
+ Sets the operation name to that of the function's.
+ """
+ if opentracing is None:
+ return func
+
+ @wraps(func)
+ def _trace_inner(self, *args, **kwargs):
+ if opentracing is None:
+ return func(self, *args, **kwargs)
+
+ scope = start_active_span(func.__name__)
+ scope.__enter__()
+
+ try:
+ result = func(self, *args, **kwargs)
+ if isinstance(result, defer.Deferred):
+
+ def call_back(result):
+ scope.__exit__(None, None, None)
+ return result
+
+ def err_back(result):
+ scope.span.set_tag(tags.ERROR, True)
+ scope.__exit__(None, None, None)
+ return result
+
+ result.addCallbacks(call_back, err_back)
+
+ else:
+ scope.__exit__(None, None, None)
+
+ return result
+
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
+
+ return _trace_inner
+
+
+def trace_using_operation_name(operation_name):
+ """Decorator to trace a function. Explicitely sets the operation_name."""
+
+ def trace(func):
+ """
+ Decorator to trace a function.
+ Sets the operation name to that of the function's.
+ """
+ if opentracing is None:
+ return func
+
+ @wraps(func)
+ def _trace_inner(self, *args, **kwargs):
+ if opentracing is None:
+ return func(self, *args, **kwargs)
+
+ scope = start_active_span(operation_name)
+ scope.__enter__()
+
+ try:
+ result = func(self, *args, **kwargs)
+ if isinstance(result, defer.Deferred):
+
+ def call_back(result):
+ scope.__exit__(None, None, None)
+ return result
+
+ def err_back(result):
+ scope.span.set_tag(tags.ERROR, True)
+ scope.__exit__(None, None, None)
+ return result
+
+ result.addCallbacks(call_back, err_back)
+ else:
+ scope.__exit__(None, None, None)
+
+ return result
+
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
+
+ return _trace_inner
+
+ return trace
+
+
+def tag_args(func):
+ """
+ Tags all of the args to the active span.
+ """
+
+ if not opentracing:
+ return func
+
+ @wraps(func)
+ def _tag_args_inner(self, *args, **kwargs):
+ argspec = inspect.getargspec(func)
+ for i, arg in enumerate(argspec.args[1:]):
+ set_tag("ARG_" + arg, args[i])
+ set_tag("args", args[len(argspec.args) :])
+ set_tag("kwargs", kwargs)
+ return func(self, *args, **kwargs)
+
+ return _tag_args_inner
+
+
def trace_servlet(servlet_name, func):
"""Decorator which traces a serlet. It starts a span with some servlet specific
tags such as the servlet_name and request information"""
+ if not opentracing:
+ return func
@wraps(func)
@defer.inlineCallbacks
def _trace_servlet_inner(request, *args, **kwargs):
- with start_active_span_from_context(
- request.requestHeaders,
+ with start_active_span(
"incoming-client-request",
tags={
"request_id": request.get_request_id(),
@@ -478,6 +697,44 @@ def trace_servlet(servlet_name, func):
},
):
result = yield defer.maybeDeferred(func, request, *args, **kwargs)
- defer.returnValue(result)
+ return result
return _trace_servlet_inner
+
+
+# Helper class
+
+
+class _DummyTagNames(object):
+ """wrapper of opentracings tags. We need to have them if we
+ want to reference them without opentracing around. Clearly they
+ should never actually show up in a trace. `set_tags` overwrites
+ these with the correct ones."""
+
+ INVALID_TAG = "invalid-tag"
+ COMPONENT = INVALID_TAG
+ DATABASE_INSTANCE = INVALID_TAG
+ DATABASE_STATEMENT = INVALID_TAG
+ DATABASE_TYPE = INVALID_TAG
+ DATABASE_USER = INVALID_TAG
+ ERROR = INVALID_TAG
+ HTTP_METHOD = INVALID_TAG
+ HTTP_STATUS_CODE = INVALID_TAG
+ HTTP_URL = INVALID_TAG
+ MESSAGE_BUS_DESTINATION = INVALID_TAG
+ PEER_ADDRESS = INVALID_TAG
+ PEER_HOSTNAME = INVALID_TAG
+ PEER_HOST_IPV4 = INVALID_TAG
+ PEER_HOST_IPV6 = INVALID_TAG
+ PEER_PORT = INVALID_TAG
+ PEER_SERVICE = INVALID_TAG
+ SAMPLING_PRIORITY = INVALID_TAG
+ SERVICE = INVALID_TAG
+ SPAN_KIND = INVALID_TAG
+ SPAN_KIND_CONSUMER = INVALID_TAG
+ SPAN_KIND_PRODUCER = INVALID_TAG
+ SPAN_KIND_RPC_CLIENT = INVALID_TAG
+ SPAN_KIND_RPC_SERVER = INVALID_TAG
+
+
+tags = _DummyTagNames
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index 8c661302c9..4eed4f2338 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -131,7 +131,7 @@ class _LogContextScope(Scope):
def close(self):
if self.manager.active is not self:
- logger.error("Tried to close a none active scope!")
+ logger.error("Tried to close a non-active scope!")
return
if self._finish_on_close:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 7bb020cb45..41147d4292 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -101,7 +101,7 @@ class ModuleApi(object):
)
user_id = yield self.register_user(localpart, displayname, emails)
_, access_token = yield self.register_device(user_id)
- defer.returnValue((user_id, access_token))
+ return (user_id, access_token)
def register_user(self, localpart, displayname=None, emails=[]):
"""Registers a new user with given localpart and optional displayname, emails.
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 918ef64897..bd80c801b6 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -365,7 +365,7 @@ class Notifier(object):
current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def get_events_for(
@@ -400,7 +400,7 @@ class Notifier(object):
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
- defer.returnValue(EventStreamResult([], (from_token, from_token)))
+ return EventStreamResult([], (from_token, from_token))
events = []
end_token = from_token
@@ -440,7 +440,7 @@ class Notifier(object):
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
- defer.returnValue(EventStreamResult(events, (from_token, end_token)))
+ return EventStreamResult(events, (from_token, end_token))
user_id_for_stream = user.to_string()
if is_peeking:
@@ -465,18 +465,18 @@ class Notifier(object):
from_token=from_token,
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _get_room_ids(self, user, explicit_room_id):
joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
- defer.returnValue(([explicit_room_id], True))
+ return ([explicit_room_id], True)
if (yield self._is_world_readable(explicit_room_id)):
- defer.returnValue(([explicit_room_id], False))
+ return ([explicit_room_id], False)
raise AuthError(403, "Non-joined access not allowed")
- defer.returnValue((joined_room_ids, True))
+ return (joined_room_ids, True)
@defer.inlineCallbacks
def _is_world_readable(self, room_id):
@@ -484,9 +484,9 @@ class Notifier(object):
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:
- defer.returnValue(state.content["history_visibility"] == "world_readable")
+ return state.content["history_visibility"] == "world_readable"
else:
- defer.returnValue(False)
+ return False
@log_function
def remove_expired_streams(self):
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index c8a5b381da..c831975635 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -95,7 +95,7 @@ class BulkPushRuleEvaluator(object):
invited
)
- defer.returnValue(rules_by_user)
+ return rules_by_user
@cached()
def _get_rules_for_room(self, room_id):
@@ -134,7 +134,7 @@ class BulkPushRuleEvaluator(object):
pl_event = auth_events.get(POWER_KEY)
- defer.returnValue((pl_event.content if pl_event else {}, sender_level))
+ return (pl_event.content if pl_event else {}, sender_level)
@defer.inlineCallbacks
def action_for_event_by_user(self, event, context):
@@ -283,13 +283,13 @@ class RulesForRoom(object):
if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
- defer.returnValue(self.rules_by_user)
+ return self.rules_by_user
with (yield self.linearizer.queue(())):
if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
- defer.returnValue(self.rules_by_user)
+ return self.rules_by_user
self.room_push_rule_cache_metrics.inc_misses()
@@ -366,7 +366,7 @@ class RulesForRoom(object):
logger.debug(
"Returning push rules for %r %r", self.room_id, ret_rules_by_user.keys()
)
- defer.returnValue(ret_rules_by_user)
+ return ret_rules_by_user
@defer.inlineCallbacks
def _update_rules_with_member_event_ids(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 4e7b6a5531..454297e6a9 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -101,7 +101,7 @@ class HttpPusher(object):
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
self.url = self.data["url"]
- self.http_client = hs.get_simple_http_client()
+ self.http_client = hs.get_proxied_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]
@@ -258,17 +258,17 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _process_one(self, push_action):
if "notify" not in push_action["actions"]:
- defer.returnValue(True)
+ return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
event = yield self.store.get_event(push_action["event_id"], allow_none=True)
if event is None:
- defer.returnValue(True) # It's been redacted
+ return True # It's been redacted
rejected = yield self.dispatch_push(event, tweaks, badge)
if rejected is False:
- defer.returnValue(False)
+ return False
if isinstance(rejected, list) or isinstance(rejected, tuple):
for pk in rejected:
@@ -282,7 +282,7 @@ class HttpPusher(object):
else:
logger.info("Pushkey %s was rejected: removing", pk)
yield self.hs.remove_pusher(self.app_id, pk, self.user_id)
- defer.returnValue(True)
+ return True
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
@@ -302,7 +302,7 @@ class HttpPusher(object):
],
}
}
- defer.returnValue(d)
+ return d
ctx = yield push_tools.get_context_for_event(
self.store, self.state_handler, event, self.user_id
@@ -345,13 +345,13 @@ class HttpPusher(object):
if "name" in ctx and len(ctx["name"]) > 0:
d["notification"]["room_name"] = ctx["name"]
- defer.returnValue(d)
+ return d
@defer.inlineCallbacks
def dispatch_push(self, event, tweaks, badge):
notification_dict = yield self._build_notification_dict(event, tweaks, badge)
if not notification_dict:
- defer.returnValue([])
+ return []
try:
resp = yield self.http_client.post_json_get_json(
self.url, notification_dict
@@ -364,11 +364,11 @@ class HttpPusher(object):
type(e),
e,
)
- defer.returnValue(False)
+ return False
rejected = []
if "rejected" in resp:
rejected = resp["rejected"]
- defer.returnValue(rejected)
+ return rejected
@defer.inlineCallbacks
def _send_badge(self, badge):
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 521c6e2cd7..4245ce26f3 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -316,7 +316,7 @@ class Mailer(object):
if not merge:
room_vars["notifs"].append(notifvars)
- defer.returnValue(room_vars)
+ return room_vars
@defer.inlineCallbacks
def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
@@ -343,7 +343,7 @@ class Mailer(object):
if messagevars is not None:
ret["messages"].append(messagevars)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_message_vars(self, notif, event, room_state_ids):
@@ -379,7 +379,7 @@ class Mailer(object):
if "body" in event.content:
ret["body_text_plain"] = event.content["body"]
- defer.returnValue(ret)
+ return ret
def add_text_message_vars(self, messagevars, event):
msgformat = event.content.get("format")
@@ -428,19 +428,16 @@ class Mailer(object):
inviter_name = name_from_member_event(inviter_member_event)
if room_name is None:
- defer.returnValue(
- INVITE_FROM_PERSON
- % {"person": inviter_name, "app": self.app_name}
- )
+ return INVITE_FROM_PERSON % {
+ "person": inviter_name,
+ "app": self.app_name,
+ }
else:
- defer.returnValue(
- INVITE_FROM_PERSON_TO_ROOM
- % {
- "person": inviter_name,
- "room": room_name,
- "app": self.app_name,
- }
- )
+ return INVITE_FROM_PERSON_TO_ROOM % {
+ "person": inviter_name,
+ "room": room_name,
+ "app": self.app_name,
+ }
sender_name = None
if len(notifs_by_room[room_id]) == 1:
@@ -454,26 +451,21 @@ class Mailer(object):
sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None:
- defer.returnValue(
- MESSAGE_FROM_PERSON_IN_ROOM
- % {
- "person": sender_name,
- "room": room_name,
- "app": self.app_name,
- }
- )
+ return MESSAGE_FROM_PERSON_IN_ROOM % {
+ "person": sender_name,
+ "room": room_name,
+ "app": self.app_name,
+ }
elif sender_name is not None:
- defer.returnValue(
- MESSAGE_FROM_PERSON
- % {"person": sender_name, "app": self.app_name}
- )
+ return MESSAGE_FROM_PERSON % {
+ "person": sender_name,
+ "app": self.app_name,
+ }
else:
# There's more than one notification for this room, so just
# say there are several
if room_name is not None:
- defer.returnValue(
- MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name}
- )
+ return MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name}
else:
# If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
@@ -493,24 +485,19 @@ class Mailer(object):
]
)
- defer.returnValue(
- MESSAGES_FROM_PERSON
- % {
- "person": descriptor_from_member_events(
- member_events.values()
- ),
- "app": self.app_name,
- }
- )
+ return MESSAGES_FROM_PERSON % {
+ "person": descriptor_from_member_events(member_events.values()),
+ "app": self.app_name,
+ }
else:
# Stuff's happened in multiple different rooms
# ...but we still refer to the 'reason' room which triggered the mail
if reason["room_name"] is not None:
- defer.returnValue(
- MESSAGES_IN_ROOM_AND_OTHERS
- % {"room": reason["room_name"], "app": self.app_name}
- )
+ return MESSAGES_IN_ROOM_AND_OTHERS % {
+ "room": reason["room_name"],
+ "app": self.app_name,
+ }
else:
# If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
@@ -527,13 +514,10 @@ class Mailer(object):
[room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
)
- defer.returnValue(
- MESSAGES_FROM_PERSON_AND_OTHERS
- % {
- "person": descriptor_from_member_events(member_events.values()),
- "app": self.app_name,
- }
- )
+ return MESSAGES_FROM_PERSON_AND_OTHERS % {
+ "person": descriptor_from_member_events(member_events.values()),
+ "app": self.app_name,
+ }
def make_room_link(self, room_id):
if self.hs.config.email_riot_base_url:
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 06056fbf4f..16a7e8e31d 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -55,7 +55,7 @@ def calculate_room_name(
room_state_ids[("m.room.name", "")], allow_none=True
)
if m_room_name and m_room_name.content and m_room_name.content["name"]:
- defer.returnValue(m_room_name.content["name"])
+ return m_room_name.content["name"]
# does it have a canonical alias?
if ("m.room.canonical_alias", "") in room_state_ids:
@@ -68,7 +68,7 @@ def calculate_room_name(
and canon_alias.content["alias"]
and _looks_like_an_alias(canon_alias.content["alias"])
):
- defer.returnValue(canon_alias.content["alias"])
+ return canon_alias.content["alias"]
# at this point we're going to need to search the state by all state keys
# for an event type, so rearrange the data structure
@@ -82,10 +82,10 @@ def calculate_room_name(
if alias_event and alias_event.content.get("aliases"):
the_aliases = alias_event.content["aliases"]
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
- defer.returnValue(the_aliases[0])
+ return the_aliases[0]
if not fallback_to_members:
- defer.returnValue(None)
+ return None
my_member_event = None
if ("m.room.member", user_id) in room_state_ids:
@@ -104,14 +104,13 @@ def calculate_room_name(
)
if inviter_member_event:
if fallback_to_single_member:
- defer.returnValue(
- "Invite from %s"
- % (name_from_member_event(inviter_member_event),)
+ return "Invite from %s" % (
+ name_from_member_event(inviter_member_event),
)
else:
return
else:
- defer.returnValue("Room Invite")
+ return "Room Invite"
# we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user.
@@ -154,17 +153,17 @@ def calculate_room_name(
# return "Inviting %s" % (
# descriptor_from_member_events(third_party_invites)
# )
- defer.returnValue("Inviting email address")
+ return "Inviting email address"
else:
- defer.returnValue(ALL_ALONE)
+ return ALL_ALONE
else:
- defer.returnValue(name_from_member_event(all_members[0]))
+ return name_from_member_event(all_members[0])
else:
- defer.returnValue(ALL_ALONE)
+ return ALL_ALONE
elif len(other_members) == 1 and not fallback_to_single_member:
return
else:
- defer.returnValue(descriptor_from_member_events(other_members))
+ return descriptor_from_member_events(other_members)
def descriptor_from_member_events(member_events):
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index e37269cdb9..a54051a726 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -39,7 +39,7 @@ def get_badge_count(store, user_id):
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
badge += 1 if notifs["notify_count"] else 0
- defer.returnValue(badge)
+ return badge
@defer.inlineCallbacks
@@ -61,4 +61,4 @@ def get_context_for_event(store, state_handler, ev, user_id):
sender_state_event = yield store.get_event(sender_state_event_id)
ctx["sender_display_name"] = name_from_member_event(sender_state_event)
- defer.returnValue(ctx)
+ return ctx
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index df6f670740..08e840fdc2 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -123,7 +123,7 @@ class PusherPool:
)
pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id)
- defer.returnValue(pusher)
+ return pusher
@defer.inlineCallbacks
def remove_pushers_by_app_id_and_pushkey_not_user(
@@ -224,7 +224,7 @@ class PusherPool:
if pusher_dict:
pusher = yield self._start_pusher(pusher_dict)
- defer.returnValue(pusher)
+ return pusher
@defer.inlineCallbacks
def _start_pushers(self):
@@ -293,7 +293,7 @@ class PusherPool:
p.on_started(have_notifs)
- defer.returnValue(p)
+ return p
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_id):
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index c6465c0386..195a7a70c8 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -72,6 +72,7 @@ REQUIREMENTS = [
"netaddr>=0.7.18",
"Jinja2>=2.9",
"bleach>=1.4.3",
+ "sdnotify>=0.3",
]
CONDITIONAL_REQUIREMENTS = {
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 43c89e36dd..2e0594e581 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -185,7 +185,7 @@ class ReplicationEndpoint(object):
except RequestSendFailed as e:
raise_from(SynapseError(502, "Failed to talk to master"), e)
- defer.returnValue(result)
+ return result
return send_request
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 61eafbe708..fed4f08820 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -80,7 +80,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
payload = {"events": event_payloads, "backfilled": backfilled}
- defer.returnValue(payload)
+ return payload
@defer.inlineCallbacks
def _handle_request(self, request):
@@ -113,7 +113,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_and_contexts, backfilled
)
- defer.returnValue((200, {}))
+ return (200, {})
class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
@@ -156,7 +156,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
result = yield self.registry.on_edu(edu_type, origin, edu_content)
- defer.returnValue((200, result))
+ return (200, result)
class ReplicationGetQueryRestServlet(ReplicationEndpoint):
@@ -204,7 +204,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
result = yield self.registry.on_query(query_type, args)
- defer.returnValue((200, result))
+ return (200, result)
class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
@@ -238,7 +238,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
def _handle_request(self, request, room_id):
yield self.store.clean_room_for_join(room_id)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 7c1197e5dd..f17d3a2da4 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -64,7 +64,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
user_id, device_id, initial_display_name, is_guest
)
- defer.returnValue((200, {"device_id": device_id, "access_token": access_token}))
+ return (200, {"device_id": device_id, "access_token": access_token})
def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 2d9cbbaefc..4217335d88 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -83,7 +83,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
remote_room_hosts, room_id, user_id, event_content
)
- defer.returnValue((200, {}))
+ return (200, {})
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
@@ -153,7 +153,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
yield self.store.locally_reject_invite(user_id, room_id)
ret = {}
- defer.returnValue((200, ret))
+ return (200, ret)
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 2bf2173895..3341320a87 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -90,7 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
address=content["address"],
)
- defer.returnValue((200, {}))
+ return (200, {})
class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
@@ -143,7 +143,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
bind_msisdn=bind_msisdn,
)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 034763fe99..eff7bd7305 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -85,7 +85,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"extra_users": [u.to_string() for u in extra_users],
}
- defer.returnValue(payload)
+ return payload
@defer.inlineCallbacks
def _handle_request(self, request, event_id):
@@ -117,7 +117,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 7ef67a5a73..c10b85d2ff 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -158,7 +158,7 @@ class Stream(object):
updates, current_token = yield self.get_updates_since(self.last_token)
self.last_token = current_token
- defer.returnValue((updates, current_token))
+ return (updates, current_token)
@defer.inlineCallbacks
def get_updates_since(self, from_token):
@@ -172,14 +172,14 @@ class Stream(object):
sent over the replication steam.
"""
if from_token in ("NOW", "now"):
- defer.returnValue(([], self.upto_token))
+ return ([], self.upto_token)
current_token = self.upto_token
from_token = int(from_token)
if from_token == current_token:
- defer.returnValue(([], current_token))
+ return ([], current_token)
if self._LIMITED:
rows = yield self.update_function(
@@ -198,7 +198,7 @@ class Stream(object):
if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME))
- defer.returnValue((updates, current_token))
+ return (updates, current_token)
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -297,7 +297,7 @@ class PushRulesStream(Stream):
@defer.inlineCallbacks
def update_function(self, from_token, to_token, limit):
rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
- defer.returnValue([(row[0], row[2]) for row in rows])
+ return [(row[0], row[2]) for row in rows]
class PushersStream(Stream):
@@ -424,7 +424,7 @@ class AccountDataStream(Stream):
for stream_id, user_id, account_data_type, content in global_results
)
- defer.returnValue(results)
+ return results
class GroupServerStream(Stream):
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 3d0694bb11..d97669c886 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -134,7 +134,7 @@ class EventsStream(Stream):
all_updates = heapq.merge(event_updates, state_updates)
- defer.returnValue(all_updates)
+ return all_updates
@classmethod
def parse_row(cls, row):
diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html
new file mode 100644
index 0000000000..894da030af
--- /dev/null
+++ b/synapse/res/templates/account_renewed.html
@@ -0,0 +1 @@
+<html><body>Your account has been successfully renewed.</body><html>
diff --git a/synapse/res/templates/invalid_token.html b/synapse/res/templates/invalid_token.html
new file mode 100644
index 0000000000..6bd2b98364
--- /dev/null
+++ b/synapse/res/templates/invalid_token.html
@@ -0,0 +1 @@
+<html><body>Invalid renewal token.</body><html>
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 1d20b96d03..f161bc51a5 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import (
keys,
notifications,
openid,
+ password_policy,
read_marker,
receipts,
register,
@@ -117,6 +118,7 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
+ password_policy.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
# moving to /_synapse/admin
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 6888ae5590..0a7d9b81b2 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -84,7 +84,7 @@ class UsersRestServlet(RestServlet):
ret = yield self.handlers.admin_handler.get_users()
- defer.returnValue((200, ret))
+ return (200, ret)
class VersionServlet(RestServlet):
@@ -227,7 +227,7 @@ class UserRegisterServlet(RestServlet):
)
result = yield register._create_registration_details(user_id, body)
- defer.returnValue((200, result))
+ return (200, result)
class WhoisRestServlet(RestServlet):
@@ -252,7 +252,7 @@ class WhoisRestServlet(RestServlet):
ret = yield self.handlers.admin_handler.get_whois(target_user)
- defer.returnValue((200, ret))
+ return (200, ret)
class PurgeMediaCacheRestServlet(RestServlet):
@@ -271,7 +271,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
ret = yield self.media_repository.delete_old_remote_media(before_ts)
- defer.returnValue((200, ret))
+ return (200, ret)
class PurgeHistoryRestServlet(RestServlet):
@@ -356,7 +356,7 @@ class PurgeHistoryRestServlet(RestServlet):
room_id, token, delete_local_events=delete_local_events
)
- defer.returnValue((200, {"purge_id": purge_id}))
+ return (200, {"purge_id": purge_id})
class PurgeHistoryStatusRestServlet(RestServlet):
@@ -381,7 +381,7 @@ class PurgeHistoryStatusRestServlet(RestServlet):
if purge_status is None:
raise NotFoundError("purge id '%s' not found" % purge_id)
- defer.returnValue((200, purge_status.asdict()))
+ return (200, purge_status.asdict())
class DeactivateAccountRestServlet(RestServlet):
@@ -413,7 +413,7 @@ class DeactivateAccountRestServlet(RestServlet):
else:
id_server_unbind_result = "no-support"
- defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result}))
+ return (200, {"id_server_unbind_result": id_server_unbind_result})
class ShutdownRoomRestServlet(RestServlet):
@@ -531,16 +531,14 @@ class ShutdownRoomRestServlet(RestServlet):
room_id, new_room_id, requester_user_id
)
- defer.returnValue(
- (
- 200,
- {
- "kicked_users": kicked_users,
- "failed_to_kick_users": failed_to_kick_users,
- "local_aliases": aliases_for_room,
- "new_room_id": new_room_id,
- },
- )
+ return (
+ 200,
+ {
+ "kicked_users": kicked_users,
+ "failed_to_kick_users": failed_to_kick_users,
+ "local_aliases": aliases_for_room,
+ "new_room_id": new_room_id,
+ },
)
@@ -564,7 +562,7 @@ class QuarantineMediaInRoom(RestServlet):
room_id, requester.user.to_string()
)
- defer.returnValue((200, {"num_quarantined": num_quarantined}))
+ return (200, {"num_quarantined": num_quarantined})
class ListMediaInRoom(RestServlet):
@@ -585,7 +583,7 @@ class ListMediaInRoom(RestServlet):
local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
- defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+ return (200, {"local": local_mxcs, "remote": remote_mxcs})
class ResetPasswordRestServlet(RestServlet):
@@ -629,7 +627,7 @@ class ResetPasswordRestServlet(RestServlet):
yield self._set_password_handler.set_password(
target_user_id, new_password, requester
)
- defer.returnValue((200, {}))
+ return (200, {})
class GetUsersPaginatedRestServlet(RestServlet):
@@ -671,7 +669,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
- defer.returnValue((200, ret))
+ return (200, ret)
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
@@ -699,7 +697,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
- defer.returnValue((200, ret))
+ return (200, ret)
class SearchUsersRestServlet(RestServlet):
@@ -742,7 +740,7 @@ class SearchUsersRestServlet(RestServlet):
logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users(term)
- defer.returnValue((200, ret))
+ return (200, ret)
class DeleteGroupAdminRestServlet(RestServlet):
@@ -765,7 +763,7 @@ class DeleteGroupAdminRestServlet(RestServlet):
raise SynapseError(400, "Can only delete local groups")
yield self.group_server.delete_group(group_id, requester.user.to_string())
- defer.returnValue((200, {}))
+ return (200, {})
class AccountValidityRenewServlet(RestServlet):
@@ -796,7 +794,7 @@ class AccountValidityRenewServlet(RestServlet):
)
res = {"expiration_ts": expiration_ts}
- defer.returnValue((200, res))
+ return (200, res)
########################################################################################
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index d9c71261f2..656526fea5 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -92,7 +92,7 @@ class SendServerNoticeServlet(RestServlet):
event_content=body["content"],
)
- defer.returnValue((200, {"event_id": event.event_id}))
+ return (200, {"event_id": event.event_id})
def on_PUT(self, request, txn_id):
return self.txns.fetch_or_execute_request(
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 57542c2b4b..4284738021 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -54,7 +54,7 @@ class ClientDirectoryServer(RestServlet):
dir_handler = self.handlers.directory_handler
res = yield dir_handler.get_association(room_alias)
- defer.returnValue((200, res))
+ return (200, res)
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
@@ -87,7 +87,7 @@ class ClientDirectoryServer(RestServlet):
requester, room_alias, room_id, servers
)
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_DELETE(self, request, room_alias):
@@ -102,7 +102,7 @@ class ClientDirectoryServer(RestServlet):
service.url,
room_alias.to_string(),
)
- defer.returnValue((200, {}))
+ return (200, {})
except InvalidClientCredentialsError:
# fallback to default user behaviour if they aren't an AS
pass
@@ -118,7 +118,7 @@ class ClientDirectoryServer(RestServlet):
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
)
- defer.returnValue((200, {}))
+ return (200, {})
class ClientDirectoryListServer(RestServlet):
@@ -136,9 +136,7 @@ class ClientDirectoryListServer(RestServlet):
if room is None:
raise NotFoundError("Unknown room")
- defer.returnValue(
- (200, {"visibility": "public" if room["is_public"] else "private"})
- )
+ return (200, {"visibility": "public" if room["is_public"] else "private"})
@defer.inlineCallbacks
def on_PUT(self, request, room_id):
@@ -151,7 +149,7 @@ class ClientDirectoryListServer(RestServlet):
requester, room_id, visibility
)
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_DELETE(self, request, room_id):
@@ -161,7 +159,7 @@ class ClientDirectoryListServer(RestServlet):
requester, room_id, "private"
)
- defer.returnValue((200, {}))
+ return (200, {})
class ClientAppserviceDirectoryListServer(RestServlet):
@@ -195,4 +193,4 @@ class ClientAppserviceDirectoryListServer(RestServlet):
requester.app_service.id, network_id, room_id, visibility
)
- defer.returnValue((200, {}))
+ return (200, {})
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index d6de2b7360..53ebed2203 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -67,7 +67,7 @@ class EventStreamRestServlet(RestServlet):
is_guest=is_guest,
)
- defer.returnValue((200, chunk))
+ return (200, chunk)
def on_OPTIONS(self, request):
return (200, {})
@@ -91,9 +91,9 @@ class EventRestServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
event = yield self._event_serializer.serialize_event(event, time_now)
- defer.returnValue((200, event))
+ return (200, event)
else:
- defer.returnValue((404, "Event not found."))
+ return (404, "Event not found.")
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 0fe5f2d79b..70b8478e90 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -42,7 +42,7 @@ class InitialSyncRestServlet(RestServlet):
include_archived=include_archived,
)
- defer.returnValue((200, content))
+ return (200, content)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 0d05945f0a..4ddf194fd7 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -152,7 +152,7 @@ class LoginRestServlet(RestServlet):
well_known_data = self._well_known_builder.get_well_known()
if well_known_data:
result["well_known"] = well_known_data
- defer.returnValue((200, result))
+ return (200, result)
@defer.inlineCallbacks
def _do_other_login(self, login_submission):
@@ -212,7 +212,7 @@ class LoginRestServlet(RestServlet):
result = yield self._register_device_with_callback(
canonical_user_id, login_submission, callback_3pid
)
- defer.returnValue(result)
+ return result
# No password providers were able to handle this 3pid
# Check local store
@@ -241,7 +241,7 @@ class LoginRestServlet(RestServlet):
result = yield self._register_device_with_callback(
canonical_user_id, login_submission, callback
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _register_device_with_callback(self, user_id, login_submission, callback=None):
@@ -273,7 +273,7 @@ class LoginRestServlet(RestServlet):
if callback is not None:
yield callback(result)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def do_token_login(self, login_submission):
@@ -284,7 +284,7 @@ class LoginRestServlet(RestServlet):
)
result = yield self._register_device_with_callback(user_id, login_submission)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
@@ -321,7 +321,7 @@ class LoginRestServlet(RestServlet):
result = yield self._register_device_with_callback(
registered_user_id, login_submission
)
- defer.returnValue(result)
+ return result
class BaseSSORedirectServlet(RestServlet):
@@ -378,7 +378,7 @@ class CasTicketServlet(RestServlet):
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
- self._http_client = hs.get_simple_http_client()
+ self._http_client = hs.get_proxied_http_client()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -395,7 +395,7 @@ class CasTicketServlet(RestServlet):
# even if that's being used old-http style to signal end-of-data
body = pde.response
result = yield self.handle_cas_response(request, body, client_redirect_url)
- defer.returnValue(result)
+ return result
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index cd711be519..2769f3a189 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -49,7 +49,7 @@ class LogoutRestServlet(RestServlet):
requester.user.to_string(), requester.device_id
)
- defer.returnValue((200, {}))
+ return (200, {})
class LogoutAllRestServlet(RestServlet):
@@ -75,7 +75,7 @@ class LogoutAllRestServlet(RestServlet):
# .. and then delete any access tokens which weren't associated with
# devices.
yield self._auth_handler.delete_access_tokens_for_user(user_id)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 3e87f0fdb3..1eb1068c98 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -56,7 +56,7 @@ class PresenceStatusRestServlet(RestServlet):
state = yield self.presence_handler.get_state(target_user=user)
state = format_user_presence_state(state, self.clock.time_msec())
- defer.returnValue((200, state))
+ return (200, state)
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
@@ -88,7 +88,7 @@ class PresenceStatusRestServlet(RestServlet):
if self.hs.config.use_presence:
yield self.presence_handler.set_state(user, state)
- defer.returnValue((200, {}))
+ return (200, {})
def on_OPTIONS(self, request):
return (200, {})
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 4d8ab1f47e..d09eb16fb0 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,12 +14,16 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
+import logging
+
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID
+logger = logging.getLogger(__name__)
+
class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
@@ -28,6 +32,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
@defer.inlineCallbacks
@@ -48,7 +53,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
if displayname is not None:
ret["displayname"] = displayname
- defer.returnValue((200, ret))
+ return (200, ret)
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
@@ -61,15 +66,31 @@ class ProfileDisplaynameRestServlet(RestServlet):
try:
new_name = content["displayname"]
except Exception:
- defer.returnValue((400, "Unable to parse name"))
+ return (400, "Unable to parse name")
yield self.profile_handler.set_displayname(user, requester, new_name, is_admin)
- defer.returnValue((200, {}))
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_displayname(shadow_user.to_string(), content)
+
+ return (200, {})
def on_OPTIONS(self, request, user_id):
return (200, {})
+ @defer.inlineCallbacks
+ def shadow_displayname(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
@@ -78,6 +99,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
super(ProfileAvatarURLRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
@defer.inlineCallbacks
@@ -98,7 +120,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
if avatar_url is not None:
ret["avatar_url"] = avatar_url
- defer.returnValue((200, ret))
+ return (200, ret)
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
@@ -108,17 +130,35 @@ class ProfileAvatarURLRestServlet(RestServlet):
content = parse_json_object_from_request(request)
try:
- new_name = content["avatar_url"]
+ new_avatar_url = content["avatar_url"]
except Exception:
- defer.returnValue((400, "Unable to parse name"))
+ return (400, "Unable to parse name")
+
+ yield self.profile_handler.set_avatar_url(
+ user, requester, new_avatar_url, is_admin
+ )
- yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_avatar_url(shadow_user.to_string(), content)
- defer.returnValue((200, {}))
+ return (200, {})
def on_OPTIONS(self, request, user_id):
return (200, {})
+ @defer.inlineCallbacks
+ def shadow_avatar_url(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
@@ -150,7 +190,7 @@ class ProfileRestServlet(RestServlet):
if avatar_url is not None:
ret["avatar_url"] = avatar_url
- defer.returnValue((200, ret))
+ return (200, ret)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index e635efb420..c3ae8b98a8 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -69,7 +69,7 @@ class PushRuleRestServlet(RestServlet):
if "attr" in spec:
yield self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
- defer.returnValue((200, {}))
+ return (200, {})
if spec["rule_id"].startswith("."):
# Rule ids starting with '.' are reserved for server default rules.
@@ -106,7 +106,7 @@ class PushRuleRestServlet(RestServlet):
except RuleNotFoundException as e:
raise SynapseError(400, str(e))
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_DELETE(self, request, path):
@@ -123,7 +123,7 @@ class PushRuleRestServlet(RestServlet):
try:
yield self.store.delete_push_rule(user_id, namespaced_rule_id)
self.notify_user(user_id)
- defer.returnValue((200, {}))
+ return (200, {})
except StoreError as e:
if e.code == 404:
raise NotFoundError()
@@ -151,10 +151,10 @@ class PushRuleRestServlet(RestServlet):
)
if path[0] == "":
- defer.returnValue((200, rules))
+ return (200, rules)
elif path[0] == "global":
result = _filter_ruleset_with_path(rules["global"], path[1:])
- defer.returnValue((200, result))
+ return (200, result)
else:
raise UnrecognizedRequestError()
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index e9246018df..ebc3dec516 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -62,7 +62,7 @@ class PushersRestServlet(RestServlet):
if k not in allowed_keys:
del p[k]
- defer.returnValue((200, {"pushers": pushers}))
+ return (200, {"pushers": pushers})
def on_OPTIONS(self, _):
return 200, {}
@@ -94,7 +94,7 @@ class PushersSetRestServlet(RestServlet):
yield self.pusher_pool.remove_pusher(
content["app_id"], content["pushkey"], user_id=user.to_string()
)
- defer.returnValue((200, {}))
+ return (200, {})
assert_params_in_dict(
content,
@@ -143,7 +143,7 @@ class PushersSetRestServlet(RestServlet):
self.notifier.on_new_replication_data()
- defer.returnValue((200, {}))
+ return (200, {})
def on_OPTIONS(self, _):
return 200, {}
@@ -190,7 +190,7 @@ class PushersRemoveRestServlet(RestServlet):
)
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
finish_request(request)
- defer.returnValue(None)
+ return None
def on_OPTIONS(self, _):
return 200, {}
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 6276e97f89..d42717effd 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -91,7 +91,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
requester, self.get_room_config(request)
)
- defer.returnValue((200, info))
+ return (200, info)
def get_room_config(self, request):
user_supplied_config = parse_json_object_from_request(request)
@@ -173,9 +173,9 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if format == "event":
event = format_event_for_client_v2(data.get_dict())
- defer.returnValue((200, event))
+ return (200, event)
elif format == "content":
- defer.returnValue((200, data.get_dict()["content"]))
+ return (200, data.get_dict()["content"])
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
@@ -210,7 +210,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
ret = {}
if event:
ret = {"event_id": event.event_id}
- defer.returnValue((200, ret))
+ return (200, ret)
# TODO: Needs unit testing for generic events + feedback
@@ -244,7 +244,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
requester, event_dict, txn_id=txn_id
)
- defer.returnValue((200, {"event_id": event.event_id}))
+ return (200, {"event_id": event.event_id})
def on_GET(self, request, room_id, event_type, txn_id):
return (200, "Not implemented")
@@ -307,7 +307,7 @@ class JoinRoomAliasServlet(TransactionRestServlet):
third_party_signed=content.get("third_party_signed", None),
)
- defer.returnValue((200, {"room_id": room_id}))
+ return (200, {"room_id": room_id})
def on_PUT(self, request, room_identifier, txn_id):
return self.txns.fetch_or_execute_request(
@@ -360,7 +360,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit=limit, since_token=since_token
)
- defer.returnValue((200, data))
+ return (200, data)
@defer.inlineCallbacks
def on_POST(self, request):
@@ -405,7 +405,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
network_tuple=network_tuple,
)
- defer.returnValue((200, data))
+ return (200, data)
# TODO: Needs unit testing
@@ -456,7 +456,7 @@ class RoomMemberListRestServlet(RestServlet):
continue
chunk.append(event)
- defer.returnValue((200, {"chunk": chunk}))
+ return (200, {"chunk": chunk})
# deprecated in favour of /members?membership=join?
@@ -477,7 +477,7 @@ class JoinedRoomMemberListRestServlet(RestServlet):
requester, room_id
)
- defer.returnValue((200, {"joined": users_with_profile}))
+ return (200, {"joined": users_with_profile})
# TODO: Needs better unit testing
@@ -510,7 +510,7 @@ class RoomMessageListRestServlet(RestServlet):
event_filter=event_filter,
)
- defer.returnValue((200, msgs))
+ return (200, msgs)
# TODO: Needs unit testing
@@ -531,7 +531,7 @@ class RoomStateRestServlet(RestServlet):
user_id=requester.user.to_string(),
is_guest=requester.is_guest,
)
- defer.returnValue((200, events))
+ return (200, events)
# TODO: Needs unit testing
@@ -550,7 +550,7 @@ class RoomInitialSyncRestServlet(RestServlet):
content = yield self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config
)
- defer.returnValue((200, content))
+ return (200, content)
class RoomEventServlet(RestServlet):
@@ -573,9 +573,9 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
event = yield self._event_serializer.serialize_event(event, time_now)
- defer.returnValue((200, event))
+ return (200, event)
else:
- defer.returnValue((404, "Event not found."))
+ return (404, "Event not found.")
class RoomEventContextServlet(RestServlet):
@@ -625,7 +625,7 @@ class RoomEventContextServlet(RestServlet):
results["state"], time_now
)
- defer.returnValue((200, results))
+ return (200, results)
class RoomForgetRestServlet(TransactionRestServlet):
@@ -644,7 +644,7 @@ class RoomForgetRestServlet(TransactionRestServlet):
yield self.room_member_handler.forget(user=requester.user, room_id=room_id)
- defer.returnValue((200, {}))
+ return (200, {})
def on_PUT(self, request, room_id, txn_id):
return self.txns.fetch_or_execute_request(
@@ -693,8 +693,9 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["id_server"],
requester,
txn_id,
+ new_room=False,
)
- defer.returnValue((200, {}))
+ return (200, {})
return
target = requester.user
@@ -721,7 +722,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
if membership_action == "join":
return_value["room_id"] = room_id
- defer.returnValue((200, return_value))
+ return (200, return_value)
def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address"}:
@@ -763,7 +764,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
txn_id=txn_id,
)
- defer.returnValue((200, {"event_id": event.event_id}))
+ return (200, {"event_id": event.event_id})
def on_PUT(self, request, room_id, event_id, txn_id):
return self.txns.fetch_or_execute_request(
@@ -808,7 +809,7 @@ class RoomTypingRestServlet(RestServlet):
target_user=target_user, auth_user=requester.user, room_id=room_id
)
- defer.returnValue((200, {}))
+ return (200, {})
class SearchRestServlet(RestServlet):
@@ -830,7 +831,7 @@ class SearchRestServlet(RestServlet):
requester.user, content, batch
)
- defer.returnValue((200, results))
+ return (200, results)
class JoinedRoomsRestServlet(RestServlet):
@@ -846,7 +847,7 @@ class JoinedRoomsRestServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
- defer.returnValue((200, {"joined_rooms": list(room_ids)}))
+ return (200, {"joined_rooms": list(room_ids)})
def register_txn_path(servlet, regex_string, http_server, with_get=False):
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 41b3171ac8..497cddf8b8 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -60,18 +60,16 @@ class VoipRestServlet(RestServlet):
password = turnPassword
else:
- defer.returnValue((200, {}))
-
- defer.returnValue(
- (
- 200,
- {
- "username": username,
- "password": password,
- "ttl": userLifetime / 1000,
- "uris": turnUris,
- },
- )
+ return (200, {})
+
+ return (
+ 200,
+ {
+ "username": username,
+ "password": password,
+ "ttl": userLifetime / 1000,
+ "uris": turnUris,
+ },
)
def on_OPTIONS(self, request):
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index f143d8b85c..2580e2bc63 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
from six.moves import http_client
@@ -31,8 +32,9 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
-from synapse.util.stringutils import random_string
+from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -82,6 +84,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Extract params from body
client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
+
email = body["email"]
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -89,7 +93,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", email):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
@@ -117,7 +121,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Wrap the session id in a JSON object
ret = {"sid": sid}
- defer.returnValue((200, ret))
+ return (200, ret)
@defer.inlineCallbacks
def send_password_reset(self, email, client_secret, send_attempt, next_link=None):
@@ -149,7 +153,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt:
# If not, just return a success without sending an email
- defer.returnValue(session_id)
+ return session_id
else:
# An non-validated session does not exist yet.
# Generate a session id
@@ -185,7 +189,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
token_expires,
)
- defer.returnValue(session_id)
+ return session_id
class MsisdnPasswordRequestTokenRestServlet(RestServlet):
@@ -208,20 +212,22 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body["country"], body["phone_number"])
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
+ assert_valid_client_secret(body["client_secret"])
+
existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn)
if existingUid is None:
raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
ret = yield self.identity_handler.requestMsisdnToken(**body)
- defer.returnValue((200, ret))
+ return (200, ret)
class PasswordResetSubmitTokenServlet(RestServlet):
@@ -260,6 +266,9 @@ class PasswordResetSubmitTokenServlet(RestServlet):
sid = parse_string(request, "sid")
client_secret = parse_string(request, "client_secret")
+
+ assert_valid_client_secret(client_secret)
+
token = parse_string(request, "token")
# Attempt to validate a 3PID sesssion
@@ -279,7 +288,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
request.setResponseCode(302)
request.setHeader("Location", next_link)
finish_request(request)
- defer.returnValue(None)
+ return None
# Otherwise show the success template
html = self.config.email_password_reset_success_html_content
@@ -295,7 +304,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
request.write(html.encode("utf-8"))
finish_request(request)
- defer.returnValue(None)
+ return None
def load_jinja2_template(self, template_dir, template_filename, template_vars):
"""Loads a jinja2 template with variables to insert
@@ -325,12 +334,14 @@ class PasswordResetSubmitTokenServlet(RestServlet):
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["sid", "client_secret", "token"])
- valid, _ = yield self.datastore.validate_threepid_validation_token(
+ assert_valid_client_secret(body["client_secret"])
+
+ valid, _ = yield self.datastore.validate_threepid_session(
body["sid"], body["client_secret"], body["token"], self.clock.time_msec()
)
response_code = 200 if valid else 400
- defer.returnValue((response_code, {"success": valid}))
+ return (response_code, {"success": valid})
class PasswordRestServlet(RestServlet):
@@ -343,6 +354,7 @@ class PasswordRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
self._set_password_handler = hs.get_set_password_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
@defer.inlineCallbacks
@@ -361,9 +373,13 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
- params = yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
- )
+ # blindly trust ASes without UI-authing them
+ if requester.app_service:
+ params = body
+ else:
+ params = yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request)
+ )
user_id = requester.user.to_string()
else:
requester = None
@@ -399,11 +415,29 @@ class PasswordRestServlet(RestServlet):
yield self._set_password_handler.set_password(user_id, new_password, requester)
- defer.returnValue((200, {}))
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_password(params, shadow_user.to_string())
+
+ return (200, {})
def on_OPTIONS(self, _):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_password(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
@@ -434,7 +468,7 @@ class DeactivateAccountRestServlet(RestServlet):
yield self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase
)
- defer.returnValue((200, {}))
+ return (200, {})
yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
@@ -447,7 +481,7 @@ class DeactivateAccountRestServlet(RestServlet):
else:
id_server_unbind_result = "no-support"
- defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result}))
+ return (200, {"id_server_unbind_result": id_server_unbind_result})
class EmailThreepidRequestTokenRestServlet(RestServlet):
@@ -466,13 +500,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
body, ["id_server", "client_secret", "email", "send_attempt"]
)
- if not check_3pid_allowed(self.hs, "email", body["email"]):
+ if not (yield check_3pid_allowed(self.hs, "email", body["email"])):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
+ assert_valid_client_secret(body["client_secret"])
+
existingUid = yield self.datastore.get_user_id_by_threepid(
"email", body["email"]
)
@@ -481,7 +517,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
- defer.returnValue((200, ret))
+ return (200, ret)
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@@ -503,20 +539,22 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body["country"], body["phone_number"])
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
+ assert_valid_client_secret(body["client_secret"])
+
existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn)
if existingUid is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestMsisdnToken(**body)
- defer.returnValue((200, ret))
+ return (200, ret)
class ThreepidRestServlet(RestServlet):
@@ -528,7 +566,8 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- self.datastore = self.hs.get_datastore()
+ self.datastore = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -536,39 +575,70 @@ class ThreepidRestServlet(RestServlet):
threepids = yield self.datastore.user_get_threepids(requester.user.to_string())
- defer.returnValue((200, {"threepids": threepids}))
+ return (200, {"threepids": threepids})
@defer.inlineCallbacks
def on_POST(self, request):
- body = parse_json_object_from_request(request)
+ if self.hs.config.disable_3pid_changes:
+ raise SynapseError(400, "3PID changes disabled on this server")
- threePidCreds = body.get("threePidCreds")
- threePidCreds = body.get("three_pid_creds", threePidCreds)
- if threePidCreds is None:
- raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
+ body = parse_json_object_from_request(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
+ # skip validation if this is a shadow 3PID from an AS
+ if not requester.app_service:
+ threePidCreds = body.get("threePidCreds")
+ threePidCreds = body.get("three_pid_creds", threePidCreds)
+ if threePidCreds is None:
+ raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
- if not threepid:
- raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED)
+ threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
- for reqd in ["medium", "address", "validated_at"]:
- if reqd not in threepid:
- logger.warn("Couldn't add 3pid: invalid response from ID server")
- raise SynapseError(500, "Invalid response from ID Server")
+ if not threepid:
+ raise SynapseError(
+ 400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED
+ )
+
+ for reqd in ["medium", "address", "validated_at"]:
+ if reqd not in threepid:
+ logger.warn("Couldn't add 3pid: invalid response from ID server")
+ raise SynapseError(500, "Invalid response from ID Server")
+ else:
+ # XXX: ASes pass in a validated threepid directly to bypass the IS.
+ # This makes the API entirely change shape when we have an AS token;
+ # it really should be an entirely separate API - perhaps
+ # /account/3pid/replicate or something.
+ threepid = body.get("threepid")
yield self.auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
)
- if "bind" in body and body["bind"]:
+ if not requester.app_service and ("bind" in body and body["bind"]):
logger.debug("Binding threepid %s to %s", threepid, user_id)
yield self.identity_handler.bind_threepid(threePidCreds, user_id)
- defer.returnValue((200, {}))
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
+ return (200, {})
+
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
class ThreepidDeleteRestServlet(RestServlet):
@@ -576,11 +646,16 @@ class ThreepidDeleteRestServlet(RestServlet):
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
+ self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
@defer.inlineCallbacks
def on_POST(self, request):
+ if self.hs.config.disable_3pid_changes:
+ raise SynapseError(400, "3PID changes disabled on this server")
+
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
@@ -598,12 +673,89 @@ class ThreepidDeleteRestServlet(RestServlet):
logger.exception("Failed to remove threepid")
raise SynapseError(500, "Failed to remove threepid")
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid_delete(body, shadow_user.to_string())
+
if ret:
id_server_unbind_result = "success"
else:
id_server_unbind_result = "no-support"
- defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result}))
+ return (200, {"id_server_unbind_result": id_server_unbind_result})
+
+ @defer.inlineCallbacks
+ def shadow_3pid_delete(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
+
+class ThreepidLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ """Proxy a /_matrix/identity/api/v1/lookup request to an identity
+ server
+ """
+ yield self.auth.get_user_by_req(request)
+
+ # Verify query parameters
+ query_params = request.args
+ assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"])
+
+ # Retrieve needed information from query parameters
+ medium = parse_string(request, "medium")
+ address = parse_string(request, "address")
+ id_server = parse_string(request, "id_server")
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = yield self.identity_handler.lookup_3pid(id_server, medium, address)
+
+ defer.returnValue((200, ret))
+
+
+class ThreepidBulkLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidBulkLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
+ server
+ """
+ yield self.auth.get_user_by_req(request)
+
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["threepids", "id_server"])
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = yield self.identity_handler.bulk_lookup_3pid(
+ body["id_server"], body["threepids"]
+ )
+
+ defer.returnValue((200, ret))
class WhoamiRestServlet(RestServlet):
@@ -617,7 +769,7 @@ class WhoamiRestServlet(RestServlet):
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
- defer.returnValue((200, {"user_id": requester.user.to_string()}))
+ return (200, {"user_id": requester.user.to_string()})
def register_servlets(hs, http_server):
@@ -630,4 +782,6 @@ def register_servlets(hs, http_server):
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ ThreepidLookupRestServlet(hs).register(http_server)
+ ThreepidBulkLookupRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index f155c26259..be1360eca3 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
from ._base import client_patterns
@@ -40,6 +41,7 @@ class AccountDataServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
+ self._profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_PUT(self, request, user_id, account_data_type):
@@ -49,13 +51,18 @@ class AccountDataServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if account_data_type == "im.vector.hide_profile":
+ user = UserID.from_string(user_id)
+ hide_profile = body.get("hide_profile")
+ yield self._profile_handler.set_active(user, not hide_profile, True)
+
max_id = yield self.store.add_account_data_for_user(
user_id, account_data_type, body
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_GET(self, request, user_id, account_data_type):
@@ -70,7 +77,7 @@ class AccountDataServlet(RestServlet):
if event is None:
raise NotFoundError("Account data not found")
- defer.returnValue((200, event))
+ return (200, event)
class RoomAccountDataServlet(RestServlet):
@@ -112,7 +119,7 @@ class RoomAccountDataServlet(RestServlet):
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_GET(self, request, user_id, room_id, account_data_type):
@@ -127,7 +134,7 @@ class RoomAccountDataServlet(RestServlet):
if event is None:
raise NotFoundError("Room account data not found")
- defer.returnValue((200, event))
+ return (200, event)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index d29c10b83d..327df56acc 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -15,6 +15,8 @@
import logging
+from six import ensure_binary
+
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
@@ -42,6 +44,8 @@ class AccountValidityRenewServlet(RestServlet):
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
+ self.success_html = hs.config.account_validity.account_renewed_html_content
+ self.failure_html = hs.config.account_validity.invalid_token_html_content
@defer.inlineCallbacks
def on_GET(self, request):
@@ -49,16 +53,23 @@ class AccountValidityRenewServlet(RestServlet):
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
- yield self.account_activity_handler.renew_account(renewal_token.decode("utf8"))
+ token_valid = yield self.account_activity_handler.renew_account(
+ renewal_token.decode("utf8")
+ )
+
+ if token_valid:
+ status_code = 200
+ response = self.success_html
+ else:
+ status_code = 404
+ response = self.failure_html
- request.setResponseCode(200)
+ request.setResponseCode(status_code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(
- b"Content-Length", b"%d" % (len(AccountValidityRenewServlet.SUCCESS_HTML),)
- )
- request.write(AccountValidityRenewServlet.SUCCESS_HTML)
+ request.setHeader(b"Content-Length", b"%d" % (len(response),))
+ request.write(ensure_binary(response))
finish_request(request)
- defer.returnValue(None)
+ return None
class AccountValiditySendMailServlet(RestServlet):
@@ -87,7 +98,7 @@ class AccountValiditySendMailServlet(RestServlet):
user_id = requester.user.to_string()
yield self.account_activity_handler.send_renewal_email_to_user(user_id)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index bebc2951e7..f21aff39e5 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -207,7 +207,7 @@ class AuthRestServlet(RestServlet):
request.write(html_bytes)
finish_request(request)
- defer.returnValue(None)
+ return None
elif stagetype == LoginType.TERMS:
if ("session" not in request.args or len(request.args["session"])) == 0:
raise SynapseError(400, "No session supplied")
@@ -239,7 +239,7 @@ class AuthRestServlet(RestServlet):
request.write(html_bytes)
finish_request(request)
- defer.returnValue(None)
+ return None
else:
raise SynapseError(404, "Unknown auth stage type")
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index fc7e2f4dd5..a4fa45fe11 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -58,7 +58,7 @@ class CapabilitiesRestServlet(RestServlet):
"m.change_password": {"enabled": change_password},
}
}
- defer.returnValue((200, response))
+ return (200, response)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index d279229d74..9adf76cc0c 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -48,7 +48,7 @@ class DevicesRestServlet(RestServlet):
devices = yield self.device_handler.get_devices_by_user(
requester.user.to_string()
)
- defer.returnValue((200, {"devices": devices}))
+ return (200, {"devices": devices})
class DeleteDevicesRestServlet(RestServlet):
@@ -91,7 +91,7 @@ class DeleteDevicesRestServlet(RestServlet):
yield self.device_handler.delete_devices(
requester.user.to_string(), body["devices"]
)
- defer.returnValue((200, {}))
+ return (200, {})
class DeviceRestServlet(RestServlet):
@@ -114,7 +114,7 @@ class DeviceRestServlet(RestServlet):
device = yield self.device_handler.get_device(
requester.user.to_string(), device_id
)
- defer.returnValue((200, device))
+ return (200, device)
@interactive_auth_handler
@defer.inlineCallbacks
@@ -137,7 +137,7 @@ class DeviceRestServlet(RestServlet):
)
yield self.device_handler.delete_device(requester.user.to_string(), device_id)
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_PUT(self, request, device_id):
@@ -147,7 +147,7 @@ class DeviceRestServlet(RestServlet):
yield self.device_handler.update_device(
requester.user.to_string(), device_id, body
)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index 3f0adf4a21..22be0ee3c5 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -56,7 +56,7 @@ class GetFilterRestServlet(RestServlet):
user_localpart=target_user.localpart, filter_id=filter_id
)
- defer.returnValue((200, filter.get_filter_json()))
+ return (200, filter.get_filter_json())
except (KeyError, StoreError):
raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND)
@@ -89,7 +89,7 @@ class CreateFilterRestServlet(RestServlet):
user_localpart=target_user.localpart, user_filter=content
)
- defer.returnValue((200, {"filter_id": str(filter_id)}))
+ return (200, {"filter_id": str(filter_id)})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index a312dd2593..e629c4256d 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -47,7 +47,7 @@ class GroupServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, group_description))
+ return (200, group_description)
@defer.inlineCallbacks
def on_POST(self, request, group_id):
@@ -59,7 +59,7 @@ class GroupServlet(RestServlet):
group_id, requester_user_id, content
)
- defer.returnValue((200, {}))
+ return (200, {})
class GroupSummaryServlet(RestServlet):
@@ -83,7 +83,7 @@ class GroupSummaryServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, get_group_summary))
+ return (200, get_group_summary)
class GroupSummaryRoomsCatServlet(RestServlet):
@@ -120,7 +120,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
content=content,
)
- defer.returnValue((200, resp))
+ return (200, resp)
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, category_id, room_id):
@@ -131,7 +131,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
group_id, requester_user_id, room_id=room_id, category_id=category_id
)
- defer.returnValue((200, resp))
+ return (200, resp)
class GroupCategoryServlet(RestServlet):
@@ -157,7 +157,7 @@ class GroupCategoryServlet(RestServlet):
group_id, requester_user_id, category_id=category_id
)
- defer.returnValue((200, category))
+ return (200, category)
@defer.inlineCallbacks
def on_PUT(self, request, group_id, category_id):
@@ -169,7 +169,7 @@ class GroupCategoryServlet(RestServlet):
group_id, requester_user_id, category_id=category_id, content=content
)
- defer.returnValue((200, resp))
+ return (200, resp)
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, category_id):
@@ -180,7 +180,7 @@ class GroupCategoryServlet(RestServlet):
group_id, requester_user_id, category_id=category_id
)
- defer.returnValue((200, resp))
+ return (200, resp)
class GroupCategoriesServlet(RestServlet):
@@ -204,7 +204,7 @@ class GroupCategoriesServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, category))
+ return (200, category)
class GroupRoleServlet(RestServlet):
@@ -228,7 +228,7 @@ class GroupRoleServlet(RestServlet):
group_id, requester_user_id, role_id=role_id
)
- defer.returnValue((200, category))
+ return (200, category)
@defer.inlineCallbacks
def on_PUT(self, request, group_id, role_id):
@@ -240,7 +240,7 @@ class GroupRoleServlet(RestServlet):
group_id, requester_user_id, role_id=role_id, content=content
)
- defer.returnValue((200, resp))
+ return (200, resp)
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, role_id):
@@ -251,7 +251,7 @@ class GroupRoleServlet(RestServlet):
group_id, requester_user_id, role_id=role_id
)
- defer.returnValue((200, resp))
+ return (200, resp)
class GroupRolesServlet(RestServlet):
@@ -275,7 +275,7 @@ class GroupRolesServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, category))
+ return (200, category)
class GroupSummaryUsersRoleServlet(RestServlet):
@@ -312,7 +312,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
content=content,
)
- defer.returnValue((200, resp))
+ return (200, resp)
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, role_id, user_id):
@@ -323,7 +323,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
group_id, requester_user_id, user_id=user_id, role_id=role_id
)
- defer.returnValue((200, resp))
+ return (200, resp)
class GroupRoomServlet(RestServlet):
@@ -347,7 +347,7 @@ class GroupRoomServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupUsersServlet(RestServlet):
@@ -371,7 +371,7 @@ class GroupUsersServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupInvitedUsersServlet(RestServlet):
@@ -395,7 +395,7 @@ class GroupInvitedUsersServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupSettingJoinPolicyServlet(RestServlet):
@@ -420,7 +420,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
group_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupCreateServlet(RestServlet):
@@ -450,7 +450,7 @@ class GroupCreateServlet(RestServlet):
group_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupAdminRoomsServlet(RestServlet):
@@ -477,7 +477,7 @@ class GroupAdminRoomsServlet(RestServlet):
group_id, requester_user_id, room_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, room_id):
@@ -488,7 +488,7 @@ class GroupAdminRoomsServlet(RestServlet):
group_id, requester_user_id, room_id
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupAdminRoomsConfigServlet(RestServlet):
@@ -516,7 +516,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
group_id, requester_user_id, room_id, config_key, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupAdminUsersInviteServlet(RestServlet):
@@ -546,7 +546,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
group_id, user_id, requester_user_id, config
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupAdminUsersKickServlet(RestServlet):
@@ -573,7 +573,7 @@ class GroupAdminUsersKickServlet(RestServlet):
group_id, user_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupSelfLeaveServlet(RestServlet):
@@ -598,7 +598,7 @@ class GroupSelfLeaveServlet(RestServlet):
group_id, requester_user_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupSelfJoinServlet(RestServlet):
@@ -623,7 +623,7 @@ class GroupSelfJoinServlet(RestServlet):
group_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupSelfAcceptInviteServlet(RestServlet):
@@ -648,7 +648,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
group_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupSelfUpdatePublicityServlet(RestServlet):
@@ -672,7 +672,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
publicise = content["publicise"]
yield self.store.update_group_publicity(group_id, requester_user_id, publicise)
- defer.returnValue((200, {}))
+ return (200, {})
class PublicisedGroupsForUserServlet(RestServlet):
@@ -694,7 +694,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
result = yield self.groups_handler.get_publicised_groups_for_user(user_id)
- defer.returnValue((200, result))
+ return (200, result)
class PublicisedGroupsForUsersServlet(RestServlet):
@@ -719,7 +719,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
result = yield self.groups_handler.bulk_get_publicised_groups(user_ids)
- defer.returnValue((200, result))
+ return (200, result)
class GroupsForUserServlet(RestServlet):
@@ -741,7 +741,7 @@ class GroupsForUserServlet(RestServlet):
result = yield self.groups_handler.get_joined_groups(requester_user_id)
- defer.returnValue((200, result))
+ return (200, result)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 45c9928b65..6008adec7c 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -95,7 +95,7 @@ class KeyUploadServlet(RestServlet):
result = yield self.e2e_keys_handler.upload_keys_for_user(
user_id, device_id, body
)
- defer.returnValue((200, result))
+ return (200, result)
class KeyQueryServlet(RestServlet):
@@ -149,7 +149,7 @@ class KeyQueryServlet(RestServlet):
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body, timeout)
- defer.returnValue((200, result))
+ return (200, result)
class KeyChangesServlet(RestServlet):
@@ -189,7 +189,7 @@ class KeyChangesServlet(RestServlet):
results = yield self.device_handler.get_user_ids_changed(user_id, from_token)
- defer.returnValue((200, results))
+ return (200, results)
class OneTimeKeyServlet(RestServlet):
@@ -224,7 +224,7 @@ class OneTimeKeyServlet(RestServlet):
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout)
- defer.returnValue((200, result))
+ return (200, result)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index 728a52328f..d034863a3c 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -88,9 +88,7 @@ class NotificationsServlet(RestServlet):
returned_push_actions.append(returned_pa)
next_token = str(pa["stream_ordering"])
- defer.returnValue(
- (200, {"notifications": returned_push_actions, "next_token": next_token})
- )
+ return (200, {"notifications": returned_push_actions, "next_token": next_token})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index b1b5385b09..b4925c0f59 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -83,16 +83,14 @@ class IdTokenServlet(RestServlet):
yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
- defer.returnValue(
- (
- 200,
- {
- "access_token": token,
- "token_type": "Bearer",
- "matrix_server_name": self.server_name,
- "expires_in": self.EXPIRES_MS / 1000,
- },
- )
+ return (
+ 200,
+ {
+ "access_token": token,
+ "token_type": "Bearer",
+ "matrix_server_name": self.server_name,
+ "expires_in": self.EXPIRES_MS / 1000,
+ },
)
diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py
new file mode 100644
index 0000000000..968403cca4
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/password_policy.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import RestServlet
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordPolicyServlet(RestServlet):
+ PATTERNS = client_patterns("/password_policy$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(PasswordPolicyServlet, self).__init__()
+
+ self.policy = hs.config.password_policy
+ self.enabled = hs.config.password_policy_enabled
+
+ def on_GET(self, request):
+ if not self.enabled or not self.policy:
+ return (200, {})
+
+ policy = {}
+
+ for param in [
+ "minimum_length",
+ "require_digit",
+ "require_symbol",
+ "require_lowercase",
+ "require_uppercase",
+ ]:
+ if param in self.policy:
+ policy["m.%s" % param] = self.policy[param]
+
+ return (200, policy)
+
+
+def register_servlets(hs, http_server):
+ PasswordPolicyServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index e75664279b..d93d6a9f24 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -59,7 +59,7 @@ class ReadMarkerRestServlet(RestServlet):
event_id=read_marker_event_id,
)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 488905626a..98a97b7059 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -52,7 +52,7 @@ class ReceiptRestServlet(RestServlet):
room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index f327999e59..c4ab2e53cf 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 - 2016 OpenMarket Ltd
-# Copyright 2017 Vector Creations Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,7 @@
import hmac
import logging
+import re
from hashlib import sha1
from six import string_types
@@ -41,6 +43,7 @@ from synapse.http.servlet import (
)
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
+from synapse.util.stringutils import assert_valid_client_secret
from synapse.util.threepids import check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -80,13 +83,15 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
body, ["id_server", "client_secret", "email", "send_attempt"]
)
- if not check_3pid_allowed(self.hs, "email", body["email"]):
+ if not (yield check_3pid_allowed(self.hs, "email", body["email"])):
raise SynapseError(
403,
- "Your email domain is not authorized to register on this server",
+ "Your email is not authorized to register on this server",
Codes.THREEPID_DENIED,
)
+ assert_params_in_dict(body["client_secret"])
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
"email", body["email"]
)
@@ -95,7 +100,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
- defer.returnValue((200, ret))
+ return (200, ret)
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
@@ -121,7 +126,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body["country"], body["phone_number"])
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ assert_valid_client_secret(body["client_secret"])
+
+ if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Phone numbers are not authorized to register on this server",
@@ -138,7 +145,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
)
ret = yield self.identity_handler.requestMsisdnToken(**body)
- defer.returnValue((200, ret))
+ return (200, ret)
class UsernameAvailabilityRestServlet(RestServlet):
@@ -178,7 +185,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
yield self.registration_handler.check_username(username)
- defer.returnValue((200, {"available": True}))
+ return (200, {"available": True})
class RegisterRestServlet(RestServlet):
@@ -200,6 +207,7 @@ class RegisterRestServlet(RestServlet):
self.room_member_handler = hs.get_room_member_handler()
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_registration_ratelimiter()
+ self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
@interactive_auth_handler
@@ -230,7 +238,7 @@ class RegisterRestServlet(RestServlet):
if kind == b"guest":
ret = yield self._do_guest_registration(body, address=client_addr)
- defer.returnValue(ret)
+ return ret
return
elif kind != b"user":
raise UnrecognizedRequestError(
@@ -246,6 +254,7 @@ class RegisterRestServlet(RestServlet):
or len(body["password"]) > 512
):
raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(body["password"])
desired_password = body["password"]
desired_username = None
@@ -257,6 +266,8 @@ class RegisterRestServlet(RestServlet):
raise SynapseError(400, "Invalid username")
desired_username = body["username"]
+ desired_display_name = body.get("display_name")
+
appservice = None
if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
@@ -280,9 +291,13 @@ class RegisterRestServlet(RestServlet):
if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(
- desired_username, access_token, body
+ desired_username,
+ desired_password,
+ desired_display_name,
+ access_token,
+ body,
)
- defer.returnValue((200, result)) # we throw for non 200 responses
+ return (200, result) # we throw for non 200 responses
return
# for either shared secret or regular registration, downcase the
@@ -301,7 +316,7 @@ class RegisterRestServlet(RestServlet):
result = yield self._do_shared_secret_registration(
desired_username, desired_password, body
)
- defer.returnValue((200, result)) # we throw for non 200 responses
+ return (200, result) # we throw for non 200 responses
return
# == Normal User Registration == (everyone else)
@@ -413,7 +428,7 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
- if not check_3pid_allowed(self.hs, medium, address):
+ if not (yield check_3pid_allowed(self.hs, medium, address)):
raise SynapseError(
403,
"Third party identifiers (email/phone numbers)"
@@ -421,6 +436,80 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ existingUid = yield self.store.get_user_id_by_threepid(
+ medium, address
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE
+ )
+
+ if self.hs.config.register_mxid_from_3pid:
+ # override the desired_username based on the 3PID if any.
+ # reset it first to avoid folks picking their own username.
+ desired_username = None
+
+ # we should have an auth_result at this point if we're going to progress
+ # to register the user (i.e. we haven't picked up a registered_user_id
+ # from our session store), in which case get ready and gen the
+ # desired_username
+ if auth_result:
+ if (
+ self.hs.config.register_mxid_from_3pid == "email"
+ and LoginType.EMAIL_IDENTITY in auth_result
+ ):
+ address = auth_result[LoginType.EMAIL_IDENTITY]["address"]
+ desired_username = synapse.types.strip_invalid_mxid_characters(
+ address.replace("@", "-").lower()
+ )
+
+ # find a unique mxid for the account, suffixing numbers
+ # if needed
+ while True:
+ try:
+ yield self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+ # if we got this far we passed the check.
+ break
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ m = re.match(r"^(.*?)(\d+)$", desired_username)
+ if m:
+ desired_username = m.group(1) + str(
+ int(m.group(2)) + 1
+ )
+ else:
+ desired_username += "1"
+ else:
+ # something else went wrong.
+ break
+
+ if self.hs.config.register_just_use_email_for_display_name:
+ desired_display_name = address
+ else:
+ # Custom mapping between email address and display name
+ desired_display_name = self._map_email_to_displayname(address)
+ elif (
+ self.hs.config.register_mxid_from_3pid == "msisdn"
+ and LoginType.MSISDN in auth_result
+ ):
+ desired_username = auth_result[LoginType.MSISDN]["address"]
+ else:
+ raise SynapseError(
+ 400, "Cannot derive mxid from 3pid; no recognised 3pid"
+ )
+
+ if desired_username is not None:
+ yield self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session", registered_user_id
@@ -431,9 +520,16 @@ class RegisterRestServlet(RestServlet):
# NB: This may be from the auth handler and NOT from the POST
assert_params_in_dict(params, ["password"])
- desired_username = params.get("username", None)
+ if not self.hs.config.register_mxid_from_3pid:
+ desired_username = params.get("username", None)
+ else:
+ # we keep the original desired_username derived from the 3pid above
+ pass
+
guest_access_token = params.get("guest_access_token", None)
- new_password = params.get("password", None)
+
+ # XXX: don't we need to validate these for length etc like we did on
+ # the ones from the JSON body earlier on in the method?
if desired_username is not None:
desired_username = desired_username.lower()
@@ -466,8 +562,9 @@ class RegisterRestServlet(RestServlet):
registered_user_id = yield self.registration_handler.register_user(
localpart=desired_username,
- password=new_password,
+ password=params.get("password", None),
guest_access_token=guest_access_token,
+ default_display_name=desired_display_name,
threepid=threepid,
address=client_addr,
)
@@ -479,6 +576,14 @@ class RegisterRestServlet(RestServlet):
):
yield self.store.upsert_monthly_active_user(registered_user_id)
+ if self.hs.config.shadow_server:
+ yield self.registration_handler.shadow_register(
+ localpart=desired_username,
+ display_name=desired_display_name,
+ auth_result=auth_result,
+ params=params,
+ )
+
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
self.auth_handler.set_session_data(
@@ -500,17 +605,37 @@ class RegisterRestServlet(RestServlet):
bind_msisdn=params.get("bind_msisdn"),
)
- defer.returnValue((200, return_dict))
+ return (200, return_dict)
def on_OPTIONS(self, _):
return 200, {}
@defer.inlineCallbacks
- def _do_appservice_registration(self, username, as_token, body):
+ def _do_appservice_registration(
+ self, username, password, display_name, as_token, body
+ ):
+
+ # FIXME: appservice_register() is horribly duplicated with register()
+ # and they should probably just be combined together with a config flag.
user_id = yield self.registration_handler.appservice_register(
- username, as_token
+ username, as_token, password, display_name
)
- defer.returnValue((yield self._create_registration_details(user_id, body)))
+ result = yield self._create_registration_details(user_id, body)
+
+ auth_result = body.get("auth_result")
+ if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
+ threepid = auth_result[LoginType.EMAIL_IDENTITY]
+ yield self._register_email_threepid(
+ user_id, threepid, result["access_token"], body.get("bind_email")
+ )
+
+ if auth_result and LoginType.MSISDN in auth_result:
+ threepid = auth_result[LoginType.MSISDN]
+ yield self._register_msisdn_threepid(
+ user_id, threepid, result["access_token"], body.get("bind_msisdn")
+ )
+
+ return result
@defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, body):
@@ -546,7 +671,7 @@ class RegisterRestServlet(RestServlet):
)
result = yield self._create_registration_details(user_id, body)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _create_registration_details(self, user_id, params):
@@ -570,7 +695,7 @@ class RegisterRestServlet(RestServlet):
)
result.update({"access_token": access_token, "device_id": device_id})
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _do_guest_registration(self, params, address=None):
@@ -588,19 +713,71 @@ class RegisterRestServlet(RestServlet):
user_id, device_id, initial_display_name, is_guest=True
)
- defer.returnValue(
- (
- 200,
- {
- "user_id": user_id,
- "device_id": device_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- },
- )
+ return (
+ 200,
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "access_token": access_token,
+ "home_server": self.hs.hostname,
+ },
)
+def cap(name):
+ """Capitalise parts of a name containing different words, including those
+ separated by hyphens.
+ For example, 'John-Doe'
+
+ Args:
+ name (str): The name to parse
+ """
+ if not name:
+ return name
+
+ # Split the name by whitespace then hyphens, capitalizing each part then
+ # joining it back together.
+ capatilized_name = " ".join(
+ "-".join(part.capitalize() for part in space_part.split("-"))
+ for space_part in name.split()
+ )
+ return capatilized_name
+
+
+def _map_email_to_displayname(address):
+ """Custom mapping from an email address to a user displayname
+
+ Args:
+ address (str): The email address to process
+ Returns:
+ str: The new displayname
+ """
+ # Split the part before and after the @ in the email.
+ # Replace all . with spaces in the first part
+ parts = address.replace(".", " ").split("@")
+
+ # Figure out which org this email address belongs to
+ org_parts = parts[1].split(" ")
+
+ # If this is a ...matrix.org email, mark them as an Admin
+ if org_parts[-2] == "matrix" and org_parts[-1] == "org":
+ org = "Tchap Admin"
+
+ # Is this is a ...gouv.fr address, set the org to whatever is before
+ # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their
+ # org as "gouv"
+ elif org_parts[-2] == "gouv" and org_parts[-1] == "fr":
+ org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2]
+
+ # Otherwise, mark their org as the email's second-level domain name
+ else:
+ org = org_parts[-2]
+
+ desired_display_name = cap(parts[0]) + " [" + cap(org) + "]"
+
+ return desired_display_name
+
+
def register_servlets(hs, http_server):
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 9e9a639055..1538b247e5 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -118,7 +118,7 @@ class RelationSendServlet(RestServlet):
requester, event_dict=event_dict, txn_id=txn_id
)
- defer.returnValue((200, {"event_id": event.event_id}))
+ return (200, {"event_id": event.event_id})
class RelationPaginationServlet(RestServlet):
@@ -198,7 +198,7 @@ class RelationPaginationServlet(RestServlet):
return_value["chunk"] = events
return_value["original_event"] = original_event
- defer.returnValue((200, return_value))
+ return (200, return_value)
class RelationAggregationPaginationServlet(RestServlet):
@@ -270,7 +270,7 @@ class RelationAggregationPaginationServlet(RestServlet):
to_token=to_token,
)
- defer.returnValue((200, pagination_chunk.to_dict()))
+ return (200, pagination_chunk.to_dict())
class RelationAggregationGroupPaginationServlet(RestServlet):
@@ -356,7 +356,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
return_value = result.to_dict()
return_value["chunk"] = events
- defer.returnValue((200, return_value))
+ return (200, return_value)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index e7578af804..3fdd4584a3 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -72,7 +72,7 @@ class ReportEventRestServlet(RestServlet):
received_ts=self.clock.time_msec(),
)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 8d1b810565..10dec96208 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -135,7 +135,7 @@ class RoomKeysServlet(RestServlet):
body = {"rooms": {room_id: body}}
yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_GET(self, request, room_id, session_id):
@@ -218,7 +218,7 @@ class RoomKeysServlet(RestServlet):
else:
room_keys = room_keys["rooms"][room_id]
- defer.returnValue((200, room_keys))
+ return (200, room_keys)
@defer.inlineCallbacks
def on_DELETE(self, request, room_id, session_id):
@@ -242,7 +242,7 @@ class RoomKeysServlet(RestServlet):
yield self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id
)
- defer.returnValue((200, {}))
+ return (200, {})
class RoomKeysNewVersionServlet(RestServlet):
@@ -293,7 +293,7 @@ class RoomKeysNewVersionServlet(RestServlet):
info = parse_json_object_from_request(request)
new_version = yield self.e2e_room_keys_handler.create_version(user_id, info)
- defer.returnValue((200, {"version": new_version}))
+ return (200, {"version": new_version})
# we deliberately don't have a PUT /version, as these things really should
# be immutable to avoid people footgunning
@@ -338,7 +338,7 @@ class RoomKeysVersionServlet(RestServlet):
except SynapseError as e:
if e.code == 404:
raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
- defer.returnValue((200, info))
+ return (200, info)
@defer.inlineCallbacks
def on_DELETE(self, request, version):
@@ -358,7 +358,7 @@ class RoomKeysVersionServlet(RestServlet):
user_id = requester.user.to_string()
yield self.e2e_room_keys_handler.delete_version(user_id, version)
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_PUT(self, request, version):
@@ -392,7 +392,7 @@ class RoomKeysVersionServlet(RestServlet):
)
yield self.e2e_room_keys_handler.update_version(user_id, version, info)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index d7f7faa029..14ba61a63e 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -80,7 +80,7 @@ class RoomUpgradeRestServlet(RestServlet):
ret = {"replacement_room": new_room_id}
- defer.returnValue((200, ret))
+ return (200, ret)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index 78075b8fc0..2613648d82 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -60,7 +60,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
)
response = (200, {})
- defer.returnValue(response)
+ return response
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 02d56dee6c..7b32dd2212 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -174,7 +174,7 @@ class SyncRestServlet(RestServlet):
time_now, sync_result, requester.access_token_id, filter
)
- defer.returnValue((200, response_content))
+ return (200, response_content)
@defer.inlineCallbacks
def encode_response(self, time_now, sync_result, access_token_id, filter):
@@ -205,27 +205,23 @@ class SyncRestServlet(RestServlet):
event_formatter,
)
- defer.returnValue(
- {
- "account_data": {"events": sync_result.account_data},
- "to_device": {"events": sync_result.to_device},
- "device_lists": {
- "changed": list(sync_result.device_lists.changed),
- "left": list(sync_result.device_lists.left),
- },
- "presence": SyncRestServlet.encode_presence(
- sync_result.presence, time_now
- ),
- "rooms": {"join": joined, "invite": invited, "leave": archived},
- "groups": {
- "join": sync_result.groups.join,
- "invite": sync_result.groups.invite,
- "leave": sync_result.groups.leave,
- },
- "device_one_time_keys_count": sync_result.device_one_time_keys_count,
- "next_batch": sync_result.next_batch.to_string(),
- }
- )
+ return {
+ "account_data": {"events": sync_result.account_data},
+ "to_device": {"events": sync_result.to_device},
+ "device_lists": {
+ "changed": list(sync_result.device_lists.changed),
+ "left": list(sync_result.device_lists.left),
+ },
+ "presence": SyncRestServlet.encode_presence(sync_result.presence, time_now),
+ "rooms": {"join": joined, "invite": invited, "leave": archived},
+ "groups": {
+ "join": sync_result.groups.join,
+ "invite": sync_result.groups.invite,
+ "leave": sync_result.groups.leave,
+ },
+ "device_one_time_keys_count": sync_result.device_one_time_keys_count,
+ "next_batch": sync_result.next_batch.to_string(),
+ }
@staticmethod
def encode_presence(events, time_now):
@@ -273,7 +269,7 @@ class SyncRestServlet(RestServlet):
event_formatter=event_formatter,
)
- defer.returnValue(joined)
+ return joined
@defer.inlineCallbacks
def encode_invited(self, rooms, time_now, token_id, event_formatter):
@@ -309,7 +305,7 @@ class SyncRestServlet(RestServlet):
invited_state.append(invite)
invited[room.room_id] = {"invite_state": {"events": invited_state}}
- defer.returnValue(invited)
+ return invited
@defer.inlineCallbacks
def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter):
@@ -342,7 +338,7 @@ class SyncRestServlet(RestServlet):
event_formatter=event_formatter,
)
- defer.returnValue(joined)
+ return joined
@defer.inlineCallbacks
def encode_room(
@@ -414,7 +410,7 @@ class SyncRestServlet(RestServlet):
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
- defer.returnValue(result)
+ return result
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index 07b6ede603..d173544355 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -45,7 +45,7 @@ class TagListServlet(RestServlet):
tags = yield self.store.get_tags_for_room(user_id, room_id)
- defer.returnValue((200, {"tags": tags}))
+ return (200, {"tags": tags})
class TagServlet(RestServlet):
@@ -76,7 +76,7 @@ class TagServlet(RestServlet):
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_DELETE(self, request, user_id, room_id, tag):
@@ -88,7 +88,7 @@ class TagServlet(RestServlet):
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 1e66662a05..158e686b01 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -40,7 +40,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
yield self.auth.get_user_by_req(request, allow_guest=True)
protocols = yield self.appservice_handler.get_3pe_protocols()
- defer.returnValue((200, protocols))
+ return (200, protocols)
class ThirdPartyProtocolServlet(RestServlet):
@@ -60,9 +60,9 @@ class ThirdPartyProtocolServlet(RestServlet):
only_protocol=protocol
)
if protocol in protocols:
- defer.returnValue((200, protocols[protocol]))
+ return (200, protocols[protocol])
else:
- defer.returnValue((404, {"error": "Unknown protocol"}))
+ return (404, {"error": "Unknown protocol"})
class ThirdPartyUserServlet(RestServlet):
@@ -85,7 +85,7 @@ class ThirdPartyUserServlet(RestServlet):
ThirdPartyEntityKind.USER, protocol, fields
)
- defer.returnValue((200, results))
+ return (200, results)
class ThirdPartyLocationServlet(RestServlet):
@@ -108,7 +108,7 @@ class ThirdPartyLocationServlet(RestServlet):
ThirdPartyEntityKind.LOCATION, protocol, fields
)
- defer.returnValue((200, results))
+ return (200, results)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index e19fb6d583..079a823c53 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -15,10 +15,13 @@
import logging
+from signedjson.sign import sign_json
+
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
from ._base import client_patterns
@@ -37,6 +40,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -60,10 +64,20 @@ class UserDirectorySearchRestServlet(RestServlet):
user_id = requester.user.to_string()
if not self.hs.config.user_directory_search_enabled:
- defer.returnValue((200, {"limited": False, "results": []}))
+ return (200, {"limited": False, "results": []})
body = parse_json_object_from_request(request)
+ if self.hs.config.user_directory_defer_to_id_server:
+ signed_body = sign_json(
+ body, self.hs.hostname, self.hs.config.signing_key[0]
+ )
+ url = "%s/_matrix/identity/api/v1/user_directory/search" % (
+ self.hs.config.user_directory_defer_to_id_server,
+ )
+ resp = yield self.http_client.post_json_get_json(url, signed_body)
+ defer.returnValue((200, resp))
+
limit = body.get("limit", 10)
limit = min(limit, 50)
@@ -76,8 +90,90 @@ class UserDirectorySearchRestServlet(RestServlet):
user_id, search_term, limit
)
- defer.returnValue((200, results))
+ return (200, results)
+
+
+class UserInfoServlet(RestServlet):
+ """
+ GET /user/{user_id}/info HTTP/1.1
+ """
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
+
+ def __init__(self, hs):
+ super(UserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.notifier = hs.get_notifier()
+ self.clock = hs.get_clock()
+ self.transport_layer = hs.get_federation_transport_client()
+ registry = hs.get_federation_registry()
+
+ if not registry.query_handlers.get("user_info"):
+ registry.register_query_handler("user_info", self._on_federation_query)
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ # Ensure the user is authenticated
+ yield self.auth.get_user_by_req(request, allow_guest=False)
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ # Attempt to make a federation request to the server that owns this user
+ args = {"user_id": user_id}
+ res = yield self.transport_layer.make_query(
+ user.domain, "user_info", args, retry_on_dns_fail=True
+ )
+ defer.returnValue((200, res))
+
+ res = yield self._get_user_info(user_id)
+ defer.returnValue((200, res))
+
+ @defer.inlineCallbacks
+ def _on_federation_query(self, args):
+ """Called when a request for user information appears over federation
+
+ Args:
+ args (dict): Dictionary of query arguments provided by the request
+
+ Returns:
+ Deferred[dict]: Deactivation and expiration information for a given user
+ """
+ user_id = args.get("user_id")
+ if not user_id:
+ raise SynapseError(400, "user_id not provided")
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ res = yield self._get_user_info(user_id)
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def _get_user_info(self, user_id):
+ """Retrieve information about a given user
+
+ Args:
+ user_id (str): The User ID of a given user on this homeserver
+
+ Returns:
+ Deferred[dict]: Deactivation and expiration information for a given user
+ """
+ # Check whether user is deactivated
+ is_deactivated = yield self.store.get_user_deactivated_status(user_id)
+
+ # Check whether user is expired
+ expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
+ is_expired = (
+ expiration_ts is not None and self.clock.time_msec() >= expiration_ts
+ )
+
+ res = {"expired": is_expired, "deactivated": is_deactivated}
+ defer.returnValue(res)
def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)
+ UserInfoServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 65afffbb42..92beefa176 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -171,7 +171,7 @@ class MediaRepository(object):
yield self._generate_thumbnails(None, media_id, media_id, media_type)
- defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
+ return "mxc://%s/%s" % (self.server_name, media_id)
@defer.inlineCallbacks
def get_local_media(self, request, media_id, name):
@@ -282,7 +282,7 @@ class MediaRepository(object):
with responder:
pass
- defer.returnValue(media_info)
+ return media_info
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
@@ -317,14 +317,14 @@ class MediaRepository(object):
responder = yield self.media_storage.fetch_media(file_info)
if responder:
- defer.returnValue((responder, media_info))
+ return (responder, media_info)
# Failed to find the file anywhere, lets download it.
media_info = yield self._download_remote_file(server_name, media_id, file_id)
responder = yield self.media_storage.fetch_media(file_info)
- defer.returnValue((responder, media_info))
+ return (responder, media_info)
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id, file_id):
@@ -421,7 +421,7 @@ class MediaRepository(object):
yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
- defer.returnValue(media_info)
+ return media_info
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
@@ -500,7 +500,7 @@ class MediaRepository(object):
media_id, t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue(output_path)
+ return output_path
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(
@@ -554,7 +554,7 @@ class MediaRepository(object):
t_len,
)
- defer.returnValue(output_path)
+ return output_path
@defer.inlineCallbacks
def _generate_thumbnails(
@@ -667,7 +667,7 @@ class MediaRepository(object):
media_id, t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue({"width": m_width, "height": m_height})
+ return {"width": m_width, "height": m_height}
@defer.inlineCallbacks
def delete_old_remote_media(self, before_ts):
@@ -704,7 +704,7 @@ class MediaRepository(object):
yield self.store.delete_remote_media(origin, media_id)
deleted += 1
- defer.returnValue({"deleted": deleted})
+ return {"deleted": deleted}
class MediaRepositoryResource(Resource):
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 25e5ac2848..3b87717a5a 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -69,7 +69,7 @@ class MediaStorage(object):
)
yield finish_cb()
- defer.returnValue(fname)
+ return fname
@contextlib.contextmanager
def store_into_file(self, file_info):
@@ -143,14 +143,14 @@ class MediaStorage(object):
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
- defer.returnValue(FileResponder(open(local_path, "rb")))
+ return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
if res:
- defer.returnValue(res)
+ return res
- defer.returnValue(None)
+ return None
@defer.inlineCallbacks
def ensure_media_is_in_local_cache(self, file_info):
@@ -166,7 +166,7 @@ class MediaStorage(object):
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
- defer.returnValue(local_path)
+ return local_path
dirname = os.path.dirname(local_path)
if not os.path.exists(dirname):
@@ -181,7 +181,7 @@ class MediaStorage(object):
)
yield res.write_to_consumer(consumer)
yield consumer.wait()
- defer.returnValue(local_path)
+ return local_path
raise Exception("file could not be found")
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 5871737bfd..da4ee52a4d 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
+ http_proxy=os.getenv("http_proxy"),
+ https_proxy=os.getenv("HTTPS_PROXY"),
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
@@ -182,7 +184,7 @@ class PreviewUrlResource(DirectServeResource):
og = cache_result["og"]
if isinstance(og, six.text_type):
og = og.encode("utf8")
- defer.returnValue(og)
+ return og
return
media_info = yield self._download_url(url, user)
@@ -284,7 +286,7 @@ class PreviewUrlResource(DirectServeResource):
media_info["created_ts"],
)
- defer.returnValue(jsonog)
+ return jsonog
@defer.inlineCallbacks
def _download_url(self, url, user):
@@ -354,22 +356,20 @@ class PreviewUrlResource(DirectServeResource):
# therefore not expire it.
raise
- defer.returnValue(
- {
- "media_type": media_type,
- "media_length": length,
- "download_name": download_name,
- "created_ts": time_now_ms,
- "filesystem_id": file_id,
- "filename": fname,
- "uri": uri,
- "response_code": code,
- # FIXME: we should calculate a proper expiration based on the
- # Cache-Control and Expire headers. But for now, assume 1 hour.
- "expires": 60 * 60 * 1000,
- "etag": headers["ETag"][0] if "ETag" in headers else None,
- }
- )
+ return {
+ "media_type": media_type,
+ "media_length": length,
+ "download_name": download_name,
+ "created_ts": time_now_ms,
+ "filesystem_id": file_id,
+ "filename": fname,
+ "uri": uri,
+ "response_code": code,
+ # FIXME: we should calculate a proper expiration based on the
+ # Cache-Control and Expire headers. But for now, assume 1 hour.
+ "expires": 60 * 60 * 1000,
+ "etag": headers["ETag"][0] if "ETag" in headers else None,
+ }
def _start_expire_url_cache_data(self):
return run_as_background_process(
diff --git a/synapse/rulecheck/__init__.py b/synapse/rulecheck/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/rulecheck/__init__.py
diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py
new file mode 100644
index 0000000000..6f2a1931c5
--- /dev/null
+++ b/synapse/rulecheck/domain_rule_checker.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.config._base import ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DomainRuleChecker(object):
+ """
+ A re-implementation of the SpamChecker that prevents users in one domain from
+ inviting users in other domains to rooms, based on a configuration.
+
+ Takes a config in the format:
+
+ spam_checker:
+ module: "rulecheck.DomainRuleChecker"
+ config:
+ domain_mapping:
+ "inviter_domain": [ "invitee_domain_permitted", "other_domain_permitted" ]
+ "other_inviter_domain": [ "invitee_domain_permitted" ]
+ default: False
+
+ # Only let local users join rooms if they were explicitly invited.
+ can_only_join_rooms_with_invite: false
+
+ # Only let local users create rooms if they are inviting only one
+ # other user, and that user matches the rules above.
+ can_only_create_one_to_one_rooms: false
+
+ # Only let local users invite during room creation, regardless of the
+ # domain mapping rules above.
+ can_only_invite_during_room_creation: false
+
+ # Prevent local users from inviting users from certain domains to
+ # rooms published in the room directory.
+ domains_prevented_from_being_invited_to_published_rooms: []
+
+ # Allow third party invites
+ can_invite_by_third_party_id: true
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config):
+ self.domain_mapping = config["domain_mapping"] or {}
+ self.default = config["default"]
+
+ self.can_only_join_rooms_with_invite = config.get(
+ "can_only_join_rooms_with_invite", False
+ )
+ self.can_only_create_one_to_one_rooms = config.get(
+ "can_only_create_one_to_one_rooms", False
+ )
+ self.can_only_invite_during_room_creation = config.get(
+ "can_only_invite_during_room_creation", False
+ )
+ self.can_invite_by_third_party_id = config.get(
+ "can_invite_by_third_party_id", True
+ )
+ self.domains_prevented_from_being_invited_to_published_rooms = config.get(
+ "domains_prevented_from_being_invited_to_published_rooms", []
+ )
+
+ def check_event_for_spam(self, event):
+ """Implements synapse.events.SpamChecker.check_event_for_spam
+ """
+ return False
+
+ def user_may_invite(
+ self,
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room=False,
+ ):
+ """Implements synapse.events.SpamChecker.user_may_invite
+ """
+ if self.can_only_invite_during_room_creation and not new_room:
+ return False
+
+ if not self.can_invite_by_third_party_id and third_party_invite:
+ return False
+
+ # This is a third party invite (without a bound mxid), so unless we have
+ # banned all third party invites (above) we allow it.
+ if not invitee_userid:
+ return True
+
+ inviter_domain = self._get_domain_from_id(inviter_userid)
+ invitee_domain = self._get_domain_from_id(invitee_userid)
+
+ if inviter_domain not in self.domain_mapping:
+ return self.default
+
+ if (
+ published_room
+ and invitee_domain
+ in self.domains_prevented_from_being_invited_to_published_rooms
+ ):
+ return False
+
+ return invitee_domain in self.domain_mapping[inviter_domain]
+
+ def user_may_create_room(
+ self, userid, invite_list, third_party_invite_list, cloning
+ ):
+ """Implements synapse.events.SpamChecker.user_may_create_room
+ """
+
+ if cloning:
+ return True
+
+ if not self.can_invite_by_third_party_id and third_party_invite_list:
+ return False
+
+ number_of_invites = len(invite_list) + len(third_party_invite_list)
+
+ if self.can_only_create_one_to_one_rooms and number_of_invites != 1:
+ return False
+
+ return True
+
+ def user_may_create_room_alias(self, userid, room_alias):
+ """Implements synapse.events.SpamChecker.user_may_create_room_alias
+ """
+ return True
+
+ def user_may_publish_room(self, userid, room_id):
+ """Implements synapse.events.SpamChecker.user_may_publish_room
+ """
+ return True
+
+ def user_may_join_room(self, userid, room_id, is_invited):
+ """Implements synapse.events.SpamChecker.user_may_join_room
+ """
+ if self.can_only_join_rooms_with_invite and not is_invited:
+ return False
+
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ """Implements synapse.events.SpamChecker.parse_config
+ """
+ if "default" in config:
+ return config
+ else:
+ raise ConfigError("No default set for spam_config DomainRuleChecker")
+
+ @staticmethod
+ def _get_domain_from_id(mxid):
+ """Parses a string and returns the domain part of the mxid.
+
+ Args:
+ mxid (str): a valid mxid
+
+ Returns:
+ str: the domain part of the mxid
+
+ """
+ idx = mxid.find(":")
+ if idx == -1:
+ raise Exception("Invalid ID: %r" % (mxid,))
+ return mxid[idx + 1 :]
diff --git a/synapse/server.py b/synapse/server.py
index 9e28dba2b1..23be3560fa 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -23,6 +23,7 @@
# Imports required for the default HomeServer() implementation
import abc
import logging
+import os
from twisted.enterprise import adbapi
from twisted.mail.smtp import sendmail
@@ -65,6 +66,7 @@ from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.message import EventCreationHandler, MessageHandler
from synapse.handlers.pagination import PaginationHandler
+from synapse.handlers.password_policy import PasswordPolicyHandler
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler
@@ -166,6 +168,7 @@ class HomeServer(object):
"event_builder_factory",
"filtering",
"http_client_context_factory",
+ "proxied_http_client",
"simple_http_client",
"media_repository",
"media_repository_resource",
@@ -196,6 +199,7 @@ class HomeServer(object):
"account_validity_handler",
"saml_handler",
"event_client_serializer",
+ "password_policy_handler",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -304,6 +308,13 @@ class HomeServer(object):
def build_simple_http_client(self):
return SimpleHttpClient(self)
+ def build_proxied_http_client(self):
+ return SimpleHttpClient(
+ self,
+ http_proxy=os.getenv("http_proxy"),
+ https_proxy=os.getenv("HTTPS_PROXY"),
+ )
+
def build_room_creation_handler(self):
return RoomCreationHandler(self)
@@ -533,6 +544,9 @@ class HomeServer(object):
def build_event_client_serializer(self):
return EventClientSerializer(self)
+ def build_password_policy_handler(self):
+ return PasswordPolicyHandler(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 16f8f6b573..56f9cd06e5 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -12,6 +12,7 @@ import synapse.handlers.message
import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
+import synapse.http.client
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
@@ -38,6 +39,14 @@ class HomeServer(object):
pass
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
+ def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
+ """Fetch an HTTP client implementation which doesn't do any blacklisting
+ or support any HTTP_PROXY settings"""
+ pass
+ def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
+ """Fetch an HTTP client implementation which doesn't do any blacklisting
+ but does support HTTP_PROXY settings"""
+ pass
def get_deactivate_account_handler(
self
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index f183743f31..729c097e6d 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -193,4 +193,4 @@ class ResourceLimitsServerNotices(object):
if event_id in referenced_events:
referenced_events.remove(event.event_id)
- defer.returnValue((currently_blocked, referenced_events))
+ return (currently_blocked, referenced_events)
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 71e7e75320..2dac90578c 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -86,7 +86,7 @@ class ServerNoticesManager(object):
res = yield self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False
)
- defer.returnValue(res)
+ return res
@cachedInlineCallbacks()
def get_notice_room_for_user(self, user_id):
@@ -120,7 +120,7 @@ class ServerNoticesManager(object):
# we found a room which our user shares with the system notice
# user
logger.info("Using room %s", room.room_id)
- defer.returnValue(room.room_id)
+ return room.room_id
# apparently no existing notice room: create a new one
logger.info("Creating server notices room for %s", user_id)
@@ -158,4 +158,4 @@ class ServerNoticesManager(object):
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
logger.info("Created server notices room %s for %s", room_id, user_id)
- defer.returnValue(room_id)
+ return room_id
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 9f708fa205..a0d34f16ea 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -135,7 +135,7 @@ class StateHandler(object):
event = None
if event_id:
event = yield self.store.get_event(event_id, allow_none=True)
- defer.returnValue(event)
+ return event
return
state_map = yield self.store.get_events(
@@ -145,7 +145,7 @@ class StateHandler(object):
key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
}
- defer.returnValue(state)
+ return state
@defer.inlineCallbacks
def get_current_state_ids(self, room_id, latest_event_ids=None):
@@ -169,7 +169,7 @@ class StateHandler(object):
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
- defer.returnValue(state)
+ return state
@defer.inlineCallbacks
def get_current_users_in_room(self, room_id, latest_event_ids=None):
@@ -189,7 +189,7 @@ class StateHandler(object):
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
- defer.returnValue(joined_users)
+ return joined_users
@defer.inlineCallbacks
def get_current_hosts_in_room(self, room_id, latest_event_ids=None):
@@ -198,7 +198,7 @@ class StateHandler(object):
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
- defer.returnValue(joined_hosts)
+ return joined_hosts
@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
@@ -241,7 +241,7 @@ class StateHandler(object):
prev_state_ids=prev_state_ids,
)
- defer.returnValue(context)
+ return context
if old_state:
# We already have the state, so we don't need to calculate it.
@@ -275,7 +275,7 @@ class StateHandler(object):
prev_state_ids=prev_state_ids,
)
- defer.returnValue(context)
+ return context
logger.debug("calling resolve_state_groups from compute_event_context")
@@ -343,7 +343,7 @@ class StateHandler(object):
delta_ids=delta_ids,
)
- defer.returnValue(context)
+ return context
@defer.inlineCallbacks
def resolve_state_groups_for_events(self, room_id, event_ids):
@@ -368,19 +368,17 @@ class StateHandler(object):
state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids)
if len(state_groups_ids) == 0:
- defer.returnValue(_StateCacheEntry(state={}, state_group=None))
+ return _StateCacheEntry(state={}, state_group=None)
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
- defer.returnValue(
- _StateCacheEntry(
- state=state_list,
- state_group=name,
- prev_group=prev_group,
- delta_ids=delta_ids,
- )
+ return _StateCacheEntry(
+ state=state_list,
+ state_group=name,
+ prev_group=prev_group,
+ delta_ids=delta_ids,
)
room_version = yield self.store.get_room_version(room_id)
@@ -392,7 +390,7 @@ class StateHandler(object):
None,
state_res_store=StateResolutionStore(self.store),
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event):
@@ -415,7 +413,7 @@ class StateHandler(object):
new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)}
- defer.returnValue(new_state)
+ return new_state
class StateResolutionHandler(object):
@@ -479,7 +477,7 @@ class StateResolutionHandler(object):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
- defer.returnValue(cache)
+ return cache
logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
@@ -525,7 +523,7 @@ class StateResolutionHandler(object):
if self._state_cache is not None:
self._state_cache[group_names] = cache
- defer.returnValue(cache)
+ return cache
def _make_state_cache_entry(new_state, state_groups_ids):
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 88acd4817e..a2f92d9ff9 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -55,7 +55,7 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
- defer.returnValue(state_sets[0])
+ return state_sets[0]
unconflicted_state, conflicted_state = _seperate(state_sets)
@@ -97,10 +97,8 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
- defer.returnValue(
- _resolve_with_state(
- unconflicted_state, conflicted_state, auth_events, state_map
- )
+ return _resolve_with_state(
+ unconflicted_state, conflicted_state, auth_events, state_map
)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index db969e8997..b327c86f40 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -63,7 +63,7 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
unconflicted_state, conflicted_state = _seperate(state_sets)
if not conflicted_state:
- defer.returnValue(unconflicted_state)
+ return unconflicted_state
logger.debug("%d conflicted state entries", len(conflicted_state))
logger.debug("Calculating auth chain difference")
@@ -137,7 +137,7 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
logger.debug("done")
- defer.returnValue(resolved_state)
+ return resolved_state
@defer.inlineCallbacks
@@ -168,18 +168,18 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
aev = yield _get_event(aid, event_map, state_res_store)
if (aev.type, aev.state_key) == (EventTypes.Create, ""):
if aev.content.get("creator") == event.sender:
- defer.returnValue(100)
+ return 100
break
- defer.returnValue(0)
+ return 0
level = pl.content.get("users", {}).get(event.sender)
if level is None:
level = pl.content.get("users_default", 0)
if level is None:
- defer.returnValue(0)
+ return 0
else:
- defer.returnValue(int(level))
+ return int(level)
@defer.inlineCallbacks
@@ -224,7 +224,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
union = set().union(*auth_sets)
- defer.returnValue(union - intersection)
+ return union - intersection
def _seperate(state_sets):
@@ -343,7 +343,7 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
it = lexicographical_topological_sort(graph, key=_get_power_order)
sorted_events = list(it)
- defer.returnValue(sorted_events)
+ return sorted_events
@defer.inlineCallbacks
@@ -396,7 +396,7 @@ def _iterative_auth_checks(
except AuthError:
pass
- defer.returnValue(resolved_state)
+ return resolved_state
@defer.inlineCallbacks
@@ -439,7 +439,7 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor
event_ids.sort(key=lambda ev_id: order_map[ev_id])
- defer.returnValue(event_ids)
+ return event_ids
@defer.inlineCallbacks
@@ -462,7 +462,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
while event:
depth = mainline_map.get(event.event_id)
if depth is not None:
- defer.returnValue(depth)
+ return depth
auth_events = event.auth_event_ids()
event = None
@@ -474,7 +474,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
break
# Didn't find a power level auth event, so we just return 0
- defer.returnValue(0)
+ return 0
@defer.inlineCallbacks
@@ -493,7 +493,7 @@ def _get_event(event_id, event_map, state_res_store):
if event_id not in event_map:
events = yield state_res_store.get_events([event_id], allow_rejected=True)
event_map.update(events)
- defer.returnValue(event_map[event_id])
+ return event_map[event_id]
def lexicographical_topological_sort(graph, key):
diff --git a/synapse/static/index.html b/synapse/static/index.html
index d3f1c7dce0..bf46df9097 100644
--- a/synapse/static/index.html
+++ b/synapse/static/index.html
@@ -48,13 +48,13 @@
</div>
<h1>It works! Synapse is running</h1>
<p>Your Synapse server is listening on this port and is ready for messages.</p>
- <p>To use this server you'll need <a href="https://matrix.org/docs/projects/try-matrix-now.html#clients" target="_blank">a Matrix client</a>.
+ <p>To use this server you'll need <a href="https://matrix.org/docs/projects/try-matrix-now.html#clients" target="_blank" rel="noopener noreferrer">a Matrix client</a>.
</p>
<p>Welcome to the Matrix universe :)</p>
<hr>
<p>
<small>
- <a href="https://matrix.org" target="_blank">
+ <a href="https://matrix.org" target="_blank" rel="noopener noreferrer">
matrix.org
</a>
</small>
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 6b0ca80087..e7f6ea7286 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -469,7 +469,7 @@ class DataStore(
return self._simple_select_list(
table="users",
keyvalues={},
- retcols=["name", "password_hash", "is_guest", "admin"],
+ retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
desc="get_users",
)
@@ -494,11 +494,11 @@ class DataStore(
orderby=order,
start=start,
limit=limit,
- retcols=["name", "password_hash", "is_guest", "admin"],
+ retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
)
count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
retval = {"users": users, "total": count}
- defer.returnValue(retval)
+ return retval
def search_users(self, term):
"""Function to search users list for one or more users with
@@ -514,7 +514,7 @@ class DataStore(
table="users",
term=term,
col="name",
- retcols=["name", "password_hash", "is_guest", "admin"],
+ retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
desc="search_users",
)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2f940dbae6..56a54267af 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -86,7 +86,21 @@ _CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
- method."""
+ method.
+
+ Args:
+ txn: The database transcation object to wrap.
+ name (str): The name of this transactions for logging.
+ database_engine (Sqlite3Engine|PostgresEngine)
+ after_callbacks(list|None): A list that callbacks will be appended to
+ that have been added by `call_after` which should be run on
+ successful completion of the transaction. None indicates that no
+ callbacks should be allowed to be scheduled to run.
+ exception_callbacks(list|None): A list that callbacks will be appended
+ to that have been added by `call_on_exception` which should be run
+ if transaction ends with an error. None indicates that no callbacks
+ should be allowed to be scheduled to run.
+ """
__slots__ = [
"txn",
@@ -97,7 +111,7 @@ class LoggingTransaction(object):
]
def __init__(
- self, txn, name, database_engine, after_callbacks, exception_callbacks
+ self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
@@ -346,14 +360,11 @@ class SQLBaseStore(object):
expiration_ts,
)
- self._simple_insert_txn(
+ self._simple_upsert_txn(
txn,
"account_validity",
- values={
- "user_id": user_id,
- "expiration_ts_ms": expiration_ts,
- "email_sent": False,
- },
+ keyvalues={"user_id": user_id},
+ values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)
def start_profiling(self):
@@ -499,7 +510,7 @@ class SQLBaseStore(object):
after_callback(*after_args, **after_kwargs)
raise
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
@@ -539,7 +550,7 @@ class SQLBaseStore(object):
with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
- defer.returnValue(result)
+ return result
@staticmethod
def cursor_to_dict(cursor):
@@ -601,8 +612,8 @@ class SQLBaseStore(object):
# a cursor after we receive an error from the db.
if not or_ignore:
raise
- defer.returnValue(False)
- defer.returnValue(True)
+ return False
+ return True
@staticmethod
def _simple_insert_txn(txn, table, values):
@@ -694,7 +705,7 @@ class SQLBaseStore(object):
insertion_values,
lock=lock,
)
- defer.returnValue(result)
+ return result
except self.database_engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
@@ -1107,7 +1118,7 @@ class SQLBaseStore(object):
results = []
if not iterable:
- defer.returnValue(results)
+ return results
# iterables can not be sliced, so convert it to a list first
it_list = list(iterable)
@@ -1128,7 +1139,7 @@ class SQLBaseStore(object):
results.extend(rows)
- defer.returnValue(results)
+ return results
@classmethod
def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 8394389073..9fa5b4f3d6 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -111,9 +111,9 @@ class AccountDataWorkerStore(SQLBaseStore):
)
if result:
- defer.returnValue(json.loads(result))
+ return json.loads(result)
else:
- defer.returnValue(None)
+ return None
@cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
@@ -264,11 +264,9 @@ class AccountDataWorkerStore(SQLBaseStore):
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
- defer.returnValue(False)
+ return False
- defer.returnValue(
- ignored_user_id in ignored_account_data.get("ignored_users", {})
- )
+ return ignored_user_id in ignored_account_data.get("ignored_users", {})
class AccountDataStore(AccountDataWorkerStore):
@@ -332,7 +330,7 @@ class AccountDataStore(AccountDataWorkerStore):
)
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def add_account_data_for_user(self, user_id, account_data_type, content):
@@ -373,7 +371,7 @@ class AccountDataStore(AccountDataWorkerStore):
)
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
def _update_max_stream_id(self, next_id):
"""Update the max stream_id
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index eb329ebd8b..1e9977e1bc 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -35,7 +35,7 @@ def _make_exclusive_regex(services_cache):
exclusive_user_regexes = [
regex.pattern
for service in services_cache
- for regex in service.get_exlusive_user_regexes()
+ for regex in service.get_exclusive_user_regexes()
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
@@ -145,7 +145,7 @@ class ApplicationServiceTransactionWorkerStore(
for service in as_list:
if service.id == res["as_id"]:
services.append(service)
- defer.returnValue(services)
+ return services
@defer.inlineCallbacks
def get_appservice_state(self, service):
@@ -164,9 +164,9 @@ class ApplicationServiceTransactionWorkerStore(
desc="get_appservice_state",
)
if result:
- defer.returnValue(result.get("state"))
+ return result.get("state")
return
- defer.returnValue(None)
+ return None
def set_appservice_state(self, service, state):
"""Set the application service state.
@@ -298,15 +298,13 @@ class ApplicationServiceTransactionWorkerStore(
)
if not entry:
- defer.returnValue(None)
+ return None
event_ids = json.loads(entry["event_ids"])
events = yield self.get_events_as_list(event_ids)
- defer.returnValue(
- AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
- )
+ return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
def _get_last_txn(self, txn, service_id):
txn.execute(
@@ -360,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore(
events = yield self.get_events_as_list(event_ids)
- defer.returnValue((upper_bound, events))
+ return (upper_bound, events)
class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 50f913a414..e5f0668f09 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -115,7 +115,7 @@ class BackgroundUpdateStore(SQLBaseStore):
" Unscheduling background update task."
)
self._all_done = True
- defer.returnValue(None)
+ return None
@defer.inlineCallbacks
def has_completed_background_updates(self):
@@ -127,11 +127,11 @@ class BackgroundUpdateStore(SQLBaseStore):
# if we've previously determined that there is nothing left to do, that
# is easy
if self._all_done:
- defer.returnValue(True)
+ return True
# obviously, if we have things in our queue, we're not done.
if self._background_update_queue:
- defer.returnValue(False)
+ return False
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
@@ -144,9 +144,9 @@ class BackgroundUpdateStore(SQLBaseStore):
)
if not updates:
self._all_done = True
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def do_next_background_update(self, desired_duration_ms):
@@ -173,14 +173,14 @@ class BackgroundUpdateStore(SQLBaseStore):
if not self._background_update_queue:
# no work left to do
- defer.returnValue(None)
+ return None
# pop from the front, and add back to the back
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
res = yield self._do_background_update(update_name, desired_duration_ms)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def _do_background_update(self, update_name, desired_duration_ms):
@@ -231,7 +231,7 @@ class BackgroundUpdateStore(SQLBaseStore):
performance.update(items_updated, duration_ms)
- defer.returnValue(len(self._background_update_performance))
+ return len(self._background_update_performance)
def register_background_update_handler(self, update_name, update_handler):
"""Register a handler for doing a background update.
@@ -266,7 +266,7 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update(update_name)
- defer.returnValue(1)
+ return 1
self.register_background_update_handler(update_name, noop_update)
@@ -370,7 +370,7 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.info("Adding index %s to %s", index_name, table)
yield self.runWithConnection(runner)
yield self._end_background_update(update_name)
- defer.returnValue(1)
+ return 1
self.register_background_update_handler(update_name, updater)
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index bda68de5be..6db8c54077 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -104,7 +104,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
yield self.runWithConnection(f)
yield self._end_background_update("user_ips_drop_nonunique_index")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _analyze_user_ip(self, progress, batch_size):
@@ -121,7 +121,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
yield self._end_background_update("user_ips_analyze")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _remove_user_ip_dupes(self, progress, batch_size):
@@ -291,7 +291,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
if last:
yield self._end_background_update("user_ips_remove_dupes")
- defer.returnValue(batch_size)
+ return batch_size
@defer.inlineCallbacks
def insert_client_ip(
@@ -401,7 +401,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"device_id": did,
"last_seen": last_seen,
}
- defer.returnValue(ret)
+ return ret
@classmethod
def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
@@ -461,14 +461,12 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
for row in rows
)
- defer.returnValue(
- list(
- {
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "last_seen": last_seen,
- }
- for (access_token, ip), (user_agent, last_seen) in iteritems(results)
- )
+ return list(
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for (access_token, ip), (user_agent, last_seen) in iteritems(results)
)
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 4ea0deea4f..79bb0ea46d 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -92,7 +92,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
user_id, last_deleted_stream_id
)
if not has_changed:
- defer.returnValue(0)
+ return 0
def delete_messages_for_device_txn(txn):
sql = (
@@ -115,7 +115,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
last_deleted_stream_id, up_to_stream_id
)
- defer.returnValue(count)
+ return count
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit
@@ -263,7 +263,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
destination, stream_id
)
- defer.returnValue(self._device_inbox_id_gen.get_current_token())
+ return self._device_inbox_id_gen.get_current_token()
@defer.inlineCallbacks
def add_messages_from_remote_to_device_inbox(
@@ -312,7 +312,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
- defer.returnValue(stream_id)
+ return stream_id
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
@@ -426,4 +426,4 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
- defer.returnValue(1)
+ return 1
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index d2b113a4e7..8f72d92895 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -71,7 +71,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="get_devices_by_user",
)
- defer.returnValue({d["device_id"]: d for d in devices})
+ return {d["device_id"]: d for d in devices}
@defer.inlineCallbacks
def get_devices_by_remote(self, destination, from_stream_id, limit):
@@ -88,7 +88,7 @@ class DeviceWorkerStore(SQLBaseStore):
destination, int(from_stream_id)
)
if not has_changed:
- defer.returnValue((now_stream_id, []))
+ return (now_stream_id, [])
# We retrieve n+1 devices from the list of outbound pokes where n is
# our outbound device update limit. We then check if the very last
@@ -111,7 +111,7 @@ class DeviceWorkerStore(SQLBaseStore):
# Return an empty list if there are no updates
if not updates:
- defer.returnValue((now_stream_id, []))
+ return (now_stream_id, [])
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
@@ -147,13 +147,13 @@ class DeviceWorkerStore(SQLBaseStore):
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
if not query_map:
- defer.returnValue((stream_id_cutoff, []))
+ return (stream_id_cutoff, [])
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
- defer.returnValue((now_stream_id, results))
+ return (now_stream_id, results)
def _get_devices_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit
@@ -232,7 +232,7 @@ class DeviceWorkerStore(SQLBaseStore):
results.append(result)
- defer.returnValue(results)
+ return results
def _get_last_device_update_for_remote_user(
self, destination, user_id, from_stream_id
@@ -330,7 +330,7 @@ class DeviceWorkerStore(SQLBaseStore):
else:
results[user_id] = yield self._get_cached_devices_for_user(user_id)
- defer.returnValue((user_ids_not_in_cache, results))
+ return (user_ids_not_in_cache, results)
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
@@ -340,7 +340,7 @@ class DeviceWorkerStore(SQLBaseStore):
retcol="content",
desc="_get_cached_user_device",
)
- defer.returnValue(db_to_json(content))
+ return db_to_json(content)
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
@@ -350,9 +350,9 @@ class DeviceWorkerStore(SQLBaseStore):
retcols=("device_id", "content"),
desc="_get_cached_devices_for_user",
)
- defer.returnValue(
- {device["device_id"]: db_to_json(device["content"]) for device in devices}
- )
+ return {
+ device["device_id"]: db_to_json(device["content"]) for device in devices
+ }
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
@@ -482,7 +482,7 @@ class DeviceWorkerStore(SQLBaseStore):
results = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows})
- defer.returnValue(results)
+ return results
class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
@@ -543,7 +543,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"""
key = (user_id, device_id)
if self.device_id_exists_cache.get(key, None):
- defer.returnValue(False)
+ return False
try:
inserted = yield self._simple_insert(
@@ -557,7 +557,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
or_ignore=True,
)
self.device_id_exists_cache.prefill(key, True)
- defer.returnValue(inserted)
+ return inserted
except Exception as e:
logger.error(
"store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -780,7 +780,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
hosts,
stream_id,
)
- defer.returnValue(stream_id)
+ return stream_id
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
now = self._clock.time_msec()
@@ -889,4 +889,4 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
yield self.runWithConnection(f)
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
- defer.returnValue(1)
+ return 1
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 201bbd430c..e966a73f3d 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -46,7 +46,7 @@ class DirectoryWorkerStore(SQLBaseStore):
)
if not room_id:
- defer.returnValue(None)
+ return None
return
servers = yield self._simple_select_onecol(
@@ -57,10 +57,10 @@ class DirectoryWorkerStore(SQLBaseStore):
)
if not servers:
- defer.returnValue(None)
+ return None
return
- defer.returnValue(RoomAliasMapping(room_id, room_alias.to_string(), servers))
+ return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(
@@ -125,7 +125,7 @@ class DirectoryStore(DirectoryWorkerStore):
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
@@ -133,7 +133,7 @@ class DirectoryStore(DirectoryWorkerStore):
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
- defer.returnValue(room_id)
+ return room_id
def _delete_room_alias_txn(self, txn, room_alias):
txn.execute(
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py
index f40ef2ab64..99128f2df7 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/e2e_room_keys.py
@@ -61,7 +61,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
row["session_data"] = json.loads(row["session_data"])
- defer.returnValue(row)
+ return row
@defer.inlineCallbacks
def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
@@ -118,7 +118,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
try:
version = int(version)
except ValueError:
- defer.returnValue({"rooms": {}})
+ return {"rooms": {}}
keyvalues = {"user_id": user_id, "version": version}
if room_id:
@@ -151,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"session_data": json.loads(row["session_data"]),
}
- defer.returnValue(sessions)
+ return sessions
@defer.inlineCallbacks
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2fabb9e2cb..1e07474e70 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -41,7 +41,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
dict containing "key_json", "device_display_name".
"""
if not query_list:
- defer.returnValue({})
+ return {}
results = yield self.runInteraction(
"get_e2e_device_keys",
@@ -55,7 +55,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for device_id, device_info in iteritems(device_keys):
device_info["keys"] = db_to_json(device_info.pop("key_json"))
- defer.returnValue(results)
+ return results
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
@@ -130,9 +130,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
desc="add_e2e_one_time_keys_check",
)
- defer.returnValue(
- {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
- )
+ return {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
@defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index cb4478342f..4f500d893e 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -131,9 +131,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
if not rows:
- defer.returnValue(0)
+ return 0
else:
- defer.returnValue(max(row["depth"] for row in rows))
+ return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
@@ -169,7 +169,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# make sure that we don't completely ignore the older events.
res = res[0:5] + random.sample(res[5:], 5)
- defer.returnValue(res)
+ return res
def get_latest_event_ids_and_hashes_in_room(self, room_id):
"""
@@ -411,7 +411,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
limit,
)
events = yield self.get_events_as_list(ids)
- defer.returnValue(events)
+ return events
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
@@ -463,7 +463,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
desc="get_successor_events",
)
- defer.returnValue([row["event_id"] for row in rows])
+ return [row["event_id"] for row in rows]
class EventFederationStore(EventFederationWorkerStore):
@@ -654,4 +654,4 @@ class EventFederationStore(EventFederationWorkerStore):
if not result:
yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
- defer.returnValue(batch_size)
+ return batch_size
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index eca77069fd..22025effbc 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -79,8 +79,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
database_engine=self.database_engine,
- after_callbacks=[],
- exception_callbacks=[],
)
self._find_stream_orderings_for_times_txn(cur)
cur.close()
@@ -102,7 +100,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
user_id,
last_read_event_id,
)
- defer.returnValue(ret)
+ return ret
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
@@ -180,7 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range_for_http(
@@ -281,7 +279,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Take only up to the limit. We have to stop at the limit because
# one of the subqueries may have hit the limit.
- defer.returnValue(notifs[:limit])
+ return notifs[:limit]
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range_for_email(
@@ -382,7 +380,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
notifs.sort(key=lambda r: -(r["received_ts"] or 0))
# Now return the first `limit`
- defer.returnValue(notifs[:limit])
+ return notifs[:limit]
def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
"""A fast check to see if there might be something to push for the
@@ -479,7 +477,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
)
- defer.returnValue(res)
+ return res
except Exception:
# this method is called from an exception handler, so propagating
# another exception here really isn't helpful - there's nothing
@@ -734,7 +732,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
push_actions = yield self.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
- defer.returnValue(push_actions)
+ return push_actions
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
@@ -751,7 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return txn.fetchone()
result = yield self.runInteraction("get_time_of_last_push_action_before", f)
- defer.returnValue(result[0] if result else None)
+ return result[0] if result else None
@defer.inlineCallbacks
def get_latest_push_action_stream_ordering(self):
@@ -760,7 +758,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return txn.fetchone()
result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
- defer.returnValue(result[0] or 0)
+ return result[0] or 0
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
# Sad that we have to blow away the cache for the whole room here
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index b486ca50eb..511f0f251f 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -223,7 +223,7 @@ def _retry_on_integrity_error(func):
except self.database_engine.module.IntegrityError:
logger.exception("IntegrityError, retrying.")
res = yield func(self, *args, delete_existing=True, **kwargs)
- defer.returnValue(res)
+ return res
return f
@@ -309,7 +309,7 @@ class EventsStore(
max_persisted_id = yield self._stream_id_gen.get_current_token()
- defer.returnValue(max_persisted_id)
+ return max_persisted_id
@defer.inlineCallbacks
@log_function
@@ -334,7 +334,7 @@ class EventsStore(
yield make_deferred_yieldable(deferred)
max_persisted_id = yield self._stream_id_gen.get_current_token()
- defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
+ return (event.internal_metadata.stream_ordering, max_persisted_id)
def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks
@@ -595,7 +595,7 @@ class EventsStore(
stale = latest_event_ids & result
stale_forward_extremities_counter.observe(len(stale))
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _get_events_which_are_prevs(self, event_ids):
@@ -633,7 +633,7 @@ class EventsStore(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def _get_prevs_before_rejected(self, event_ids):
@@ -695,7 +695,7 @@ class EventsStore(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
- defer.returnValue(existing_prevs)
+ return existing_prevs
@defer.inlineCallbacks
def _get_new_state_after_events(
@@ -796,7 +796,7 @@ class EventsStore(
# If they old and new groups are the same then we don't need to do
# anything.
if old_state_groups == new_state_groups:
- defer.returnValue((None, None))
+ return (None, None)
if len(new_state_groups) == 1 and len(old_state_groups) == 1:
# If we're going from one state group to another, lets check if
@@ -813,7 +813,7 @@ class EventsStore(
# the current state in memory then lets also return that,
# but it doesn't matter if we don't.
new_state = state_groups_map.get(new_state_group)
- defer.returnValue((new_state, delta_ids))
+ return (new_state, delta_ids)
# Now that we have calculated new_state_groups we need to get
# their state IDs so we can resolve to a single state set.
@@ -825,7 +825,7 @@ class EventsStore(
if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
- defer.returnValue((state_groups_map[new_state_groups.pop()], None))
+ return (state_groups_map[new_state_groups.pop()], None)
# Ok, we need to defer to the state handler to resolve our state sets.
@@ -854,7 +854,7 @@ class EventsStore(
state_res_store=StateResolutionStore(self),
)
- defer.returnValue((res.state, None))
+ return (res.state, None)
@defer.inlineCallbacks
def _calculate_state_delta(self, room_id, current_state):
@@ -877,7 +877,7 @@ class EventsStore(
if ev_id != existing_state.get(key)
}
- defer.returnValue((to_delete, to_insert))
+ return (to_delete, to_insert)
@log_function
def _persist_events_txn(
@@ -918,8 +918,6 @@ class EventsStore(
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
- self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
-
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremeties,
@@ -993,6 +991,10 @@ class EventsStore(
backfilled=backfilled,
)
+ # We call this last as it assumes we've inserted the events into
+ # room_memberships, where applicable.
+ self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
for room_id, current_state_tuple in iteritems(state_delta_by_room):
to_delete, to_insert = current_state_tuple
@@ -1062,16 +1064,16 @@ class EventsStore(
),
)
- self._simple_insert_many_txn(
- txn,
- table="current_state_events",
- values=[
- {
- "event_id": ev_id,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- }
+ # We include the membership in the current state table, hence we do
+ # a lookup when we insert. This assumes that all events have already
+ # been inserted into room_memberships.
+ txn.executemany(
+ """INSERT INTO current_state_events
+ (room_id, type, state_key, event_id, membership)
+ VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+ """,
+ [
+ (room_id, key[0], key[1], ev_id, ev_id)
for key, ev_id in iteritems(to_insert)
],
)
@@ -1455,6 +1457,9 @@ class EventsStore(
elif event.type == EventTypes.GuestAccess:
# Insert into the event_search table.
self._store_guest_access_txn(txn, event)
+ elif event.type == EventTypes.Retention:
+ # Update the room_retention table.
+ self._store_retention_policy_for_room_txn(txn, event)
self._handle_event_relations(txn, event)
@@ -1562,7 +1567,7 @@ class EventsStore(
return count
ret = yield self.runInteraction("count_messages", _count_messages)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def count_daily_sent_messages(self):
@@ -1583,7 +1588,7 @@ class EventsStore(
return count
ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def count_daily_active_rooms(self):
@@ -1598,7 +1603,7 @@ class EventsStore(
return count
ret = yield self.runInteraction("count_daily_active_rooms", _count)
- defer.returnValue(ret)
+ return ret
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
@@ -2181,7 +2186,7 @@ class EventsStore(
"""
to_1, so_1 = yield self._get_event_ordering(event_id1)
to_2, so_2 = yield self._get_event_ordering(event_id2)
- defer.returnValue((to_1, so_1) > (to_2, so_2))
+ return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000)
def _get_event_ordering(self, event_id):
@@ -2195,9 +2200,7 @@ class EventsStore(
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
- defer.returnValue(
- (int(res["topological_ordering"]), int(res["stream_ordering"]))
- )
+ return (int(res["topological_ordering"]), int(res["stream_ordering"]))
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py
index 1ce21d190c..6587f31e2b 100644
--- a/synapse/storage/events_bg_updates.py
+++ b/synapse/storage/events_bg_updates.py
@@ -135,7 +135,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if not result:
yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
@@ -212,7 +212,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if not result:
yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _cleanup_extremities_bg_update(self, progress, batch_size):
@@ -396,4 +396,4 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
- defer.returnValue(num_handled)
+ return num_handled
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 06379281b6..79680ee856 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -139,8 +139,11 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
- Deferred : A FrozenEvent.
+ Deferred[EventBase|None]
"""
+ if not isinstance(event_id, str):
+ raise TypeError("Invalid event event_id %r" % (event_id,))
+
events = yield self.get_events_as_list(
[event_id],
check_redacted=check_redacted,
@@ -157,7 +160,7 @@ class EventsWorkerStore(SQLBaseStore):
if event is None and not allow_none:
raise NotFoundError("Could not find event %s" % (event_id,))
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def get_events(
@@ -187,7 +190,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected=allow_rejected,
)
- defer.returnValue({e.event_id: e for e in events})
+ return {e.event_id: e for e in events}
@defer.inlineCallbacks
def get_events_as_list(
@@ -217,7 +220,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
if not event_ids:
- defer.returnValue([])
+ return []
# there may be duplicates so we cast the list to a set
event_entry_map = yield self._get_events_from_cache_or_db(
@@ -313,7 +316,7 @@ class EventsWorkerStore(SQLBaseStore):
event.unsigned["prev_content"] = prev.content
event.unsigned["prev_sender"] = prev.sender
- defer.returnValue(events)
+ return events
@defer.inlineCallbacks
def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
@@ -460,7 +463,7 @@ class EventsWorkerStore(SQLBaseStore):
without having to create a new transaction for each request for events.
"""
if not events:
- defer.returnValue({})
+ return {}
events_d = defer.Deferred()
with self._event_fetch_lock:
@@ -504,7 +507,7 @@ class EventsWorkerStore(SQLBaseStore):
)
)
- defer.returnValue({e.event.event_id: e for e in res if e})
+ return {e.event.event_id: e for e in res if e}
def _fetch_event_rows(self, txn, event_ids):
"""Fetch event rows from the database
@@ -617,7 +620,7 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
- defer.returnValue(cache_entry)
+ return cache_entry
@defer.inlineCallbacks
def _maybe_redact_event_row(self, original_ev, redactions):
@@ -710,7 +713,7 @@ class EventsWorkerStore(SQLBaseStore):
desc="have_events_in_timeline",
)
- defer.returnValue(set(r["event_id"] for r in rows))
+ return set(r["event_id"] for r in rows)
@defer.inlineCallbacks
def have_seen_events(self, event_ids):
@@ -736,7 +739,7 @@ class EventsWorkerStore(SQLBaseStore):
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
- defer.returnValue(results)
+ return results
def get_seen_events_with_rejections(self, event_ids):
"""Given a list of event ids, check if we rejected them.
@@ -847,4 +850,4 @@ class EventsWorkerStore(SQLBaseStore):
# it.
complexity_v1 = round(state_events / 500, 2)
- defer.returnValue({"v1": complexity_v1})
+ return {"v1": complexity_v1}
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index b195dc66a0..23b48f6cea 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -15,8 +15,6 @@
from canonicaljson import encode_canonical_json
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks
@@ -41,7 +39,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
- defer.returnValue(db_to_json(def_json))
+ return db_to_json(def_json)
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
index 73e6fc6de2..15b01c6958 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/group_server.py
@@ -307,15 +307,13 @@ class GroupServerStore(SQLBaseStore):
desc="get_group_categories",
)
- defer.returnValue(
- {
- row["category_id"]: {
- "is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
- }
- for row in rows
+ return {
+ row["category_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
}
- )
+ for row in rows
+ }
@defer.inlineCallbacks
def get_group_category(self, group_id, category_id):
@@ -328,7 +326,7 @@ class GroupServerStore(SQLBaseStore):
category["profile"] = json.loads(category["profile"])
- defer.returnValue(category)
+ return category
def upsert_group_category(self, group_id, category_id, profile, is_public):
"""Add/update room category for group
@@ -370,15 +368,13 @@ class GroupServerStore(SQLBaseStore):
desc="get_group_roles",
)
- defer.returnValue(
- {
- row["role_id"]: {
- "is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
- }
- for row in rows
+ return {
+ row["role_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
}
- )
+ for row in rows
+ }
@defer.inlineCallbacks
def get_group_role(self, group_id, role_id):
@@ -391,7 +387,7 @@ class GroupServerStore(SQLBaseStore):
role["profile"] = json.loads(role["profile"])
- defer.returnValue(role)
+ return role
def upsert_group_role(self, group_id, role_id, profile, is_public):
"""Add/remove user role
@@ -960,7 +956,7 @@ class GroupServerStore(SQLBaseStore):
_register_user_group_membership_txn,
next_id,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def create_group(
@@ -1057,9 +1053,9 @@ class GroupServerStore(SQLBaseStore):
now = int(self._clock.time_msec())
if row and now < row["valid_until_ms"]:
- defer.returnValue(json.loads(row["attestation_json"]))
+ return json.loads(row["attestation_json"])
- defer.returnValue(None)
+ return None
def get_joined_groups(self, user_id):
return self._simple_select_onecol(
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index 081564360f..752e9788a2 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -173,7 +173,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
)
if user_id:
count = count + 1
- defer.returnValue(count)
+ return count
@defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 7c4e1dc7ec..d20eacda59 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 55
+SCHEMA_VERSION = 56
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 42ec8c6bb8..1a0f2d5768 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -90,9 +90,7 @@ class PresenceStore(SQLBaseStore):
presence_states,
)
- defer.returnValue(
- (stream_orderings[-1], self._presence_id_gen.get_current_token())
- )
+ return (stream_orderings[-1], self._presence_id_gen.get_current_token())
def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states):
@@ -180,7 +178,7 @@ class PresenceStore(SQLBaseStore):
for row in rows:
row["currently_active"] = bool(row["currently_active"])
- defer.returnValue({row["user_id"]: UserPresenceState(**row) for row in rows})
+ return {row["user_id"]: UserPresenceState(**row) for row in rows}
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 0ff392bdb4..96a7e32fca 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,8 +19,11 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.storage.roommember import ProfileInfo
+from . import background_updates
from ._base import SQLBaseStore
+BATCH_SIZE = 100
+
class ProfileWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
@@ -34,15 +38,13 @@ class ProfileWorkerStore(SQLBaseStore):
except StoreError as e:
if e.code == 404:
# no match
- defer.returnValue(ProfileInfo(None, None))
+ return ProfileInfo(None, None)
return
else:
raise
- defer.returnValue(
- ProfileInfo(
- avatar_url=profile["avatar_url"], display_name=profile["displayname"]
- )
+ return ProfileInfo(
+ avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
def get_profile_displayname(self, user_localpart):
@@ -61,6 +63,54 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ def get_latest_profile_replication_batch_number(self):
+ def f(txn):
+ txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
+ rows = self.cursor_to_dict(txn)
+ return rows[0]["maxbatch"]
+
+ return self.runInteraction("get_latest_profile_replication_batch_number", f)
+
+ def get_profile_batch(self, batchnum):
+ return self._simple_select_list(
+ table="profiles",
+ keyvalues={"batch": batchnum},
+ retcols=("user_id", "displayname", "avatar_url", "active"),
+ desc="get_profile_batch",
+ )
+
+ def assign_profile_batch(self):
+ def f(txn):
+ sql = (
+ "UPDATE profiles SET batch = "
+ "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) "
+ "WHERE user_id in ("
+ " SELECT user_id FROM profiles WHERE batch is NULL limit ?"
+ ")"
+ )
+ txn.execute(sql, (BATCH_SIZE,))
+ return txn.rowcount
+
+ return self.runInteraction("assign_profile_batch", f)
+
+ def get_replication_hosts(self):
+ def f(txn):
+ txn.execute(
+ "SELECT host, last_synced_batch FROM profile_replication_status"
+ )
+ rows = self.cursor_to_dict(txn)
+ return {r["host"]: r["last_synced_batch"] for r in rows}
+
+ return self.runInteraction("get_replication_hosts", f)
+
+ def update_replication_batch_for_host(self, host, last_synced_batch):
+ return self._simple_upsert(
+ table="profile_replication_status",
+ keyvalues={"host": host},
+ values={"last_synced_batch": last_synced_batch},
+ desc="update_replication_batch_for_host",
+ )
+
def get_from_remote_profile_cache(self, user_id):
return self._simple_select_one(
table="remote_profile_cache",
@@ -70,29 +120,53 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache",
)
- def create_profile(self, user_localpart):
- return self._simple_insert(
- table="profiles", values={"user_id": user_localpart}, desc="create_profile"
- )
-
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self._simple_update_one(
+ def set_profile_displayname(self, user_localpart, new_displayname, batchnum):
+ return self._simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
+ values={"displayname": new_displayname, "batch": batchnum},
desc="set_profile_displayname",
+ lock=False, # we can do this because user_id has a unique index
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self._simple_update_one(
+ def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum):
+ return self._simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
+ values={"avatar_url": new_avatar_url, "batch": batchnum},
desc="set_profile_avatar_url",
+ lock=False, # we can do this because user_id has a unique index
+ )
+
+ def set_profile_active(self, user_localpart, active, hide, batchnum):
+ values = {"active": int(active), "batch": batchnum}
+ if not active and not hide:
+ # we are deactivating for real (not in hide mode)
+ # so clear the profile.
+ values["avatar_url"] = None
+ values["displayname"] = None
+ return self._simple_upsert(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ values=values,
+ desc="set_profile_active",
+ lock=False, # we can do this because user_id has a unique index
)
-class ProfileStore(ProfileWorkerStore):
+class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore):
+ def __init__(self, db_conn, hs):
+
+ super(ProfileStore, self).__init__(db_conn, hs)
+
+ self.register_background_index_update(
+ "profile_replication_status_host_index",
+ index_name="profile_replication_status_idx",
+ table="profile_replication_status",
+ columns=["host"],
+ unique=True,
+ )
+
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
@@ -168,7 +242,7 @@ class ProfileStore(ProfileWorkerStore):
)
if res:
- defer.returnValue(True)
+ return True
res = yield self._simple_select_one_onecol(
table="group_invites",
@@ -179,4 +253,4 @@ class ProfileStore(ProfileWorkerStore):
)
if res:
- defer.returnValue(True)
+ return True
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 98cec8c82b..a6517c4cf3 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -120,7 +120,7 @@ class PushRulesWorkerStore(
rules = _load_rules(rows, enabled_map)
- defer.returnValue(rules)
+ return rules
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
@@ -130,9 +130,7 @@ class PushRulesWorkerStore(
retcols=("user_name", "rule_id", "enabled"),
desc="get_push_rules_enabled_for_user",
)
- defer.returnValue(
- {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
- )
+ return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
@@ -160,7 +158,7 @@ class PushRulesWorkerStore(
)
def bulk_get_push_rules(self, user_ids):
if not user_ids:
- defer.returnValue({})
+ return {}
results = {user_id: [] for user_id in user_ids}
@@ -182,7 +180,7 @@ class PushRulesWorkerStore(
for user_id, rules in results.items():
results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
@@ -253,7 +251,7 @@ class PushRulesWorkerStore(
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)
- defer.returnValue(result)
+ return result
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(
@@ -312,7 +310,7 @@ class PushRulesWorkerStore(
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
- defer.returnValue(rules_by_user)
+ return rules_by_user
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
@@ -322,7 +320,7 @@ class PushRulesWorkerStore(
)
def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
- defer.returnValue({})
+ return {}
results = {user_id: {} for user_id in user_ids}
@@ -336,7 +334,7 @@ class PushRulesWorkerStore(
for row in rows:
enabled = bool(row["enabled"])
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
- defer.returnValue(results)
+ return results
class PushRuleStore(PushRulesWorkerStore):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index cfe0a94330..be3d4d9ded 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -63,7 +63,7 @@ class PusherWorkerStore(SQLBaseStore):
ret = yield self._simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
- defer.returnValue(ret is not None)
+ return ret is not None
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
@@ -95,7 +95,7 @@ class PusherWorkerStore(SQLBaseStore):
],
desc="get_pushers_by",
)
- defer.returnValue(self._decode_pushers_rows(ret))
+ return self._decode_pushers_rows(ret)
@defer.inlineCallbacks
def get_all_pushers(self):
@@ -106,7 +106,7 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(rows)
rows = yield self.runInteraction("get_all_pushers", get_pushers)
- defer.returnValue(rows)
+ return rows
def get_all_updated_pushers(self, last_id, current_id, limit):
if last_id == current_id:
@@ -205,7 +205,7 @@ class PusherWorkerStore(SQLBaseStore):
result = {user_id: False for user_id in user_ids}
result.update({r["user_name"]: True for r in rows})
- defer.returnValue(result)
+ return result
class PusherStore(PusherWorkerStore):
@@ -343,7 +343,7 @@ class PusherStore(PusherWorkerStore):
"throttle_ms": row["throttle_ms"],
}
- defer.returnValue(params_by_room)
+ return params_by_room
@defer.inlineCallbacks
def set_throttle_params(self, pusher_id, room_id, params):
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index b477da12b1..6aa6d98ebb 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
- defer.returnValue(set(r["user_id"] for r in receipts))
+ return set(r["user_id"] for r in receipts)
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
@@ -92,7 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
desc="get_receipts_for_user",
)
- defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
+ return {row["room_id"]: row["event_id"] for row in rows}
@defer.inlineCallbacks
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
@@ -110,16 +110,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
return txn.fetchall()
rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
- defer.returnValue(
- {
- row[0]: {
- "event_id": row[1],
- "topological_ordering": row[2],
- "stream_ordering": row[3],
- }
- for row in rows
+ return {
+ row[0]: {
+ "event_id": row[1],
+ "topological_ordering": row[2],
+ "stream_ordering": row[3],
}
- )
+ for row in rows
+ }
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
@@ -147,7 +145,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
room_ids, to_key, from_key=from_key
)
- defer.returnValue([ev for res in results.values() for ev in res])
+ return [ev for res in results.values() for ev in res]
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
@@ -197,7 +195,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
- defer.returnValue([])
+ return []
content = {}
for row in rows:
@@ -205,9 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
row["user_id"]
] = json.loads(row["data"])
- defer.returnValue(
- [{"type": "m.receipt", "room_id": room_id, "content": content}]
- )
+ return [{"type": "m.receipt", "room_id": room_id, "content": content}]
@cachedList(
cached_method_name="_get_linearized_receipts_for_room",
@@ -217,7 +213,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
- defer.returnValue({})
+ return {}
def f(txn):
if from_key:
@@ -264,7 +260,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
room_id: [results[room_id]] if room_id in results else []
for room_id in room_ids
}
- defer.returnValue(results)
+ return results
def get_all_updated_receipts(self, last_id, current_id, limit=None):
if last_id == current_id:
@@ -468,7 +464,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
)
if event_ts is None:
- defer.returnValue(None)
+ return None
now = self._clock.time_msec()
logger.debug(
@@ -482,7 +478,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
max_persisted_id = self._receipts_id_gen.get_current_token()
- defer.returnValue((stream_id, max_persisted_id))
+ return (stream_id, max_persisted_id)
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
return self.runInteraction(
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 8b2c2a97ab..ea5b2be0f7 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -75,12 +75,12 @@ class RegistrationWorkerStore(SQLBaseStore):
info = yield self.get_user_by_id(user_id)
if not info:
- defer.returnValue(False)
+ return False
now = self.clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
- defer.returnValue(is_trial)
+ return is_trial
@cached()
def get_user_by_access_token(self, token):
@@ -115,7 +115,7 @@ class RegistrationWorkerStore(SQLBaseStore):
allow_none=True,
desc="get_expiration_ts_for_user",
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def set_account_validity_for_user(
@@ -155,6 +155,28 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@defer.inlineCallbacks
+ def get_expired_users(self):
+ """Get IDs of all expired users
+
+ Returns:
+ Deferred[list[str]]: List of expired user IDs
+ """
+
+ def get_expired_users_txn(txn, now_ms):
+ sql = """
+ SELECT user_id from account_validity
+ WHERE expiration_ts_ms <= ?
+ """
+ txn.execute(sql, (now_ms,))
+ rows = txn.fetchall()
+ return [row[0] for row in rows]
+
+ res = yield self.runInteraction(
+ "get_expired_users", get_expired_users_txn, self.clock.time_msec()
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
def set_renewal_token_for_user(self, user_id, renewal_token):
"""Defines a renewal token for a given user.
@@ -190,7 +212,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_from_renewal_token",
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def get_renewal_token_for_user(self, user_id):
@@ -209,7 +231,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_renewal_token_for_user",
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def get_users_expiring_soon(self):
@@ -237,7 +259,7 @@ class RegistrationWorkerStore(SQLBaseStore):
self.config.account_validity.renew_at,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def set_renewal_mail_status(self, user_id, email_sent):
@@ -280,7 +302,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="is_server_admin",
)
- defer.returnValue(res if res else False)
+ return res if res else False
def _query_for_auth(self, txn, token):
sql = (
@@ -311,7 +333,7 @@ class RegistrationWorkerStore(SQLBaseStore):
res = yield self.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
- defer.returnValue(res)
+ return res
def is_support_user_txn(self, txn, user_id):
res = self._simple_select_one_onecol_txn(
@@ -349,7 +371,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return 0
ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
+ return ret
def count_daily_user_type(self):
"""
@@ -395,7 +417,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return count
ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
@@ -425,7 +447,7 @@ class RegistrationWorkerStore(SQLBaseStore):
if i not in found:
return i
- defer.returnValue(
+ return (
(
yield self.runInteraction(
"find_next_generated_user_id", _find_next_generated_user_id
@@ -447,7 +469,7 @@ class RegistrationWorkerStore(SQLBaseStore):
user_id = yield self.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
- defer.returnValue(user_id)
+ return user_id
def get_user_id_by_threepid_txn(self, txn, medium, address):
"""Returns user id from threepid
@@ -487,7 +509,7 @@ class RegistrationWorkerStore(SQLBaseStore):
["medium", "address", "validated_at", "added_at"],
"user_get_threepids",
)
- defer.returnValue(ret)
+ return ret
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
@@ -677,7 +699,7 @@ class RegistrationStore(
if end:
yield self._end_background_update("users_set_deactivated_flag")
- defer.returnValue(batch_size)
+ return batch_size
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
@@ -957,7 +979,7 @@ class RegistrationStore(
desc="is_guest",
)
- defer.returnValue(res if res else False)
+ return res if res else False
def add_user_pending_deactivation(self, user_id):
"""
@@ -1024,7 +1046,7 @@ class RegistrationStore(
yield self._end_background_update("user_threepids_grandfather")
- defer.returnValue(1)
+ return 1
def get_threepid_validation_session(
self, medium, client_secret, address=None, sid=None, validated=True
@@ -1337,4 +1359,4 @@ class RegistrationStore(
)
# Convert the integer into a boolean.
- defer.returnValue(res == 1)
+ return res == 1
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 9954bc094f..fcb5f2f23a 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -17,8 +17,6 @@ import logging
import attr
-from twisted.internet import defer
-
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
@@ -363,7 +361,7 @@ class RelationsWorkerStore(SQLBaseStore):
return
edit_event = yield self.get_event(edit_id, allow_none=True)
- defer.returnValue(edit_event)
+ return edit_event
def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
"""Check if a user has already annotated an event with the same key
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index fe9d79d792..1ca01a4c70 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -17,10 +17,13 @@ import collections
import logging
import re
+from six import integer_types
+
from canonicaljson import json
from twisted.internet import defer
+from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.search import SearchStore
@@ -171,6 +174,24 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ @defer.inlineCallbacks
+ def is_room_published(self, room_id):
+ """Check whether a room has been published in the local public room
+ directory.
+
+ Args:
+ room_id (str)
+ Returns:
+ bool: Whether the room is currently published in the room directory
+ """
+ # Get room information
+ room_info = yield self.get_room(room_id)
+ if not room_info:
+ defer.returnValue(False)
+
+ # Check the is_public value
+ defer.returnValue(room_info.get("is_public", False))
+
@cachedInlineCallbacks(max_entries=10000)
def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given
@@ -193,17 +214,153 @@ class RoomWorkerStore(SQLBaseStore):
)
if row:
- defer.returnValue(
- RatelimitOverride(
- messages_per_second=row["messages_per_second"],
- burst_count=row["burst_count"],
- )
+ return RatelimitOverride(
+ messages_per_second=row["messages_per_second"],
+ burst_count=row["burst_count"],
)
else:
- defer.returnValue(None)
+ return None
+
+ @cachedInlineCallbacks()
+ def get_retention_policy_for_room(self, room_id):
+ """Get the retention policy for a given room.
+
+ If no retention policy has been found for this room, returns a policy defined
+ by the configured default policy (which has None as both the 'min_lifetime' and
+ the 'max_lifetime' if no default policy has been defined in the server's
+ configuration).
+
+ Args:
+ room_id (str): The ID of the room to get the retention policy of.
+
+ Returns:
+ dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+ """
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum, in order not to filter out events we should filter out when sending to
+ # the client.
+ if not self.config.retention_enabled:
+ defer.returnValue({"min_lifetime": None, "max_lifetime": None})
+
+ def get_retention_policy_for_room_txn(txn):
+ txn.execute(
+ """
+ SELECT min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ WHERE room_id = ?;
+ """,
+ (room_id,),
+ )
+
+ return self.cursor_to_dict(txn)
+
+ ret = yield self.runInteraction(
+ "get_retention_policy_for_room", get_retention_policy_for_room_txn
+ )
+
+ # If we don't know this room ID, ret will be None, in this case return the default
+ # policy.
+ if not ret:
+ defer.returnValue(
+ {
+ "min_lifetime": self.config.retention_default_min_lifetime,
+ "max_lifetime": self.config.retention_default_max_lifetime,
+ }
+ )
+
+ row = ret[0]
+
+ # If one of the room's policy's attributes isn't defined, use the matching
+ # attribute from the default policy.
+ # The default values will be None if no default policy has been defined, or if one
+ # of the attributes is missing from the default policy.
+ if row["min_lifetime"] is None:
+ row["min_lifetime"] = self.config.retention_default_min_lifetime
+
+ if row["max_lifetime"] is None:
+ row["max_lifetime"] = self.config.retention_default_max_lifetime
+
+ defer.returnValue(row)
class RoomStore(RoomWorkerStore, SearchStore):
+ def __init__(self, db_conn, hs):
+ super(RoomStore, self).__init__(db_conn, hs)
+
+ self.config = hs.config
+
+ self.register_background_update_handler(
+ "insert_room_retention", self._background_insert_retention
+ )
+
+ @defer.inlineCallbacks
+ def _background_insert_retention(self, progress, batch_size):
+ """Retrieves a list of all rooms within a range and inserts an entry for each of
+ them into the room_retention table.
+ NULLs the property's columns if missing from the retention event in the room's
+ state (or NULLs all of them if there's no retention event in the room's state),
+ so that we fall back to the server's retention policy.
+ """
+
+ last_room = progress.get("room_id", "")
+
+ def _background_insert_retention_txn(txn):
+ txn.execute(
+ """
+ SELECT state.room_id, state.event_id, events.json
+ FROM current_state_events as state
+ LEFT JOIN event_json AS events ON (state.event_id = events.event_id)
+ WHERE state.room_id > ? AND state.type = '%s'
+ ORDER BY state.room_id ASC
+ LIMIT ?;
+ """
+ % EventTypes.Retention,
+ (last_room, batch_size),
+ )
+
+ rows = self.cursor_to_dict(txn)
+
+ if not rows:
+ return True
+
+ for row in rows:
+ if not row["json"]:
+ retention_policy = {}
+ else:
+ ev = json.loads(row["json"])
+ retention_policy = json.dumps(ev["content"])
+
+ self._simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": row["room_id"],
+ "event_id": row["event_id"],
+ "min_lifetime": retention_policy.get("min_lifetime"),
+ "max_lifetime": retention_policy.get("max_lifetime"),
+ },
+ )
+
+ logger.info("Inserted %d rows into room_retention", len(rows))
+
+ self._background_update_progress_txn(
+ txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
+ )
+
+ if batch_size > len(rows):
+ return True
+ else:
+ return False
+
+ end = yield self.runInteraction(
+ "insert_room_retention", _background_insert_retention_txn
+ )
+
+ if end:
+ yield self._end_background_update("insert_room_retention")
+
+ defer.returnValue(batch_size)
+
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
"""Stores a room.
@@ -439,6 +596,35 @@ class RoomStore(RoomWorkerStore, SearchStore):
)
txn.execute(sql, (event.event_id, event.room_id, event.content[key]))
+ def _store_retention_policy_for_room_txn(self, txn, event):
+ if hasattr(event, "content") and (
+ "min_lifetime" in event.content or "max_lifetime" in event.content
+ ):
+ if (
+ "min_lifetime" in event.content
+ and not isinstance(event.content.get("min_lifetime"), integer_types)
+ ) or (
+ "max_lifetime" in event.content
+ and not isinstance(event.content.get("max_lifetime"), integer_types)
+ ):
+ # Ignore the event if one of the value isn't an integer.
+ return
+
+ self._simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ "min_lifetime": event.content.get("min_lifetime"),
+ "max_lifetime": event.content.get("max_lifetime"),
+ },
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_retention_policy_for_room, (event.room_id,)
+ )
+
def add_event_report(
self, room_id, event_id, user_id, reason, content, received_ts
):
@@ -620,3 +806,89 @@ class RoomStore(RoomWorkerStore, SearchStore):
remote_media_mxcs.append((hostname, media_id))
return local_media_mxcs, remote_media_mxcs
+
+ @defer.inlineCallbacks
+ def get_rooms_for_retention_period_in_range(
+ self, min_ms, max_ms, include_null=False
+ ):
+ """Retrieves all of the rooms within the given retention range.
+
+ Optionally includes the rooms which don't have a retention policy.
+
+ Args:
+ min_ms (int|None): Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, doesn't set a lower limit.
+ max_ms (int|None): Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, doesn't set an upper limit.
+ include_null (bool): Whether to include rooms which retention policy is NULL
+ in the returned set.
+
+ Returns:
+ dict[str, dict]: The rooms within this range, along with their retention
+ policy. The key is "room_id", and maps to a dict describing the retention
+ policy associated with this room ID. The keys for this nested dict are
+ "min_lifetime" (int|None), and "max_lifetime" (int|None).
+ """
+
+ def get_rooms_for_retention_period_in_range_txn(txn):
+ range_conditions = []
+ args = []
+
+ if min_ms is not None:
+ range_conditions.append("max_lifetime > ?")
+ args.append(min_ms)
+
+ if max_ms is not None:
+ range_conditions.append("max_lifetime <= ?")
+ args.append(max_ms)
+
+ # Do a first query which will retrieve the rooms that have a retention policy
+ # in their current state.
+ sql = """
+ SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ """
+
+ if len(range_conditions):
+ sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+ if include_null:
+ sql += " OR max_lifetime IS NULL"
+
+ txn.execute(sql, args)
+
+ rows = self.cursor_to_dict(txn)
+ rooms_dict = {}
+
+ for row in rows:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": row["min_lifetime"],
+ "max_lifetime": row["max_lifetime"],
+ }
+
+ if include_null:
+ # If required, do a second query that retrieves all of the rooms we know
+ # of so we can handle rooms with no retention policy.
+ sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+ txn.execute(sql)
+
+ rows = self.cursor_to_dict(txn)
+
+ # If a room isn't already in the dict (i.e. it doesn't have a retention
+ # policy in its state), add it with a null policy.
+ for row in rows:
+ if row["room_id"] not in rooms_dict:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": None,
+ "max_lifetime": None,
+ }
+
+ return rooms_dict
+
+ rooms = yield self.runInteraction(
+ "get_rooms_for_retention_period_in_range",
+ get_rooms_for_retention_period_in_range_txn,
+ )
+
+ defer.returnValue(rooms)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 32cfd010a5..cb88e49b51 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -24,6 +24,8 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import LoggingTransaction
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id
from synapse.util.async_helpers import Linearizer
@@ -53,9 +55,51 @@ ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
MemberSummary = namedtuple("MemberSummary", ("members", "count"))
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
+_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
+ def __init__(self, db_conn, hs):
+ super(RoomMemberWorkerStore, self).__init__(db_conn, hs)
+
+ # Is the current_state_events.membership up to date? Or is the
+ # background update still running?
+ self._current_state_events_membership_up_to_date = False
+
+ txn = LoggingTransaction(
+ db_conn.cursor(),
+ name="_check_safe_current_state_events_membership_updated",
+ database_engine=self.database_engine,
+ )
+ self._check_safe_current_state_events_membership_updated_txn(txn)
+ txn.close()
+
+ def _check_safe_current_state_events_membership_updated_txn(self, txn):
+ """Checks if it is safe to assume the new current_state_events
+ membership column is up to date
+ """
+
+ pending_update = self._simple_select_one_txn(
+ txn,
+ table="background_updates",
+ keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
+ retcols=["update_name"],
+ allow_none=True,
+ )
+
+ self._current_state_events_membership_up_to_date = not pending_update
+
+ # If the update is still running, reschedule to run.
+ if pending_update:
+ self._clock.call_later(
+ 15.0,
+ run_as_background_process,
+ "_check_safe_current_state_events_membership_updated",
+ self.runInteraction,
+ "_check_safe_current_state_events_membership_updated",
+ self._check_safe_current_state_events_membership_updated_txn,
+ )
+
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
def get_hosts_in_room(self, room_id, cache_context):
"""Returns the set of all hosts currently in the room
@@ -64,19 +108,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
room_id, on_invalidate=cache_context.invalidate
)
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
- defer.returnValue(hosts)
+ return hosts
@cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id):
def f(txn):
- sql = (
- "SELECT m.user_id FROM room_memberships as m"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id "
- " AND m.room_id = c.room_id "
- " AND m.user_id = c.state_key"
- " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
- )
+ # If we can assume current_state_events.membership is up to date
+ # then we can avoid a join, which is a Very Good Thing given how
+ # frequently this function gets called.
+ if self._current_state_events_membership_up_to_date:
+ sql = """
+ SELECT state_key FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
+ """
+ else:
+ sql = """
+ SELECT state_key FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
+ """
txn.execute(sql, (room_id, Membership.JOIN))
return [to_ascii(r[0]) for r in txn]
@@ -98,15 +151,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# first get counts.
# We do this all in one transaction to keep the cache small.
# FIXME: get rid of this when we have room_stats
- sql = """
- SELECT count(*), m.membership FROM room_memberships as m
- INNER JOIN current_state_events as c
- ON m.event_id = c.event_id
- AND m.room_id = c.room_id
- AND m.user_id = c.state_key
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- GROUP BY m.membership
- """
+
+ # If we can assume current_state_events.membership is up to date
+ # then we can avoid a join, which is a Very Good Thing given how
+ # frequently this function gets called.
+ if self._current_state_events_membership_up_to_date:
+ sql = """
+ SELECT count(*), membership FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ?
+ GROUP BY membership
+ """
+ else:
+ sql = """
+ SELECT count(*), m.membership FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ GROUP BY m.membership
+ """
txn.execute(sql, (room_id,))
res = {}
@@ -189,8 +253,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
invites = yield self.get_invited_rooms_for_user(user_id)
for invite in invites:
if invite.room_id == room_id:
- defer.returnValue(invite)
- defer.returnValue(None)
+ return invite
+ return None
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
@@ -224,7 +288,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
results = []
if membership_list:
where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
- " OR ".join(["membership = ?" for _ in membership_list]),
+ " OR ".join(["m.membership = ?" for _ in membership_list]),
)
args = [user_id]
@@ -283,11 +347,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN]
)
- defer.returnValue(
- frozenset(
- GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
- for r in rooms
- )
+ return frozenset(
+ GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+ for r in rooms
)
@defer.inlineCallbacks
@@ -297,7 +359,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rooms = yield self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate
)
- defer.returnValue(frozenset(r.room_id for r in rooms))
+ return frozenset(r.room_id for r in rooms)
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context):
@@ -314,7 +376,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
user_who_share_room.update(user_ids)
- defer.returnValue(user_who_share_room)
+ return user_who_share_room
@defer.inlineCallbacks
def get_joined_users_from_context(self, event, context):
@@ -330,7 +392,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
- defer.returnValue(result)
+ return result
def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
@@ -444,7 +506,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
avatar_url=to_ascii(event.content.get("avatar_url", None)),
)
- defer.returnValue(users_in_room)
+ return users_in_room
@cachedInlineCallbacks(max_entries=10000)
def is_host_joined(self, room_id, host):
@@ -453,8 +515,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
sql = """
SELECT state_key FROM current_state_events AS c
- INNER JOIN room_memberships USING (event_id)
- WHERE membership = 'join'
+ INNER JOIN room_memberships AS m USING (event_id)
+ WHERE m.membership = 'join'
AND type = 'm.room.member'
AND c.room_id = ?
AND state_key LIKE ?
@@ -469,14 +531,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
if not rows:
- defer.returnValue(False)
+ return False
user_id = rows[0][0]
if get_domain_from_id(user_id) != host:
# This can only happen if the host name has something funky in it
raise Exception("Invalid host name")
- defer.returnValue(True)
+ return True
@cachedInlineCallbacks()
def was_host_joined(self, room_id, host):
@@ -509,14 +571,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
if not rows:
- defer.returnValue(False)
+ return False
user_id = rows[0][0]
if get_domain_from_id(user_id) != host:
# This can only happen if the host name has something funky in it
raise Exception("Invalid host name")
- defer.returnValue(True)
+ return True
def get_joined_hosts(self, room_id, state_entry):
state_group = state_entry.state_group
@@ -543,7 +605,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
cache = self._get_joined_hosts_cache(room_id)
joined_hosts = yield cache.get_destinations(state_entry)
- defer.returnValue(joined_hosts)
+ return joined_hosts
@cached(max_entries=10000)
def _get_joined_hosts_cache(self, room_id):
@@ -573,7 +635,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return rows[0][0]
count = yield self.runInteraction("did_forget_membership", f)
- defer.returnValue(count == 0)
+ return count == 0
@defer.inlineCallbacks
def get_rooms_user_has_been_in(self, user_id):
@@ -602,6 +664,10 @@ class RoomMemberStore(RoomMemberWorkerStore):
self.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
+ self.register_background_update_handler(
+ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+ self._background_current_state_membership,
+ )
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
@@ -779,7 +845,65 @@ class RoomMemberStore(RoomMemberWorkerStore):
if not result:
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
- defer.returnValue(result)
+ return result
+
+ @defer.inlineCallbacks
+ def _background_current_state_membership(self, progress, batch_size):
+ """Update the new membership column on current_state_events.
+
+ This works by iterating over all rooms in alphebetical order.
+ """
+
+ def _background_current_state_membership_txn(txn, last_processed_room):
+ processed = 0
+ while processed < batch_size:
+ txn.execute(
+ """
+ SELECT MIN(room_id) FROM rooms WHERE room_id > ?
+ """,
+ (last_processed_room,),
+ )
+ row = txn.fetchone()
+ if not row or not row[0]:
+ return processed, True
+
+ next_room, = row
+
+ sql = """
+ UPDATE current_state_events AS c
+ SET membership = (
+ SELECT membership FROM room_memberships
+ WHERE event_id = c.event_id
+ )
+ WHERE room_id = ?
+ """
+ txn.execute(sql, (next_room,))
+ processed += txn.rowcount
+
+ last_processed_room = next_room
+
+ self._background_update_progress_txn(
+ txn,
+ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+ {"last_processed_room": last_processed_room},
+ )
+
+ return processed, False
+
+ # If we haven't got a last processed room then just use the empty
+ # string, which will compare before all room IDs correctly.
+ last_processed_room = progress.get("last_processed_room", "")
+
+ row_count, finished = yield self.runInteraction(
+ "_background_current_state_membership_update",
+ _background_current_state_membership_txn,
+ last_processed_room,
+ )
+
+ if finished:
+ yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME)
+
+ return row_count
class _JoinedHostsCache(object):
@@ -807,7 +931,7 @@ class _JoinedHostsCache(object):
state_entry(synapse.state._StateCacheEntry)
"""
if state_entry.state_group == self.state_group:
- defer.returnValue(frozenset(self.hosts_to_joined_users))
+ return frozenset(self.hosts_to_joined_users)
with (yield self.linearizer.queue(())):
if state_entry.state_group == self.state_group:
@@ -844,7 +968,7 @@ class _JoinedHostsCache(object):
else:
self.state_group = object()
self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
- defer.returnValue(frozenset(self.hosts_to_joined_users))
+ return frozenset(self.hosts_to_joined_users)
def __len__(self):
return self._len
diff --git a/synapse/storage/schema/delta/48/profiles_batch.sql b/synapse/storage/schema/delta/48/profiles_batch.sql
new file mode 100644
index 0000000000..e744c02fe8
--- /dev/null
+++ b/synapse/storage/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Add a batch number to track changes to profiles and the
+ * order they're made in so we can replicate user profiles
+ * to other hosts as they change
+ */
+ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL;
+
+/*
+ * Index on the batch number so we can get profiles
+ * by their batch
+ */
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+
+/*
+ * A table to track what batch of user profiles has been
+ * synced to what profile replication target.
+ */
+CREATE TABLE profile_replication_status (
+ host TEXT NOT NULL,
+ last_synced_batch BIGINT NOT NULL
+);
diff --git a/synapse/storage/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/schema/delta/50/profiles_deactivated_users.sql
new file mode 100644
index 0000000000..c8893ecbe8
--- /dev/null
+++ b/synapse/storage/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * A flag saying whether the user owning the profile has been deactivated
+ * This really belongs on the users table, not here, but the users table
+ * stores users by their full user_id and profiles stores them by localpart,
+ * so we can't easily join between the two tables. Plus, the batch number
+ * realy ought to represent data in this table that has changed.
+ */
+ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;
diff --git a/synapse/storage/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/schema/delta/55/profile_replication_status_index.sql
new file mode 100644
index 0000000000..18a0f7e10c
--- /dev/null
+++ b/synapse/storage/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('profile_replication_status_host_index', '{}');
diff --git a/synapse/storage/schema/delta/55/room_retention.sql b/synapse/storage/schema/delta/55/room_retention.sql
new file mode 100644
index 0000000000..ee6cdf7a14
--- /dev/null
+++ b/synapse/storage/schema/delta/55/room_retention.sql
@@ -0,0 +1,33 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Tracks the retention policy of a room.
+-- A NULL max_lifetime or min_lifetime means that the matching property is not defined in
+-- the room's retention policy state event.
+-- If a room doesn't have a retention policy state event in its state, both max_lifetime
+-- and min_lifetime are NULL.
+CREATE TABLE IF NOT EXISTS room_retention(
+ room_id TEXT,
+ event_id TEXT,
+ min_lifetime BIGINT,
+ max_lifetime BIGINT,
+
+ PRIMARY KEY(room_id, event_id)
+);
+
+CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime);
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('insert_room_retention', '{}');
diff --git a/synapse/storage/schema/delta/56/current_state_events_membership.sql b/synapse/storage/schema/delta/56/current_state_events_membership.sql
new file mode 100644
index 0000000000..b2e08cd85d
--- /dev/null
+++ b/synapse/storage/schema/delta/56/current_state_events_membership.sql
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We add membership to current state so that we don't need to join against
+-- room_memberships, which can be surprisingly costly (we do such queries
+-- very frequently).
+-- This will be null for non-membership events and the content.membership key
+-- for membership events. (Will also be null for membership events until the
+-- background update job has finished).
+ALTER TABLE current_state_events ADD membership TEXT;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('current_state_events_membership', '{}');
diff --git a/synapse/storage/schema/full_schemas/54/full.sql.postgres b/synapse/storage/schema/full_schemas/54/full.sql.postgres
index 098434356f..01a2b0e024 100644
--- a/synapse/storage/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/schema/full_schemas/54/full.sql.postgres
@@ -667,10 +667,19 @@ CREATE TABLE presence_stream (
+CREATE TABLE profile_replication_status (
+ host text NOT NULL,
+ last_synced_batch bigint NOT NULL
+);
+
+
+
CREATE TABLE profiles (
user_id text NOT NULL,
displayname text,
- avatar_url text
+ avatar_url text,
+ batch bigint,
+ active smallint DEFAULT 1 NOT NULL
);
@@ -1842,6 +1851,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id);
+CREATE INDEX profiles_batch_idx ON profiles USING btree (batch);
+
+
+
CREATE INDEX public_room_index ON rooms USING btree (is_public);
diff --git a/synapse/storage/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/schema/full_schemas/54/full.sql.sqlite
index be9295e4c9..f1a71627f0 100644
--- a/synapse/storage/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us
CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) );
CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) );
CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL );
-CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) );
+CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) );
CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) );
CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER );
CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) );
@@ -208,6 +208,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id);
CREATE INDEX group_invites_u_idx ON group_invites(user_id);
CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL );
CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL );
CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index f3b1cec933..df87ab6a6d 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -166,7 +166,7 @@ class SearchStore(BackgroundUpdateStore):
if not result:
yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _background_reindex_gin_search(self, progress, batch_size):
@@ -209,7 +209,7 @@ class SearchStore(BackgroundUpdateStore):
yield self.runWithConnection(create_index)
yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _background_reindex_search_order(self, progress, batch_size):
@@ -287,7 +287,7 @@ class SearchStore(BackgroundUpdateStore):
if not finished:
yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
- defer.returnValue(num_rows)
+ return num_rows
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
@@ -454,17 +454,15 @@ class SearchStore(BackgroundUpdateStore):
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
- defer.returnValue(
- {
- "results": [
- {"event": event_map[r["event_id"]], "rank": r["rank"]}
- for r in results
- if r["event_id"] in event_map
- ],
- "highlights": highlights,
- "count": count,
- }
- )
+ return {
+ "results": [
+ {"event": event_map[r["event_id"]], "rank": r["rank"]}
+ for r in results
+ if r["event_id"] in event_map
+ ],
+ "highlights": highlights,
+ "count": count,
+ }
@defer.inlineCallbacks
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
@@ -599,22 +597,20 @@ class SearchStore(BackgroundUpdateStore):
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
- defer.returnValue(
- {
- "results": [
- {
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- "pagination_token": "%s,%s"
- % (r["origin_server_ts"], r["stream_ordering"]),
- }
- for r in results
- if r["event_id"] in event_map
- ],
- "highlights": highlights,
- "count": count,
- }
- )
+ return {
+ "results": [
+ {
+ "event": event_map[r["event_id"]],
+ "rank": r["rank"],
+ "pagination_token": "%s,%s"
+ % (r["origin_server_ts"], r["stream_ordering"]),
+ }
+ for r in results
+ if r["event_id"] in event_map
+ ],
+ "highlights": highlights,
+ "count": count,
+ }
def _find_highlights_in_postgres(self, search_query, events):
"""Given a list of events and a search term, return a list of words
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 6bd81e84ad..fb83218f90 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -59,7 +59,7 @@ class SignatureWorkerStore(SQLBaseStore):
for e_id, h in hashes.items()
}
- defer.returnValue(list(hashes.items()))
+ return list(hashes.items())
def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU.
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0bfe1b4550..1980a87108 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -422,7 +422,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# Retrieve the room's create event
create_event = yield self.get_create_event_for_room(room_id)
- defer.returnValue(create_event.content.get("room_version", "1"))
+ return create_event.content.get("room_version", "1")
@defer.inlineCallbacks
def get_room_predecessor(self, room_id):
@@ -442,7 +442,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_event = yield self.get_create_event_for_room(room_id)
# Return predecessor if present
- defer.returnValue(create_event.content.get("predecessor", None))
+ return create_event.content.get("predecessor", None)
@defer.inlineCallbacks
def get_create_event_for_room(self, room_id):
@@ -466,7 +466,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# Retrieve the room's create event and return
create_event = yield self.get_event(create_id)
- defer.returnValue(create_event)
+ return create_event
@cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
@@ -510,6 +510,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
event ID.
"""
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ if not where_clause:
+ # We delegate to the cached version
+ return self.get_current_state_ids(room_id)
+
def _get_filtered_current_state_ids_txn(txn):
results = {}
sql = """
@@ -517,8 +523,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
WHERE room_id = ?
"""
- where_clause, where_args = state_filter.make_sql_filter_clause()
-
if where_clause:
sql += " AND (%s)" % (where_clause,)
@@ -559,7 +563,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event:
return
- defer.returnValue(event.content.get("canonical_alias"))
+ return event.content.get("canonical_alias")
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
@@ -609,14 +613,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
- defer.returnValue({})
+ return {}
event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups)
- defer.returnValue(group_to_state)
+ return group_to_state
@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
@@ -630,7 +634,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
group_to_state = yield self._get_state_for_groups((state_group,))
- defer.returnValue(group_to_state[state_group])
+ return group_to_state[state_group]
@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
@@ -641,7 +645,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
dict of state_group_id -> list of state events.
"""
if not event_ids:
- defer.returnValue({})
+ return {}
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
@@ -654,16 +658,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
get_prev_content=False,
)
- defer.returnValue(
- {
- group: [
- state_event_map[v]
- for v in itervalues(event_id_map)
- if v in state_event_map
- ]
- for group, event_id_map in iteritems(group_to_ids)
- }
- )
+ return {
+ group: [
+ state_event_map[v]
+ for v in itervalues(event_id_map)
+ if v in state_event_map
+ ]
+ for group, event_id_map in iteritems(group_to_ids)
+ }
@defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, state_filter):
@@ -690,7 +692,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
results.update(res)
- defer.returnValue(results)
+ return results
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all()
@@ -825,7 +827,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
for event_id, group in iteritems(event_to_groups)
}
- defer.returnValue({event: event_to_state[event] for event in event_ids})
+ return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
@@ -851,7 +853,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
for event_id, group in iteritems(event_to_groups)
}
- defer.returnValue({event: event_to_state[event] for event in event_ids})
+ return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
@@ -867,7 +869,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], state_filter)
- defer.returnValue(state_map[event_id])
+ return state_map[event_id]
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
@@ -883,7 +885,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
- defer.returnValue(state_map[event_id])
+ return state_map[event_id]
@cached(max_entries=50000)
def _get_state_group_for_event(self, event_id):
@@ -913,7 +915,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="_get_state_group_for_events",
)
- defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
+ return {row["event_id"]: row["state_group"] for row in rows}
def _get_state_for_group_using_cache(self, cache, group, state_filter):
"""Checks if group is in cache. See `_get_state_for_groups`
@@ -993,7 +995,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
incomplete_groups = incomplete_groups_m | incomplete_groups_nm
if not incomplete_groups:
- defer.returnValue(state)
+ return state
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
@@ -1020,7 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict)
- defer.returnValue(state)
+ return state
def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
"""Gets the state at each of a list of state groups, optionally
@@ -1494,7 +1496,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
)
- defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
+ return result * BATCH_SIZE_SCALE_FACTOR
@defer.inlineCallbacks
def _background_index_state(self, progress, batch_size):
@@ -1524,4 +1526,4 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
- defer.returnValue(1)
+ return 1
diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py
index 1cec84ee2e..e13efed417 100644
--- a/synapse/storage/stats.py
+++ b/synapse/storage/stats.py
@@ -66,7 +66,7 @@ class StatsStore(StateDeltasStore):
if not self.stats_enabled:
yield self._end_background_update("populate_stats_createtables")
- defer.returnValue(1)
+ return 1
# Get all the rooms that we want to process.
def _make_staging_area(txn):
@@ -120,7 +120,7 @@ class StatsStore(StateDeltasStore):
self.get_earliest_token_for_room_stats.invalidate_all()
yield self._end_background_update("populate_stats_createtables")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _populate_stats_cleanup(self, progress, batch_size):
@@ -129,7 +129,7 @@ class StatsStore(StateDeltasStore):
"""
if not self.stats_enabled:
yield self._end_background_update("populate_stats_cleanup")
- defer.returnValue(1)
+ return 1
position = yield self._simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
@@ -143,14 +143,14 @@ class StatsStore(StateDeltasStore):
yield self.runInteraction("populate_stats_cleanup", _delete_staging_area)
yield self._end_background_update("populate_stats_cleanup")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _populate_stats_process_rooms(self, progress, batch_size):
if not self.stats_enabled:
yield self._end_background_update("populate_stats_process_rooms")
- defer.returnValue(1)
+ return 1
# If we don't have progress filed, delete everything.
if not progress:
@@ -186,7 +186,7 @@ class StatsStore(StateDeltasStore):
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
yield self._end_background_update("populate_stats_process_rooms")
- defer.returnValue(1)
+ return 1
logger.info(
"Processing the next %d rooms of %d remaining",
@@ -211,16 +211,18 @@ class StatsStore(StateDeltasStore):
avatar_id = current_state_ids.get((EventTypes.RoomAvatar, ""))
canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, ""))
+ event_ids = [
+ join_rules_id,
+ history_visibility_id,
+ encryption_id,
+ name_id,
+ topic_id,
+ avatar_id,
+ canonical_alias_id,
+ ]
+
state_events = yield self.get_events(
- [
- join_rules_id,
- history_visibility_id,
- encryption_id,
- name_id,
- topic_id,
- avatar_id,
- canonical_alias_id,
- ]
+ [ev for ev in event_ids if ev is not None]
)
def _get_or_none(event_id, arg):
@@ -303,9 +305,9 @@ class StatsStore(StateDeltasStore):
if processed_event_count > batch_size:
# Don't process any more rooms, we've hit our batch size.
- defer.returnValue(processed_event_count)
+ return processed_event_count
- defer.returnValue(processed_event_count)
+ return processed_event_count
def delete_all_stats(self):
"""
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index a0465484df..856c2ee8d8 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -300,7 +300,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
if not room_ids:
- defer.returnValue({})
+ return {}
results = {}
room_ids = list(room_ids)
@@ -323,7 +323,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
results.update(dict(zip(rm_ids, res)))
- defer.returnValue(results)
+ return results
def get_rooms_that_changed(self, room_ids, from_key):
"""Given a list of rooms and a token, return rooms where there may have
@@ -364,7 +364,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
the chunk of events returned.
"""
if from_key == to_key:
- defer.returnValue(([], from_key))
+ return ([], from_key)
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -374,7 +374,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
if not has_changed:
- defer.returnValue(([], from_key))
+ return ([], from_key)
def f(txn):
sql = (
@@ -407,7 +407,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# get.
key = from_key
- defer.returnValue((ret, key))
+ return (ret, key)
@defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key):
@@ -415,14 +415,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
- defer.returnValue([])
+ return []
if from_id:
has_changed = self._membership_stream_cache.has_entity_changed(
user_id, int(from_id)
)
if not has_changed:
- defer.returnValue([])
+ return []
def f(txn):
sql = (
@@ -447,7 +447,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._set_before_and_after(ret, rows, topo_order=False)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token):
@@ -477,7 +477,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._set_before_and_after(events, rows)
- defer.returnValue((events, token))
+ return (events, token)
@defer.inlineCallbacks
def get_recent_event_ids_for_room(self, room_id, limit, end_token):
@@ -496,7 +496,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
# Allow a zero limit here, and no-op.
if limit == 0:
- defer.returnValue(([], end_token))
+ return ([], end_token)
end_token = RoomStreamToken.parse(end_token)
@@ -511,7 +511,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# We want to return the results in ascending order.
rows.reverse()
- defer.returnValue((rows, token))
+ return (rows, token)
def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
"""Gets details of the first event in a room at or after a stream ordering
@@ -549,12 +549,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
token = yield self.get_room_max_stream_ordering()
if room_id is None:
- defer.returnValue("s%d" % (token,))
+ return "s%d" % (token,)
else:
topo = yield self.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
- defer.returnValue("t%d-%d" % (topo, token))
+ return "t%d-%d" % (topo, token)
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
@@ -674,14 +674,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
[e for e in results["after"]["event_ids"]], get_prev_content=True
)
- defer.returnValue(
- {
- "events_before": events_before,
- "events_after": events_after,
- "start": results["before"]["token"],
- "end": results["after"]["token"],
- }
- )
+ return {
+ "events_before": events_before,
+ "events_after": events_after,
+ "start": results["before"]["token"],
+ "end": results["after"]["token"],
+ }
def _get_events_around_txn(
self, txn, room_id, event_id, before_limit, after_limit, event_filter
@@ -785,7 +783,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
events = yield self.get_events_as_list(event_ids)
- defer.returnValue((upper_bound, events))
+ return (upper_bound, events)
def get_federation_out_pos(self, typ):
return self._simple_select_one_onecol(
@@ -939,7 +937,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._set_before_and_after(events, rows)
- defer.returnValue((events, token))
+ return (events, token)
class StreamStore(StreamWorkerStore):
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index e88f8ea35f..20dd6bd53d 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -66,7 +66,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
room_id string, tag string and content string.
"""
if last_id == current_id:
- defer.returnValue([])
+ return []
def get_all_updated_tags_txn(txn):
sql = (
@@ -107,7 +107,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
)
results.extend(tags)
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
@@ -135,7 +135,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
user_id, int(stream_id)
)
if not changed:
- defer.returnValue({})
+ return {}
room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
@@ -145,7 +145,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
for room_id in room_ids:
results[room_id] = tags_by_room.get(room_id, {})
- defer.returnValue(results)
+ return results
def get_tags_for_room(self, user_id, room_id):
"""Get all the tags for the given room
@@ -194,7 +194,7 @@ class TagsStore(TagsWorkerStore):
self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def remove_tag_from_room(self, user_id, room_id, tag):
@@ -217,7 +217,7 @@ class TagsStore(TagsWorkerStore):
self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
def _update_revision_txn(self, txn, user_id, room_id, next_id):
"""Update the latest revision of the tags for the given user and room.
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index fd18619178..b3c3bf55bc 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -147,7 +147,7 @@ class TransactionStore(SQLBaseStore):
result = self._destination_retry_cache.get(destination, SENTINEL)
if result is not SENTINEL:
- defer.returnValue(result)
+ return result
result = yield self.runInteraction(
"get_destination_retry_timings",
@@ -158,7 +158,7 @@ class TransactionStore(SQLBaseStore):
# We don't hugely care about race conditions between getting and
# invalidating the cache, since we time out fairly quickly anyway.
self._destination_retry_cache[destination] = result
- defer.returnValue(result)
+ return result
def _get_destination_retry_timings(self, txn, destination):
result = self._simple_select_one_txn(
@@ -196,6 +196,26 @@ class TransactionStore(SQLBaseStore):
def _set_destination_retry_timings(
self, txn, destination, retry_last_ts, retry_interval
):
+
+ if self.database_engine.can_native_upsert:
+ # Upsert retry time interval if retry_interval is zero (i.e. we're
+ # resetting it) or greater than the existing retry interval.
+
+ sql = """
+ INSERT INTO destinations (destination, retry_last_ts, retry_interval)
+ VALUES (?, ?, ?)
+ ON CONFLICT (destination) DO UPDATE SET
+ retry_last_ts = EXCLUDED.retry_last_ts,
+ retry_interval = EXCLUDED.retry_interval
+ WHERE
+ EXCLUDED.retry_interval = 0
+ OR destinations.retry_interval < EXCLUDED.retry_interval
+ """
+
+ txn.execute(sql, (destination, retry_last_ts, retry_interval))
+
+ return
+
self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 83466e25d9..b5188d9bee 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -109,7 +109,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
yield self._end_background_update("populate_user_directory_createtables")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _populate_user_directory_cleanup(self, progress, batch_size):
@@ -131,7 +131,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
)
yield self._end_background_update("populate_user_directory_cleanup")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _populate_user_directory_process_rooms(self, progress, batch_size):
@@ -177,7 +177,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
yield self._end_background_update("populate_user_directory_process_rooms")
- defer.returnValue(1)
+ return 1
logger.info(
"Processing the next %d rooms of %d remaining"
@@ -257,9 +257,9 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
if processed_event_count > batch_size:
# Don't process any more rooms, we've hit our batch size.
- defer.returnValue(processed_event_count)
+ return processed_event_count
- defer.returnValue(processed_event_count)
+ return processed_event_count
@defer.inlineCallbacks
def _populate_user_directory_process_users(self, progress, batch_size):
@@ -268,7 +268,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"""
if not self.hs.config.user_directory_search_all_users:
yield self._end_background_update("populate_user_directory_process_users")
- defer.returnValue(1)
+ return 1
def _get_next_batch(txn):
sql = "SELECT user_id FROM %s LIMIT %s" % (
@@ -298,7 +298,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
# No more users -- complete the transaction.
if not users_to_work_on:
yield self._end_background_update("populate_user_directory_process_users")
- defer.returnValue(1)
+ return 1
logger.info(
"Processing the next %d users of %d remaining"
@@ -322,7 +322,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
progress,
)
- defer.returnValue(len(users_to_work_on))
+ return len(users_to_work_on)
@defer.inlineCallbacks
def is_room_world_readable_or_publicly_joinable(self, room_id):
@@ -344,16 +344,16 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
- defer.returnValue(True)
+ return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
if hist_vis_ev.content.get("history_visibility") == "world_readable":
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
"""
@@ -499,7 +499,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
user_ids = set(user_ids_share_pub)
user_ids.update(user_ids_share_priv)
- defer.returnValue(user_ids)
+ return user_ids
def add_users_who_share_private_room(self, room_id, user_id_tuples):
"""Insert entries into the users_who_share_private_rooms table. The first
@@ -609,7 +609,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
users = set(pub_rows)
users.update(rows)
- defer.returnValue(list(users))
+ return list(users)
@defer.inlineCallbacks
def get_rooms_in_common_for_users(self, user_id, other_user_id):
@@ -618,15 +618,15 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
sql = """
SELECT room_id FROM (
SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships USING (event_id)
+ INNER JOIN room_memberships AS m USING (event_id)
WHERE type = 'm.room.member'
- AND membership = 'join'
+ AND m.membership = 'join'
AND state_key = ?
) AS f1 INNER JOIN (
SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships USING (event_id)
+ INNER JOIN room_memberships AS m USING (event_id)
WHERE type = 'm.room.member'
- AND membership = 'join'
+ AND m.membership = 'join'
AND state_key = ?
) f2 USING (room_id)
"""
@@ -635,7 +635,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"get_rooms_in_common_for_users", None, sql, user_id, other_user_id
)
- defer.returnValue([room_id for room_id, in rows])
+ return [room_id for room_id, in rows]
def delete_all_from_user_dir(self):
"""Delete the entire user directory
@@ -782,7 +782,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
limited = len(results) > limit
- defer.returnValue({"limited": limited, "results": results})
+ return {"limited": limited, "results": results}
def _parse_query_sqlite(search_term):
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py
index 1815fdc0dd..05cabc2282 100644
--- a/synapse/storage/user_erasure_store.py
+++ b/synapse/storage/user_erasure_store.py
@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import operator
-from twisted.internet import defer
+import operator
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
@@ -67,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
res = dict((u, u in erased_users) for u in user_ids)
- defer.returnValue(res)
+ return res
class UserErasureStore(UserErasureWorkerStore):
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 488c49747a..b91fb2db7b 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -56,7 +56,7 @@ class EventSources(object):
device_list_key=device_list_key,
groups_key=groups_key,
)
- defer.returnValue(token)
+ return token
@defer.inlineCallbacks
def get_current_token_for_pagination(self):
@@ -80,4 +80,4 @@ class EventSources(object):
device_list_key=0,
groups_key=0,
)
- defer.returnValue(token)
+ return token
diff --git a/synapse/third_party_rules/__init__.py b/synapse/third_party_rules/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/synapse/third_party_rules/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
new file mode 100644
index 0000000000..253bba664b
--- /dev/null
+++ b/synapse/third_party_rules/access_rules.py
@@ -0,0 +1,586 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import email.utils
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.api.errors import SynapseError
+from synapse.config._base import ConfigError
+from synapse.types import get_domain_from_id
+
+ACCESS_RULES_TYPE = "im.vector.room.access_rules"
+ACCESS_RULE_RESTRICTED = "restricted"
+ACCESS_RULE_UNRESTRICTED = "unrestricted"
+ACCESS_RULE_DIRECT = "direct"
+
+VALID_ACCESS_RULES = (
+ ACCESS_RULE_DIRECT,
+ ACCESS_RULE_RESTRICTED,
+ ACCESS_RULE_UNRESTRICTED,
+)
+
+# Rules to which we need to apply the power levels restrictions.
+#
+# These are all of the rules that neither:
+# * forbid users from joining based on a server blacklist (which means that there
+# is no need to apply power level restrictions), nor
+# * target direct chats (since we allow both users to be room admins in this case).
+#
+# The power-level restrictions, when they are applied, prevent the following:
+# * the default power level for users (users_default) being set to anything other than 0.
+# * a non-default power level being assigned to any user which would be forbidden from
+# joining a restricted room.
+RULES_WITH_RESTRICTED_POWER_LEVELS = (ACCESS_RULE_UNRESTRICTED,)
+
+
+class RoomAccessRules(object):
+ """Implementation of the ThirdPartyEventRules module API that allows federation admins
+ to define custom rules for specific events and actions.
+ Implements the custom behaviour for the "im.vector.room.access_rules" state event.
+
+ Takes a config in the format:
+
+ third_party_event_rules:
+ module: third_party_rules.RoomAccessRules
+ config:
+ # List of domains (server names) that can't be invited to rooms if the
+ # "restricted" rule is set. Defaults to an empty list.
+ domains_forbidden_when_restricted: []
+
+ # Identity server to use when checking the HS an email address belongs to
+ # using the /info endpoint. Required.
+ id_server: "vector.im"
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config, http_client):
+ self.http_client = http_client
+
+ self.id_server = config["id_server"]
+
+ self.domains_forbidden_when_restricted = config.get(
+ "domains_forbidden_when_restricted", []
+ )
+
+ @staticmethod
+ def parse_config(config):
+ if "id_server" in config:
+ return config
+ else:
+ raise ConfigError("No IS for event rules TchapEventRules")
+
+ def on_create_room(self, requester, config, is_requester_admin):
+ """Implements synapse.events.ThirdPartyEventRules.on_create_room
+
+ Checks if a im.vector.room.access_rules event is being set during room creation.
+ If yes, make sure the event is correct. Otherwise, append an event with the
+ default rule to the initial state.
+ """
+ is_direct = config.get("is_direct")
+ preset = config.get("preset")
+ access_rule = None
+ join_rule = None
+
+ # If there's a rules event in the initial state, check if it complies with the
+ # spec for im.vector.room.access_rules and deny the request if not.
+ for event in config.get("initial_state", []):
+ if event["type"] == ACCESS_RULES_TYPE:
+ access_rule = event["content"].get("rule")
+
+ # Make sure the event has a valid content.
+ if access_rule is None:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule name is valid.
+ if access_rule not in VALID_ACCESS_RULES:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule is "direct" if the room is a direct chat.
+ if (is_direct and access_rule != ACCESS_RULE_DIRECT) or (
+ access_rule == ACCESS_RULE_DIRECT and not is_direct
+ ):
+ raise SynapseError(400, "Invalid access rule")
+
+ if event["type"] == EventTypes.JoinRules:
+ join_rule = event["content"].get("join_rule")
+
+ if access_rule is None:
+ # If there's no access rules event in the initial state, create one with the
+ # default setting.
+ if is_direct:
+ default_rule = ACCESS_RULE_DIRECT
+ else:
+ # If the default value for non-direct chat changes, we should make another
+ # case here for rooms created with either a "public" join_rule or the
+ # "public_chat" preset to make sure those keep defaulting to "restricted"
+ default_rule = ACCESS_RULE_RESTRICTED
+
+ if not config.get("initial_state"):
+ config["initial_state"] = []
+
+ config["initial_state"].append(
+ {
+ "type": ACCESS_RULES_TYPE,
+ "state_key": "",
+ "content": {"rule": default_rule},
+ }
+ )
+
+ access_rule = default_rule
+
+ # Check that the preset or the join rule in use is compatible with the access
+ # rule, whether it's a user-defined one or the default one (i.e. if it involves
+ # a "public" join rule, the access rule must be "restricted").
+ if (
+ join_rule == JoinRules.PUBLIC or preset == RoomCreationPreset.PUBLIC_CHAT
+ ) and access_rule != ACCESS_RULE_RESTRICTED:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Check if the creator can override values for the power levels.
+ allowed = self._is_power_level_content_allowed(
+ config.get("power_level_content_override", {}), access_rule
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content override")
+
+ # Second loop for events we need to know the current rule to process.
+ for event in config.get("initial_state", []):
+ if event["type"] == EventTypes.PowerLevels:
+ allowed = self._is_power_level_content_allowed(
+ event["content"], access_rule
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content")
+
+ @defer.inlineCallbacks
+ def check_threepid_can_be_invited(self, medium, address, state_events):
+ """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited
+
+ Check if a threepid can be invited to the room via a 3PID invite given the current
+ rules and the threepid's address, by retrieving the HS it's mapped to from the
+ configured identity server, and checking if we can invite users from it.
+ """
+ rule = self._get_rule_from_state(state_events)
+
+ if medium != "email":
+ defer.returnValue(False)
+
+ if rule != ACCESS_RULE_RESTRICTED:
+ # Only "restricted" requires filtering 3PID invites. We don't need to do
+ # anything for "direct" here, because only "restricted" requires filtering
+ # based on the HS the address is mapped to.
+ defer.returnValue(True)
+
+ parsed_address = email.utils.parseaddr(address)[1]
+ if parsed_address != address:
+ # Avoid reproducing the security issue described here:
+ # https://matrix.org/blog/2019/04/18/security-update-sydent-1-0-2
+ # It's probably not worth it but let's just be overly safe here.
+ defer.returnValue(False)
+
+ # Get the HS this address belongs to from the identity server.
+ res = yield self.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/info" % (self.id_server,),
+ {"medium": medium, "address": address},
+ )
+
+ # Look for a domain that's not forbidden from being invited.
+ if not res.get("hs"):
+ defer.returnValue(False)
+ if res.get("hs") in self.domains_forbidden_when_restricted:
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
+ def check_event_allowed(self, event, state_events):
+ """Implements synapse.events.ThirdPartyEventRules.check_event_allowed
+
+ Checks the event's type and the current rule and calls the right function to
+ determine whether the event can be allowed.
+ """
+ if event.type == ACCESS_RULES_TYPE:
+ return self._on_rules_change(event, state_events)
+
+ # We need to know the rule to apply when processing the event types below.
+ rule = self._get_rule_from_state(state_events)
+
+ if event.type == EventTypes.PowerLevels:
+ return self._is_power_level_content_allowed(event.content, rule)
+
+ if event.type == EventTypes.Member or event.type == EventTypes.ThirdPartyInvite:
+ return self._on_membership_or_invite(event, rule, state_events)
+
+ if event.type == EventTypes.JoinRules:
+ return self._on_join_rule_change(event, rule)
+
+ if event.type == EventTypes.RoomAvatar:
+ return self._on_room_avatar_change(event, rule)
+
+ if event.type == EventTypes.Name:
+ return self._on_room_name_change(event, rule)
+
+ if event.type == EventTypes.Topic:
+ return self._on_room_topic_change(event, rule)
+
+ return True
+
+ def _on_rules_change(self, event, state_events):
+ """Implement the checks and behaviour specified on allowing or forbidding a new
+ im.vector.room.access_rules event.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ new_rule = event.content.get("rule")
+
+ # Check for invalid values.
+ if new_rule not in VALID_ACCESS_RULES:
+ return False
+
+ # We must not allow rooms with the "public" join rule to be given any other access
+ # rule than "restricted".
+ join_rule = self._get_join_rule_from_state(state_events)
+ if join_rule == JoinRules.PUBLIC and new_rule != ACCESS_RULE_RESTRICTED:
+ return False
+
+ # Make sure we don't apply "direct" if the room has more than two members.
+ if new_rule == ACCESS_RULE_DIRECT:
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ if len(existing_members) > 2 or len(threepid_tokens) > 1:
+ return False
+
+ prev_rules_event = state_events.get((ACCESS_RULES_TYPE, ""))
+
+ # Now that we know the new rule doesn't break the "direct" case, we can allow any
+ # new rule in rooms that had none before.
+ if prev_rules_event is None:
+ return True
+
+ prev_rule = prev_rules_event.content.get("rule")
+
+ # Currently, we can only go from "restricted" to "unrestricted".
+ if prev_rule == ACCESS_RULE_RESTRICTED and new_rule == ACCESS_RULE_UNRESTRICTED:
+ return True
+
+ return False
+
+ def _on_membership_or_invite(self, event, rule, state_events):
+ """Applies the correct rule for incoming m.room.member and
+ m.room.third_party_invite events.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ if rule == ACCESS_RULE_RESTRICTED:
+ ret = self._on_membership_or_invite_restricted(event)
+ elif rule == ACCESS_RULE_UNRESTRICTED:
+ ret = self._on_membership_or_invite_unrestricted()
+ elif rule == ACCESS_RULE_DIRECT:
+ ret = self._on_membership_or_invite_direct(event, state_events)
+ else:
+ # We currently apply the default (restricted) if we don't know the rule, we
+ # might want to change that in the future.
+ ret = self._on_membership_or_invite_restricted(event)
+
+ return ret
+
+ def _on_membership_or_invite_restricted(self, event):
+ """Implements the checks and behaviour specified for the "restricted" rule.
+
+ "restricted" currently means that users can only invite users if their server is
+ included in a limited list of domains.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # We're not applying the rules on m.room.third_party_member events here because
+ # the filtering on threepids is done in check_threepid_can_be_invited, which is
+ # called before check_event_allowed.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return True
+
+ # We only need to process "join" and "invite" memberships, in order to be backward
+ # compatible, e.g. if a user from a blacklisted server joined a restricted room
+ # before the rules started being enforced on the server, that user must be able to
+ # leave it.
+ if event.membership not in [Membership.JOIN, Membership.INVITE]:
+ return True
+
+ invitee_domain = get_domain_from_id(event.state_key)
+ return invitee_domain not in self.domains_forbidden_when_restricted
+
+ def _on_membership_or_invite_unrestricted(self):
+ """Implements the checks and behaviour specified for the "unrestricted" rule.
+
+ "unrestricted" currently means that every event is allowed.
+
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return True
+
+ def _on_membership_or_invite_direct(self, event, state_events):
+ """Implements the checks and behaviour specified for the "direct" rule.
+
+ "direct" currently means that no member is allowed apart from the two initial
+ members the room was created for (i.e. the room's creator and their first
+ invitee).
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # Get the room memberships and 3PID invite tokens from the room's state.
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ # There should never be more than one 3PID invite in the room state: if the second
+ # original user came and left, and we're inviting them using their email address,
+ # given we know they have a Matrix account binded to the address (so they could
+ # join the first time), Synapse will successfully look it up before attempting to
+ # store an invite on the IS.
+ if len(threepid_tokens) == 1 and event.type == EventTypes.ThirdPartyInvite:
+ # If we already have a 3PID invite in flight, don't accept another one, unless
+ # the new one has the same invite token as its state key. This is because 3PID
+ # invite revocations must be allowed, and a revocation is basically a new 3PID
+ # invite event with an empty content and the same token as the invite it
+ # revokes.
+ return event.state_key in threepid_tokens
+
+ if len(existing_members) == 2:
+ # If the user was within the two initial user of the room, Synapse would have
+ # looked it up successfully and thus sent a m.room.member here instead of
+ # m.room.third_party_invite.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return False
+
+ # We can only have m.room.member events here. The rule in this case is to only
+ # allow the event if its target is one of the initial two members in the room,
+ # i.e. the state key of one of the two m.room.member states in the room.
+ return event.state_key in existing_members
+
+ # We're alone in the room (and always have been) and there's one 3PID invite in
+ # flight.
+ if len(existing_members) == 1 and len(threepid_tokens) == 1:
+ # We can only have m.room.member events here. In this case, we can only allow
+ # the event if it's either a m.room.member from the joined user (we can assume
+ # that the only m.room.member event is a join otherwise we wouldn't be able to
+ # send an event to the room) or an an invite event which target is the invited
+ # user.
+ target = event.state_key
+ is_from_threepid_invite = self._is_invite_from_threepid(
+ event, threepid_tokens[0]
+ )
+ if is_from_threepid_invite or target == existing_members[0]:
+ return True
+
+ return False
+
+ return True
+
+ def _is_power_level_content_allowed(self, content, access_rule):
+ """Check if a given power levels event is permitted under the given access rule.
+
+ It shouldn't be allowed if it either changes the default PL to a non-0 value or
+ gives a non-0 PL to a user that would have been forbidden from joining the room
+ under a more restrictive access rule.
+
+ Args:
+ content (dict[]): The content of the m.room.power_levels event to check.
+ access_rule (str): The access rule in place in this room.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # Check if we need to apply the restrictions with the current rule.
+ if access_rule not in RULES_WITH_RESTRICTED_POWER_LEVELS:
+ return True
+
+ # If users_default is explicitly set to a non-0 value, deny the event.
+ users_default = content.get("users_default", 0)
+ if users_default:
+ return False
+
+ users = content.get("users", {})
+ for user_id, power_level in users.items():
+ server_name = get_domain_from_id(user_id)
+ # Check the domain against the blacklist. If found, and the PL isn't 0, deny
+ # the event.
+ if (
+ server_name in self.domains_forbidden_when_restricted
+ and power_level != 0
+ ):
+ return False
+
+ return True
+
+ def _on_join_rule_change(self, event, rule):
+ """Check whether a join rule change is allowed. A join rule change is always
+ allowed unless the new join rule is "public" and the current access rule isn't
+ "restricted".
+ The rationale is that external users (those whose server would be denied access
+ to rooms enforcing the "restricted" access rule) should always rely on non-
+ external users for access to rooms, therefore they shouldn't be able to access
+ rooms that don't require an invite to be joined.
+
+ Note that we currently rely on the default access rule being "restricted": during
+ room creation, the m.room.join_rules event will be sent *before* the
+ im.vector.room.access_rules one, so the access rule that will be considered here
+ in this case will be the default "restricted" one. This is fine since the
+ "restricted" access rule allows any value for the join rule, but we should keep
+ that in mind if we need to change the default access rule in the future.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ if event.content.get("join_rule") == JoinRules.PUBLIC:
+ return rule == ACCESS_RULE_RESTRICTED
+
+ return True
+
+ def _on_room_avatar_change(self, event, rule):
+ """Check whether a change of room avatar is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ def _on_room_name_change(self, event, rule):
+ """Check whether a change of room name is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ def _on_room_topic_change(self, event, rule):
+ """Check whether a change of room topic is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ @staticmethod
+ def _get_rule_from_state(state_events):
+ """Extract the rule to be applied from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ str, the name of the rule (either "direct", "restricted" or "unrestricted")
+ """
+ access_rules = state_events.get((ACCESS_RULES_TYPE, ""))
+ if access_rules is None:
+ rule = ACCESS_RULE_RESTRICTED
+ else:
+ rule = access_rules.content.get("rule")
+ return rule
+
+ @staticmethod
+ def _get_join_rule_from_state(state_events):
+ """Extract the room's join rule from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ str, the name of the join rule (either "public", or "invite")
+ """
+ join_rule_event = state_events.get((EventTypes.JoinRules, ""))
+ if join_rule_event is None:
+ return None
+ return join_rule_event.content.get("join_rule")
+
+ @staticmethod
+ def _get_members_and_tokens_from_state(state_events):
+ """Retrieves from a list of state events the list of users that have a
+ m.room.member event in the room, and the tokens of 3PID invites in the room.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ existing_members (list[str]): List of targets of the m.room.member events in
+ the state.
+ threepid_invite_tokens (list[str]): List of tokens of the 3PID invites in the
+ state.
+ """
+ existing_members = []
+ threepid_invite_tokens = []
+ for key, state_event in state_events.items():
+ if key[0] == EventTypes.Member and state_event.content:
+ existing_members.append(state_event.state_key)
+ if key[0] == EventTypes.ThirdPartyInvite and state_event.content:
+ # Don't include revoked invites.
+ threepid_invite_tokens.append(state_event.state_key)
+
+ return existing_members, threepid_invite_tokens
+
+ @staticmethod
+ def _is_invite_from_threepid(invite, threepid_invite_token):
+ """Checks whether the given invite follows the given 3PID invite.
+
+ Args:
+ invite (EventBase): The m.room.member event with "invite" membership.
+ threepid_invite_token (str): The state key from the 3PID invite.
+ """
+ token = (
+ invite.content.get("third_party_invite", {})
+ .get("signed", {})
+ .get("token", "")
+ )
+
+ return token == threepid_invite_token
diff --git a/synapse/types.py b/synapse/types.py
index 51eadb6ad4..94c01b0a18 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -16,6 +16,8 @@ import re
import string
from collections import namedtuple
+from six.moves import filter
+
import attr
from synapse.api.errors import SynapseError
@@ -235,6 +237,19 @@ def contains_invalid_mxid_characters(localpart):
return any(c not in mxid_localpart_allowed_characters for c in localpart)
+def strip_invalid_mxid_characters(localpart):
+ """Removes any invalid characters from an mxid
+
+ Args:
+ localpart (basestring): the localpart to be stripped
+
+ Returns:
+ localpart (basestring): the localpart having been stripped
+ """
+ filtered = filter(lambda c: c in mxid_localpart_allowed_characters, localpart)
+ return "".join(filtered)
+
+
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index f506b2a695..7856353002 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -49,7 +49,7 @@ class Clock(object):
with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
res = yield d
- defer.returnValue(res)
+ return res
def time(self):
"""Returns the current system time in seconds since epoch."""
@@ -59,7 +59,7 @@ class Clock(object):
"""Returns the current system time in miliseconds since epoch."""
return int(self.time() * 1000)
- def looping_call(self, f, msec):
+ def looping_call(self, f, msec, *args, **kwargs):
"""Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time.
@@ -70,8 +70,10 @@ class Clock(object):
Args:
f(function): The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds.
+ *args: Postional arguments to pass to function.
+ **kwargs: Key arguments to pass to function.
"""
- call = task.LoopingCall(f)
+ call = task.LoopingCall(f, *args, **kwargs)
call.clock = self._reactor
d = call.start(msec / 1000.0, now=False)
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 58a6b8764f..f1c46836b1 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -366,7 +366,7 @@ class ReadWriteLock(object):
new_defer.callback(None)
self.key_to_current_readers.get(key, set()).discard(new_defer)
- defer.returnValue(_ctx_manager())
+ return _ctx_manager()
@defer.inlineCallbacks
def write(self, key):
@@ -396,7 +396,7 @@ class ReadWriteLock(object):
if self.key_to_current_writer[key] == new_defer:
self.key_to_current_writer.pop(key)
- defer.returnValue(_ctx_manager())
+ return _ctx_manager()
def _cancelled_to_timed_out_error(value, timeout):
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8271229015..b50e3503f0 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -51,7 +52,19 @@ response_cache_evicted = Gauge(
response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
-def register_cache(cache_type, cache_name, cache):
+def register_cache(cache_type, cache_name, cache, collect_callback=None):
+ """Register a cache object for metric collection.
+
+ Args:
+ cache_type (str):
+ cache_name (str): name of the cache
+ cache (object): cache itself
+ collect_callback (callable|None): if not None, a function which is called during
+ metric collection to update additional metrics.
+
+ Returns:
+ CacheMetric: an object which provides inc_{hits,misses,evictions} methods
+ """
# Check if the metric is already registered. Unregister it, if so.
# This usually happens during tests, as at runtime these caches are
@@ -90,6 +103,8 @@ def register_cache(cache_type, cache_name, cache):
cache_hits.labels(cache_name).set(self.hits)
cache_evicted.labels(cache_name).set(self.evicted_size)
cache_total.labels(cache_name).set(self.hits + self.misses)
+ if collect_callback:
+ collect_callback()
except Exception as e:
logger.warn("Error calculating metrics for %s: %s", cache_name, e)
raise
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 675db2f448..43f66ec4be 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -19,8 +19,9 @@ import logging
import threading
from collections import namedtuple
-import six
-from six import itervalues, string_types
+from six import itervalues
+
+from prometheus_client import Gauge
from twisted.internet import defer
@@ -30,13 +31,18 @@ from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.stringutils import to_ascii
from . import register_cache
logger = logging.getLogger(__name__)
+cache_pending_metric = Gauge(
+ "synapse_util_caches_cache_pending",
+ "Number of lookups currently pending for this cache",
+ ["name"],
+)
+
_CacheSentinel = object()
@@ -82,11 +88,19 @@ class Cache(object):
self.name = name
self.keylen = keylen
self.thread = None
- self.metrics = register_cache("cache", name, self.cache)
+ self.metrics = register_cache(
+ "cache",
+ name,
+ self.cache,
+ collect_callback=self._metrics_collection_callback,
+ )
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
+ def _metrics_collection_callback(self):
+ cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
+
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
@@ -108,7 +122,7 @@ class Cache(object):
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
- Either a Deferred or the raw result
+ Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
@@ -132,9 +146,14 @@ class Cache(object):
return default
def set(self, key, value, callback=None):
+ if not isinstance(value, defer.Deferred):
+ raise TypeError("not a Deferred")
+
callbacks = [callback] if callback else []
self.check_thread()
- entry = CacheEntry(deferred=value, callbacks=callbacks)
+ observable = ObservableDeferred(value, consumeErrors=True)
+ observer = defer.maybeDeferred(observable.observe)
+ entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
@@ -142,20 +161,31 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
- def shuffle(result):
+ def compare_and_pop():
+ """Check if our entry is still the one in _pending_deferred_cache, and
+ if so, pop it.
+
+ Returns true if the entries matched.
+ """
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
+ return True
+
+ # oops, the _pending_deferred_cache has been updated since
+ # we started our query, so we are out of date.
+ #
+ # Better put back whatever we took out. (We do it this way
+ # round, rather than peeking into the _pending_deferred_cache
+ # and then removing on a match, to make the common case faster)
+ if existing_entry is not None:
+ self._pending_deferred_cache[key] = existing_entry
+
+ return False
+
+ def cb(result):
+ if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
- # oops, the _pending_deferred_cache has been updated since
- # we started our query, so we are out of date.
- #
- # Better put back whatever we took out. (We do it this way
- # round, rather than peeking into the _pending_deferred_cache
- # and then removing on a match, to make the common case faster)
- if existing_entry is not None:
- self._pending_deferred_cache[key] = existing_entry
-
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
@@ -163,9 +193,16 @@ class Cache(object):
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
- return result
- entry.deferred.addCallback(shuffle)
+ def eb(_fail):
+ compare_and_pop()
+ entry.invalidate()
+
+ # once the deferred completes, we can move the entry from the
+ # _pending_deferred_cache to the real cache.
+ #
+ observer.addCallbacks(cb, eb)
+ return observable
def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
@@ -289,7 +326,7 @@ class CacheDescriptor(_CacheDescriptorBase):
def foo(self, key, cache_context):
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
- defer.returnValue(r1 + r2)
+ return r1 + r2
Args:
num_args (int): number of positional arguments (excluding ``self`` and
@@ -398,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
- # If our cache_key is a string on py2, try to convert to ascii
- # to save a bit of space in large caches. Py3 does this
- # internally automatically.
- if six.PY2 and isinstance(cache_key, string_types):
- cache_key = to_ascii(cache_key)
-
- result_d = ObservableDeferred(ret, consumeErrors=True)
- cache.set(cache_key, result_d, callback=invalidate_callback)
+ result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()
- if isinstance(observer, defer.Deferred):
- return make_deferred_yieldable(observer)
- else:
- return observer
+ return make_deferred_yieldable(observer)
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
@@ -527,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
missing.add(arg)
if missing:
- # we need an observable deferred for each entry in the list,
+ # we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
@@ -535,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
- observable = ObservableDeferred(deferred)
- cache.set(key, observable, callback=invalidate_callback)
+ cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res):
# the wrapped function has completed. It returns a
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index d6908e169d..82d3eefe0e 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -121,7 +121,7 @@ class ResponseCache(object):
@defer.inlineCallbacks
def handle_request(request):
# etc
- defer.returnValue(result)
+ return result
result = yield response_cache.wrap(
key,
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index c30b6de19c..0910930c21 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -67,7 +67,7 @@ def measure_func(name):
def measured_func(self, *args, **kwargs):
with Measure(self.clock, name):
r = yield func(self, *args, **kwargs)
- defer.returnValue(r)
+ return r
return measured_func
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index d8d0ceae51..0862b5ca5a 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -95,15 +95,13 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
# maximum backoff even though it might only have been down briefly
backoff_on_failure = not ignore_backoff
- defer.returnValue(
- RetryDestinationLimiter(
- destination,
- clock,
- store,
- retry_interval,
- backoff_on_failure=backoff_on_failure,
- **kwargs
- )
+ return RetryDestinationLimiter(
+ destination,
+ clock,
+ store,
+ retry_interval,
+ backoff_on_failure=backoff_on_failure,
+ **kwargs
)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 982c6d81ca..6a2464cab3 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,12 +15,15 @@
# limitations under the License.
import random
+import re
import string
import six
from six import PY2, PY3
from six.moves import range
+from synapse.api.errors import Codes, SynapseError
+
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# random_string and random_string_with_symbols are used for a range of things,
@@ -27,6 +31,8 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# we get cryptographically-secure randoms.
rand = random.SystemRandom()
+client_secret_regex = re.compile(r"^[0-9a-zA-Z.=_-]+$")
+
def random_string(length):
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
@@ -109,3 +115,11 @@ def exception_to_unicode(e):
return msg.decode("utf-8", errors="replace")
else:
return msg
+
+
+def assert_valid_client_secret(client_secret):
+ """Validate that a given string matches the client_secret regex defined by the spec"""
+ if client_secret_regex.match(client_secret) is None:
+ raise SynapseError(
+ 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
+ )
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 3ec1dfb0c2..34ce7cac16 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -16,11 +16,14 @@
import logging
import re
+from twisted.internet import defer
+
logger = logging.getLogger(__name__)
+@defer.inlineCallbacks
def check_3pid_allowed(hs, medium, address):
- """Checks whether a given format of 3PID is allowed to be used on this HS
+ """Checks whether a given 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
@@ -28,9 +31,36 @@ def check_3pid_allowed(hs, medium, address):
address (str): address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
- bool: whether the 3PID medium/address is allowed to be added to this HS
+ defered bool: whether the 3PID medium/address is allowed to be added to this HS
"""
+ if hs.config.check_is_for_allowed_local_3pids:
+ data = yield hs.get_simple_http_client().get_json(
+ "https://%s%s"
+ % (
+ hs.config.check_is_for_allowed_local_3pids,
+ "/_matrix/identity/api/v1/internal-info",
+ ),
+ {"medium": medium, "address": address},
+ )
+
+ # Check for invalid response
+ if "hs" not in data and "shadow_hs" not in data:
+ defer.returnValue(False)
+
+ # Check if this user is intended to register for this homeserver
+ if (
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
+ ):
+ defer.returnValue(False)
+
+ if data.get("requires_invite", False) and not data.get("invited", False):
+ # Requires an invite but hasn't been invited
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
@@ -43,8 +73,8 @@ def check_3pid_allowed(hs, medium, address):
if medium == constraint["medium"] and re.match(
constraint["pattern"], address
):
- return True
+ defer.returnValue(True)
else:
- return True
+ defer.returnValue(True)
- return False
+ defer.returnValue(False)
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index a4d9a462f7..fa404b9d75 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -22,6 +22,23 @@ logger = logging.getLogger(__name__)
def get_version_string(module):
+ """Given a module calculate a git-aware version string for it.
+
+ If called on a module not in a git checkout will return `__verison__`.
+
+ Args:
+ module (module)
+
+ Returns:
+ str
+ """
+
+ cached_version = getattr(module, "_synapse_version_string_cache", None)
+ if cached_version:
+ return cached_version
+
+ version_string = module.__version__
+
try:
null = open(os.devnull, "w")
cwd = os.path.dirname(os.path.abspath(module.__file__))
@@ -80,8 +97,10 @@ def get_version_string(module):
s for s in (git_branch, git_tag, git_commit, git_dirty) if s
)
- return "%s (%s)" % (module.__version__, git_version)
+ version_string = "%s (%s)" % (module.__version__, git_version)
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
- return module.__version__
+ module._synapse_version_string_cache = version_string
+
+ return version_string
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 2a11c83596..a19011b793 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -43,7 +43,12 @@ MEMBERSHIP_PRIORITY = (
@defer.inlineCallbacks
def filter_events_for_client(
- store, user_id, events, is_peeking=False, always_include_ids=frozenset()
+ store,
+ user_id,
+ events,
+ is_peeking=False,
+ always_include_ids=frozenset(),
+ apply_retention_policies=True,
):
"""
Check which events a user is allowed to see
@@ -59,6 +64,10 @@ def filter_events_for_client(
events
always_include_ids (set(event_id)): set of event ids to specifically
include (unless sender is ignored)
+ apply_retention_policies (bool): Whether to filter out events that's older than
+ allowed by the room's retention policy. Useful when this function is called
+ to e.g. check whether a user should be allowed to see the state at a given
+ event rather than to know if it should send an event to a user's client(s).
Returns:
Deferred[list[synapse.events.EventBase]]
@@ -86,6 +95,15 @@ def filter_events_for_client(
erased_senders = yield store.are_users_erased((e.sender for e in events))
+ if apply_retention_policies:
+ room_ids = set(e.room_id for e in events)
+ retention_policies = {}
+
+ for room_id in room_ids:
+ retention_policies[room_id] = (
+ yield store.get_retention_policy_for_room(room_id)
+ )
+
def allowed(event):
"""
Args:
@@ -103,6 +121,18 @@ def filter_events_for_client(
if not event.is_state() and event.sender in ignore_list:
return None
+ # Don't try to apply the room's retention policy if the event is a state event, as
+ # MSC1763 states that retention is only considered for non-state events.
+ if apply_retention_policies and not event.is_state():
+ retention_policy = retention_policies[event.room_id]
+ max_lifetime = retention_policy.get("max_lifetime")
+
+ if max_lifetime is not None:
+ oldest_allowed_ts = store.clock.time_msec() - max_lifetime
+
+ if event.origin_server_ts < oldest_allowed_ts:
+ return None
+
if event.event_id in always_include_ids:
return event
@@ -208,7 +238,7 @@ def filter_events_for_client(
filtered_events = filter(operator.truth, filtered_events)
# we turn it into a list before returning it.
- defer.returnValue(list(filtered_events))
+ return list(filtered_events)
@defer.inlineCallbacks
@@ -317,11 +347,11 @@ def filter_events_for_server(
elif redact:
to_return.append(prune_event(e))
- defer.returnValue(to_return)
+ return to_return
# If there are no erased users then we can just return the given list
# of events without having to copy it.
- defer.returnValue(events)
+ return events
# Ok, so we're dealing with events that have non-trivial visibility
# rules, so we need to also get the memberships of the room.
@@ -384,4 +414,4 @@ def filter_events_for_server(
elif redact:
to_return.append(prune_event(e))
- defer.returnValue(to_return)
+ return to_return
diff --git a/sytest-blacklist b/sytest-blacklist
index 11785fd43f..7d6b1d0a2f 100644
--- a/sytest-blacklist
+++ b/sytest-blacklist
@@ -1,6 +1,6 @@
# This file serves as a blacklist for SyTest tests that we expect will fail in
# Synapse.
-#
+#
# Each line of this file is scanned by sytest during a run and if the line
# exactly matches the name of a test, it will be marked as "expected fail",
# meaning the test will still run, but failure will not mark the entire test
@@ -29,3 +29,24 @@ Enabling an unknown default rule fails with 404
# Blacklisted due to https://github.com/matrix-org/synapse/issues/1663
New federated private chats get full presence information (SYN-115)
+
+# flaky test
+If remote user leaves room we no longer receive device updates
+
+# flaky test
+Can re-join room if re-invited
+
+# flaky test
+Forgotten room messages cannot be paginated
+
+# flaky test
+Local device key changes get to remote servers
+
+# flaky test
+Old leaves are present in gapped incremental syncs
+
+# flaky test on workers
+Old members are included in gappy incr LL sync if they start speaking
+
+# flaky test on workers
+Presence changes to UNAVAILABLE are reported to remote room members
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 795703967d..c4f0bbd3dd 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -86,35 +86,6 @@ class KeyringTestCase(unittest.HomeserverTestCase):
getattr(LoggingContext.current_context(), "request", None), expected
)
- def test_wait_for_previous_lookups(self):
- kr = keyring.Keyring(self.hs)
-
- lookup_1_deferred = defer.Deferred()
- lookup_2_deferred = defer.Deferred()
-
- # we run the lookup in a logcontext so that the patched inlineCallbacks can check
- # it is doing the right thing with logcontexts.
- wait_1_deferred = run_in_context(
- kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
- )
-
- # there were no previous lookups, so the deferred should be ready
- self.successResultOf(wait_1_deferred)
-
- # set off another wait. It should block because the first lookup
- # hasn't yet completed.
- wait_2_deferred = run_in_context(
- kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
- )
-
- self.assertFalse(wait_2_deferred.called)
-
- # let the first lookup complete (in the sentinel context)
- lookup_1_deferred.callback(None)
-
- # now the second wait should complete.
- self.successResultOf(wait_2_deferred)
-
def test_verify_json_objects_for_server_awaits_previous_requests(self):
key1 = signedjson.key.generate_signing_key(1)
@@ -136,7 +107,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertEquals(LoggingContext.current_context().request, "11")
with PreserveLoggingContext():
yield persp_deferred
- defer.returnValue(persp_resp)
+ return persp_resp
self.http_client.post_json.side_effect = get_perspectives
@@ -583,7 +554,7 @@ def run_in_context(f, *args, **kwargs):
# logs.
ctx.request = "testctx"
rv = yield f(*args, **kwargs)
- defer.returnValue(rv)
+ return rv
def _verify_json_for_server(kr, *args):
@@ -594,6 +565,6 @@ def _verify_json_for_server(kr, *args):
@defer.inlineCallbacks
def v():
rv1 = yield kr.verify_json_for_server(*args)
- defer.returnValue(rv1)
+ return rv1
return run_in_context(v)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
new file mode 100644
index 0000000000..b1ae15a2bd
--- /dev/null
+++ b/tests/handlers/test_federation.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError, Codes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+
+
+class FederationTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(http_client=None)
+ self.handler = hs.get_handlers().federation_handler
+ self.store = hs.get_datastore()
+ return hs
+
+ def test_exchange_revoked_invite(self):
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ # Send a 3PID invite event with an empty body so it's considered as a revoked one.
+ invite_token = "sometoken"
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=tok,
+ )
+
+ d = self.handler.on_exchange_third_party_invite_request(
+ origin="example.com",
+ room_id=room_id,
+ event_dict={
+ "type": EventTypes.Member,
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": "@someone:example.org",
+ "content": {
+ "membership": "invite",
+ "third_party_invite": {
+ "display_name": "alice",
+ "signed": {
+ "mxid": "@alice:localhost",
+ "token": invite_token,
+ "signatures": {
+ "magic.forest": {
+ "ed25519:3": (
+ "fQpGIW1Snz+pwLZu6sTy2aHy/DYWWTspTJRPyNp0PKkymfIs"
+ "NffysMl6ObMMFdIJhk6g6pwlIqZ54rxo8SLmAg"
+ )
+ }
+ },
+ },
+ },
+ },
+ },
+ )
+
+ failure = self.get_failure(d, AuthError).value
+
+ self.assertEqual(failure.code, 403, failure)
+ self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
+ self.assertEqual(failure.msg, "You are not invited to this room.")
diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py
new file mode 100644
index 0000000000..32c31b2f66
--- /dev/null
+++ b/tests/handlers/test_identity.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse.rest.admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account
+
+from tests import unittest
+
+
+class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.address = "test@test"
+ self.is_server_name = "testis"
+ self.rewritten_is_url = "int.testis"
+
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = [self.is_server_name]
+ config["rewrite_identity_server_urls"] = {
+ self.is_server_name: self.rewritten_is_url
+ }
+
+ mock_http_client = Mock(spec=["post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+
+ def test_rewritten_id_server(self):
+ """
+ Tests that, when validating a 3PID association while rewriting the IS's server
+ name:
+ * the bind request is done against the rewritten hostname
+ * the original, non-rewritten, server name is stored in the database
+ """
+ handler = self.hs.get_handlers().identity_handler
+ post_json_get_json = self.hs.get_simple_http_client().post_json_get_json
+ store = self.hs.get_datastore()
+
+ creds = {"sid": "123", "client_secret": "some_secret"}
+
+ # Make sure processing the mocked response goes through.
+ data = self.get_success(
+ handler.bind_threepid(
+ {
+ "id_server": self.is_server_name,
+ "client_secret": creds["client_secret"],
+ "sid": creds["sid"],
+ },
+ self.user_id,
+ )
+ )
+ self.assertEqual(data.get("address"), self.address)
+
+ # Check that the request was done against the rewritten server name.
+ post_json_get_json.assert_called_once_with(
+ "https://%s/_matrix/identity/api/v1/3pid/bind" % self.rewritten_is_url,
+ {
+ "sid": creds["sid"],
+ "client_secret": creds["client_secret"],
+ "mxid": self.user_id,
+ },
+ )
+
+ # Check that the original server name is saved in the database instead of the
+ # rewritten one.
+ id_servers = self.get_success(
+ store.get_id_servers_user_bound(self.user_id, "email", self.address)
+ )
+ self.assertEqual(id_servers, [self.is_server_name])
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d60c124eec..2311040201 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -67,13 +67,11 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- yield self.store.create_profile(self.frank.localpart)
-
self.handler = hs.get_profile_handler()
@defer.inlineCallbacks
def test_get_my_name(self):
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
displayname = yield self.handler.get_displayname(self.frank)
@@ -116,8 +114,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.create_profile("caroline")
- yield self.store.set_profile_displayname("caroline", "Caroline")
+ yield self.store.set_profile_displayname("caroline", "Caroline", 1)
response = yield self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"}
@@ -128,7 +125,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ self.frank.localpart, "http://my.server/me.png", 1
)
avatar_url = yield self.handler.get_avatar_url(self.frank)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 90d0129374..408f8583f1 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
+from synapse.rest.client.v2_alpha.register import _map_email_to_displayname
from synapse.types import RoomAlias, UserID, create_requester
from .. import unittest
@@ -231,6 +232,26 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_email_to_displayname_mapping(self):
+ """Test that custom emails are mapped to new user displaynames correctly"""
+ self._check_mapping(
+ "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]"
+ )
+
+ self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]")
+
+ self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]")
+
+ # Multibyte unicode characters
+ self._check_mapping(
+ "j\u030a\u0065an-poppy.seed@example.com",
+ "J\u030a\u0065an-Poppy Seed [Example]",
+ )
+
+ def _check_mapping(self, i, expected):
+ result = _map_email_to_displayname(i)
+ self.assertEqual(result, expected)
+
@defer.inlineCallbacks
def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
"""Creates a new user if the user does not exist,
@@ -283,4 +304,4 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user, requester, displayname, by_admin=True
)
- defer.returnValue((user_id, token))
+ return (user_id, token)
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2d5dba6464..2096ba3c91 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -20,6 +20,23 @@ from zope.interface import implementer
from OpenSSL import SSL
from OpenSSL.SSL import Connection
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+from twisted.internet.ssl import Certificate, trustRootFromCertificates
+from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
+from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
+
+
+def get_test_https_policy():
+ """Get a test IPolicyForHTTPS which trusts the test CA cert
+
+ Returns:
+ IPolicyForHTTPS
+ """
+ ca_file = get_test_ca_cert_file()
+ with open(ca_file) as stream:
+ content = stream.read()
+ cert = Certificate.loadPEM(content)
+ trust_root = trustRootFromCertificates([cert])
+ return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
def get_test_ca_cert_file():
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index a49f9b3224..cf3f52fd9d 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -115,19 +115,24 @@ class MatrixFederationAgentTests(TestCase):
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
)
+ # grab a hold of the TLS connection, in case it gets torn down
+ server_tls_connection = server_tls_protocol._tlsConnection
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_protocol = server_tls_protocol.wrappedProtocol
+
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
# check the SNI
- server_name = server_tls_protocol._tlsConnection.get_servername()
+ server_name = server_tls_connection.get_servername()
self.assertEqual(
server_name,
expected_sni,
"Expected SNI %s but got %s" % (expected_sni, server_name),
)
- # fish the test server back out of the server-side TLS protocol.
- return server_tls_protocol.wrappedProtocol
+ return http_protocol
@defer.inlineCallbacks
def _make_get_request(self, uri):
@@ -145,7 +150,7 @@ class MatrixFederationAgentTests(TestCase):
try:
fetch_res = yield fetch_d
- defer.returnValue(fetch_res)
+ return fetch_res
except Exception as e:
logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e)
raise
@@ -936,7 +941,7 @@ class MatrixFederationAgentTests(TestCase):
except Exception as e:
logger.warning("Error fetching well-known: %s", e)
raise
- defer.returnValue(result)
+ return result
def test_well_known_cache(self):
self.reactor.lookups["testserv"] = "1.2.3.4"
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 65b51dc981..3b885ef64b 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -61,7 +61,7 @@ class SrvResolverTestCase(unittest.TestCase):
# should have restored our context
self.assertIs(LoggingContext.current_context(), ctx)
- defer.returnValue(result)
+ return result
test_d = do_lookup()
self.assertNoResult(test_d)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index b9d6d7ad1c..2b01f40a42 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -68,7 +68,7 @@ class FederationClientTests(HomeserverTestCase):
try:
fetch_res = yield fetch_d
- defer.returnValue(fetch_res)
+ return fetch_res
finally:
check_logcontext(context)
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
new file mode 100644
index 0000000000..22abf76515
--- /dev/null
+++ b/tests/http/test_proxyagent.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import treq
+
+from twisted.internet import interfaces # noqa: F401
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+
+from synapse.http.proxyagent import ProxyAgent
+
+from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
+from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.unittest import TestCase
+
+logger = logging.getLogger(__name__)
+
+HTTPFactory = Factory.forProtocol(HTTPChannel)
+
+
+class MatrixFederationAgentTests(TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def _make_connection(
+ self, client_factory, server_factory, ssl=False, expected_sni=None
+ ):
+ """Builds a test server, and completes the outgoing client connection
+
+ Args:
+ client_factory (interfaces.IProtocolFactory): the the factory that the
+ application is trying to use to make the outbound connection. We will
+ invoke it to build the client Protocol
+
+ server_factory (interfaces.IProtocolFactory): a factory to build the
+ server-side protocol
+
+ ssl (bool): If true, we will expect an ssl connection and wrap
+ server_factory with a TLSMemoryBIOFactory
+
+ expected_sni (bytes|None): the expected SNI value
+
+ Returns:
+ IProtocol: the server Protocol returned by server_factory
+ """
+ if ssl:
+ server_factory = _wrap_server_factory_for_tls(server_factory)
+
+ server_protocol = server_factory.buildProtocol(None)
+
+ # now, tell the client protocol factory to build the client protocol,
+ # and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server protocol to send its stuff back to the client, too
+ server_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_protocol)
+ )
+
+ if ssl:
+ http_protocol = server_protocol.wrappedProtocol
+ tls_connection = server_protocol._tlsConnection
+ else:
+ http_protocol = server_protocol
+ tls_connection = None
+
+ # give the reactor a pump to get the TLS juices flowing (if needed)
+ self.reactor.advance(0)
+
+ if expected_sni is not None:
+ server_name = tls_connection.get_servername()
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ return http_protocol
+
+ def test_http_request(self):
+ agent = ProxyAgent(self.reactor)
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 80)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request(self):
+ agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ _get_test_protocol_factory(),
+ ssl=True,
+ expected_sni=b"test.com",
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_http_request_via_proxy(self):
+ agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 8888)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"http://test.com")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request_via_proxy(self):
+ agent = ProxyAgent(
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ https_proxy=b"proxy.com",
+ )
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 1080)
+
+ # make a test HTTP server, and wire up the client
+ proxy_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # fish the transports back out so that we can do the old switcheroo
+ s2c_transport = proxy_server.transport
+ client_protocol = s2c_transport.other
+ c2s_transport = client_protocol.transport
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending CONNECT request
+ self.assertEqual(len(proxy_server.requests), 1)
+
+ request = proxy_server.requests[0]
+ self.assertEqual(request.method, b"CONNECT")
+ self.assertEqual(request.path, b"test.com:443")
+
+ # tell the proxy server not to close the connection
+ proxy_server.persistent = True
+
+ # this just stops the http Request trying to do a chunked response
+ # request.setHeader(b"Content-Length", b"0")
+ request.finish()
+
+ # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+ ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+ ssl_protocol = ssl_factory.buildProtocol(None)
+ http_server = ssl_protocol.wrappedProtocol
+
+ ssl_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, ssl_protocol)
+ )
+ c2s_transport.other = ssl_protocol
+
+ self.reactor.advance(0)
+
+ server_name = ssl_protocol._tlsConnection.get_servername()
+ expected_sni = b"test.com"
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+
+def _wrap_server_factory_for_tls(factory, sanlist=None):
+ """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
+
+ The resultant factory will create a TLS server which presents a certificate
+ signed by our test CA, valid for the domains in `sanlist`
+
+ Args:
+ factory (interfaces.IProtocolFactory): protocol factory to wrap
+ sanlist (iterable[bytes]): list of domains the cert should be valid for
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ if sanlist is None:
+ sanlist = [b"DNS:test.com"]
+
+ connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
+ return TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=factory
+ )
+
+
+def _get_test_protocol_factory():
+ """Get a protocol Factory which will build an HTTPChannel
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ return server_factory
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 8ce6bb62da..af2327fb66 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
config = self.default_config()
config["start_pushers"] = True
- hs = self.setup_test_homeserver(config=config, simple_http_client=m)
+ hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
return hs
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..f81f81602e 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@
import json
+from mock import Mock
+
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import account
from tests import unittest
-class IdentityTestCase(unittest.HomeserverTestCase):
+class IdentityDisabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts fail when the HS's config disallows them."""
servlets = [
+ account.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
@@ -32,19 +39,93 @@ class IdentityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["testis"]
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_disabled(self):
+ request, channel = self.make_request(
+ b"POST", "/createRoom", b"{}", access_token=self.tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ room_id = channel.json_body["room_id"]
+
+ params = {
+ "id_server": "testis",
+ "medium": "email",
+ "address": "test@example.com",
+ }
+ request_data = json.dumps(params)
+ request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
+ request, channel = self.make_request(
+ b"POST", request_url, request_data, access_token=self.tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
def test_3pid_lookup_disabled(self):
- self.hs.config.enable_3pid_lookup = False
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+ def test_3pid_bulk_lookup_disabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+
+class IdentityEnabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts succeed when the HS's config allows them."""
- self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
+ servlets = [
+ account.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+
+ config = self.default_config()
+ config["enable_3pid_lookup"] = True
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.return_value = defer.succeed((200, "{}"))
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_enabled(self):
request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=tok
+ b"POST", "/createRoom", b"{}", access_token=self.tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -58,7 +139,44 @@ class IdentityTestCase(unittest.HomeserverTestCase):
request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
request, channel = self.make_request(
- b"POST", request_url, request_data, access_token=tok
+ b"POST", request_url, request_data, access_token=self.tok
)
self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ get_json = self.hs.get_simple_http_client().get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "test@example.com", "medium": "email"},
+ )
+
+ def test_3pid_lookup_enabled(self):
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.render(request)
+
+ get_json = self.hs.get_simple_http_client().get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "foo@bar.baz", "medium": "email"},
+ )
+
+ def test_3pid_bulk_lookup_enabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.render(request)
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ post_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/bulk_lookup",
+ {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]},
+ )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
new file mode 100644
index 0000000000..4303f95206
--- /dev/null
+++ b/tests/rest/client/test_retention.py
@@ -0,0 +1,292 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from synapse.api.constants import EventTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.visibility import filter_events_for_client
+
+from tests import unittest
+
+one_hour_ms = 3600000
+one_day_ms = one_hour_ms * 24
+
+
+class RetentionTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["default_room_version"] = "1"
+ config["retention"] = {
+ "enabled": True,
+ "default_policy": {
+ "min_lifetime": one_day_ms,
+ "max_lifetime": one_day_ms * 3,
+ },
+ "allowed_lifetime_min": one_day_ms,
+ "allowed_lifetime_max": one_day_ms * 3,
+ }
+
+ self.hs = self.setup_test_homeserver(config=config)
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("user", "password")
+ self.token = self.login("user", "password")
+
+ def test_retention_state_event(self):
+ """Tests that the server configuration can limit the values a user can set to the
+ room's retention policy.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_day_ms * 4},
+ tok=self.token,
+ expect_code=400,
+ )
+
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_hour_ms},
+ tok=self.token,
+ expect_code=400,
+ )
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": lifetime},
+ tok=self.token,
+ )
+
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_without_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by the server's configuration's default retention policy.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self._test_retention_event_purged(room_id, one_day_ms * 2)
+
+ def test_visibility(self):
+ """Tests that synapse.visibility.filter_events_for_client correctly filters out
+ outdated events
+ """
+ store = self.hs.get_datastore()
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ events = []
+
+ # Send a first event, which should be filtered out at the end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ # Get the event from the store so that we end up with a FrozenEvent that we can
+ # give to filter_events_for_client. We need to do this now because the event won't
+ # be in the database anymore after it has expired.
+ events.append(self.get_success(store.get_event(resp.get("event_id"))))
+
+ # Advance the time by 2 days. We're using the default retention policy, therefore
+ # after this the first event will still be valid.
+ self.reactor.advance(one_day_ms * 2 / 1000)
+
+ # Send another event, which shouldn't get filtered out.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ valid_event_id = resp.get("event_id")
+
+ events.append(self.get_success(store.get_event(valid_event_id)))
+
+ # Advance the time by anothe 2 days. After this, the first event should be
+ # outdated but not the second one.
+ self.reactor.advance(one_day_ms * 2 / 1000)
+
+ # Run filter_events_for_client with our list of FrozenEvents.
+ filtered_events = self.get_success(
+ filter_events_for_client(store, self.user_id, events)
+ )
+
+ # We should only get one event back.
+ self.assertEqual(len(filtered_events), 1, filtered_events)
+ # That event should be the second, not outdated event.
+ self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
+
+ def _test_retention_event_purged(self, room_id, increment):
+ # Get the create event to, later, check that we can still access it.
+ message_handler = self.hs.get_message_handler()
+ create_event = self.get_success(
+ message_handler.get_room_data(self.user_id, room_id, EventTypes.Create)
+ )
+
+ # Send a first event to the room. This is the event we'll want to be purged at the
+ # end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ expired_event_id = resp.get("event_id")
+
+ # Check that we can retrieve the event.
+ expired_event = self.get_event(room_id, expired_event_id)
+ self.assertEqual(
+ expired_event.get("content", {}).get("body"), "1", expired_event
+ )
+
+ # Advance the time.
+ self.reactor.advance(increment / 1000)
+
+ # Send another event. We need this because the purge job won't purge the most
+ # recent event in the room.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ valid_event_id = resp.get("event_id")
+
+ # Advance the time again. Now our first event should have expired but our second
+ # one should still be kept.
+ self.reactor.advance(increment / 1000)
+
+ # Check that the event has been purged from the database.
+ self.get_event(room_id, expired_event_id, expected_code=404)
+
+ # Check that the event that hasn't been purged can still be retrieved.
+ valid_event = self.get_event(room_id, valid_event_id)
+ self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
+
+ # Check that we can still access state events that were sent before the event that
+ # has been purged.
+ self.get_event(room_id, create_event.event_id)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url, access_token=self.token)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
+
+
+class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["default_room_version"] = "1"
+ config["retention"] = {"enabled": True}
+
+ mock_federation_client = Mock(spec=["backfill"])
+
+ self.hs = self.setup_test_homeserver(
+ config=config, federation_client=mock_federation_client
+ )
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("user", "password")
+ self.token = self.login("user", "password")
+
+ def test_no_default_policy(self):
+ """Tests that an event doesn't get expired if there is neither a default retention
+ policy nor a policy specific to the room.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self._test_retention(room_id)
+
+ def test_state_policy(self):
+ """Tests that an event gets correctly expired if there is no default retention
+ policy but there's a policy specific to the room.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set the maximum lifetime to 35 days so that the first event gets expired but not
+ # the second one.
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_day_ms * 35},
+ tok=self.token,
+ )
+
+ self._test_retention(room_id, expected_code_for_first_event=404)
+
+ def _test_retention(self, room_id, expected_code_for_first_event=200):
+ # Send a first event to the room. This is the event we'll want to be purged at the
+ # end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ first_event_id = resp.get("event_id")
+
+ # Check that we can retrieve the event.
+ expired_event = self.get_event(room_id, first_event_id)
+ self.assertEqual(
+ expired_event.get("content", {}).get("body"), "1", expired_event
+ )
+
+ # Advance the time by a month.
+ self.reactor.advance(one_day_ms * 30 / 1000)
+
+ # Send another event. We need this because the purge job won't purge the most
+ # recent event in the room.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ second_event_id = resp.get("event_id")
+
+ # Advance the time by another month.
+ self.reactor.advance(one_day_ms * 30 / 1000)
+
+ # Check if the event has been purged from the database.
+ first_event = self.get_event(
+ room_id, first_event_id, expected_code=expected_code_for_first_event
+ )
+
+ if expected_code_for_first_event == 200:
+ self.assertEqual(
+ first_event.get("content", {}).get("body"), "1", first_event
+ )
+
+ # Check that the event that hasn't been purged can still be retrieved.
+ second_event = self.get_event(room_id, second_event_id)
+ self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url, access_token=self.token)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
new file mode 100644
index 0000000000..d44f5c2c8c
--- /dev/null
+++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,721 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+import random
+import string
+
+from mock import Mock
+
+from twisted.internet import defer
+from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.third_party_rules.access_rules import (
+ ACCESS_RULE_DIRECT,
+ ACCESS_RULE_RESTRICTED,
+ ACCESS_RULE_UNRESTRICTED,
+ ACCESS_RULES_TYPE,
+)
+
+from tests import unittest
+
+
+class RoomAccessTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["third_party_event_rules"] = {
+ "module": "synapse.third_party_rules.access_rules.RoomAccessRules",
+ "config": {
+ "domains_forbidden_when_restricted": ["forbidden_domain"],
+ "id_server": "testis",
+ },
+ }
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ def send_invite(destination, room_id, event_id, pdu):
+ return defer.succeed(pdu)
+
+ def get_json(uri, args={}, headers=None):
+ address_domain = args["address"].split("@")[1]
+ return defer.succeed({"hs": address_domain})
+
+ def post_json_get_json(uri, post_json, args={}, headers=None):
+ token = "".join(random.choice(string.ascii_letters) for _ in range(10))
+ return defer.succeed(
+ {
+ "token": token,
+ "public_keys": [
+ {
+ "public_key": "serverpublickey",
+ "key_validity_url": "https://testis/pubkey/isvalid",
+ },
+ {
+ "public_key": "phemeralpublickey",
+ "key_validity_url": "https://testis/pubkey/ephemeral/isvalid",
+ },
+ ],
+ "display_name": "f...@b...",
+ }
+ )
+
+ mock_federation_client = Mock(spec=["send_invite"])
+ mock_federation_client.send_invite.side_effect = send_invite
+
+ mock_http_client = Mock(
+ spec=["get_json", "post_json_get_json"],
+ )
+ # Mocking the response for /info on the IS API.
+ mock_http_client.get_json.side_effect = get_json
+ # Mocking the response for /store-invite on the IS API.
+ mock_http_client.post_json_get_json.side_effect = post_json_get_json
+ self.hs = self.setup_test_homeserver(
+ config=config,
+ federation_client=mock_federation_client,
+ simple_http_client=mock_http_client,
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.restricted_room = self.create_room()
+ self.unrestricted_room = self.create_room(rule=ACCESS_RULE_UNRESTRICTED)
+ self.direct_rooms = [
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ ]
+
+ self.invitee_id = self.register_user("invitee", "test")
+ self.invitee_tok = self.login("invitee", "test")
+
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ )
+
+ def test_create_room_no_rule(self):
+ """Tests that creating a room with no rule will set the default value."""
+ room_id = self.create_room()
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, ACCESS_RULE_RESTRICTED)
+
+ def test_create_room_direct_no_rule(self):
+ """Tests that creating a direct room with no rule will set the default value."""
+ room_id = self.create_room(direct=True)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, ACCESS_RULE_DIRECT)
+
+ def test_create_room_valid_rule(self):
+ """Tests that creating a room with a valid rule will set the right value."""
+ room_id = self.create_room(rule=ACCESS_RULE_UNRESTRICTED)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, ACCESS_RULE_UNRESTRICTED)
+
+ def test_create_room_invalid_rule(self):
+ """Tests that creating a room with an invalid rule will set fail."""
+ self.create_room(rule=ACCESS_RULE_DIRECT, expected_code=400)
+
+ def test_create_room_direct_invalid_rule(self):
+ """Tests that creating a direct room with an invalid rule will fail.
+ """
+ self.create_room(direct=True, rule=ACCESS_RULE_RESTRICTED, expected_code=400)
+
+ def test_public_room(self):
+ """Tests that it's not possible to have a room with the public join rule and an
+ access rule that's not restricted.
+ """
+ # Creating a room with the public_chat preset should succeed and set the access
+ # rule to restricted.
+ preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT)
+ self.assertEqual(
+ self.current_rule_in_room(preset_room_id), ACCESS_RULE_RESTRICTED
+ )
+
+ # Creating a room with the public join rule in its initial state should succeed
+ # and set the access rule to restricted.
+ init_state_room_id = self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ]
+ )
+ self.assertEqual(
+ self.current_rule_in_room(init_state_room_id), ACCESS_RULE_RESTRICTED
+ )
+
+ # Changing access rule to unrestricted should fail.
+ self.change_rule_in_room(
+ preset_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403
+ )
+ self.change_rule_in_room(
+ init_state_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403
+ )
+
+ # Changing access rule to direct should fail.
+ self.change_rule_in_room(preset_room_id, ACCESS_RULE_DIRECT, expected_code=403)
+ self.change_rule_in_room(
+ init_state_room_id, ACCESS_RULE_DIRECT, expected_code=403
+ )
+
+ # Changing join rule to public in an unrestricted room should fail.
+ self.change_join_rule_in_room(
+ self.unrestricted_room, JoinRules.PUBLIC, expected_code=403
+ )
+ # Changing join rule to public in an direct room should fail.
+ self.change_join_rule_in_room(
+ self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403
+ )
+
+ # Creating a new room with the public_chat preset and an access rule that isn't
+ # restricted should fail.
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=400,
+ )
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=ACCESS_RULE_DIRECT,
+ expected_code=400,
+ )
+
+ # Creating a room with the public join rule in its initial state and an access
+ # rule that isn't restricted should fail.
+ self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ],
+ rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=400,
+ )
+ self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ],
+ rule=ACCESS_RULE_DIRECT,
+ expected_code=400,
+ )
+
+ def test_restricted(self):
+ """Tests that in restricted mode we're unable to invite users from blacklisted
+ servers but can invite other users.
+ """
+ # We can't invite a user from a forbidden HS.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can invite a user which HS isn't forbidden.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:allowed_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.restricted_room,
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.restricted_room,
+ expected_code=200,
+ )
+
+ def test_direct(self):
+ """Tests that, in direct mode, other users than the initial two can't be invited,
+ but the following scenario works:
+ * invited user joins the room
+ * invited user leaves the room
+ * room creator re-invites invited user
+ Also tests that a user from a HS that's in the list of forbidden domains (to use
+ in restricted mode) can be invited.
+ """
+ not_invited_user = "@not_invited:forbidden_domain"
+
+ # We can't invite a new user to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # The invited user can join the room.
+ self.helper.join(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can leave the room.
+ self.helper.leave(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can be re-invited to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # If we're alone in the room and have always been the only member, we can invite
+ # someone.
+ self.helper.invite(
+ room=self.direct_rooms[1],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # Disable the 3pid invite ratelimiter
+ burst = self.hs.config.rc_third_party_invite.burst_count
+ per_second = self.hs.config.rc_third_party_invite.per_second
+ self.hs.config.rc_third_party_invite.burst_count = 10
+ self.hs.config.rc_third_party_invite.per_second = 0.1
+
+ # We can't send a 3PID invite to a room that already has two members.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[0],
+ expected_code=403,
+ )
+
+ # We can't send a 3PID invite to a room that already has a pending invite.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[1],
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to a room in which we've always been the only member.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to a room in which there's a 3PID invite.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=403,
+ )
+
+ self.hs.config.rc_third_party_invite.burst_count = burst
+ self.hs.config.rc_third_party_invite.per_second = per_second
+
+ def test_unrestricted(self):
+ """Tests that, in unrestricted mode, we can invite whoever we want, but we can
+ only change the power level of users that wouldn't be forbidden in restricted
+ mode.
+ """
+ # We can invite
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:not_forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a power level event that doesn't redefine the default PL or set a
+ # non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a power level event that redefines the default PL and doesn't set
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={
+ "users": {self.user_id: 100, "@test:not_forbidden_domain": 10},
+ "users_default": 10,
+ },
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't send a power level event that doesn't redefines the default PL but sets
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_rules(self):
+ """Tests that we can only change the current rule from restricted to
+ unrestricted.
+ """
+ # We can change the rule from restricted to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.restricted_room,
+ new_rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=200,
+ )
+
+ # We can't change the rule from restricted to direct.
+ self.change_rule_in_room(
+ room_id=self.restricted_room, new_rule=ACCESS_RULE_DIRECT, expected_code=403
+ )
+
+ # We can't change the rule from unrestricted to restricted.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=ACCESS_RULE_RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from unrestricted to direct.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=ACCESS_RULE_DIRECT,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to restricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=ACCESS_RULE_RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=403,
+ )
+
+ def test_change_room_avatar(self):
+ """Tests that changing the room avatar is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ avatar_content = {
+ "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394},
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ }
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_name(self):
+ """Tests that changing the room name is always allowed unless the room is a direct
+ chat, in which case it's forbidden.
+ """
+
+ name_content = {"name": "My super room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_topic(self):
+ """Tests that changing the room topic is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ topic_content = {"topic": "Welcome to this room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_revoke_3pid_invite_direct(self):
+ """Tests that revoking a 3PID invite doesn't cause the room access rules module to
+ confuse the revokation as a new 3PID invite.
+ """
+ invite_token = "sometoken"
+
+ invite_body = {
+ "display_name": "ker...@exa...",
+ "public_keys": [
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ },
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I",
+ },
+ ],
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ }
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=self.tok,
+ )
+
+ invite_token = "someothertoken"
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ def create_room(
+ self,
+ direct=False,
+ rule=None,
+ preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ initial_state=None,
+ expected_code=200,
+ ):
+ content = {"is_direct": direct, "preset": preset}
+
+ if rule:
+ content["initial_state"] = [
+ {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}}
+ ]
+
+ if initial_state:
+ if "initial_state" not in content:
+ content["initial_state"] = []
+
+ content["initial_state"] += initial_state
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ json.dumps(content),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ if expected_code == 200:
+ return channel.json_body["room_id"]
+
+ def current_rule_in_room(self, room_id):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["rule"]
+
+ def change_rule_in_room(self, room_id, new_rule, expected_code=200):
+ data = {"rule": new_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200):
+ data = {"join_rule": new_join_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_threepid_invite(self, address, room_id, expected_code=200):
+ params = {"id_server": "testis", "medium": "email", "address": address}
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/invite" % room_id,
+ json.dumps(params),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_state_with_state_key(
+ self, room_id, event_type, state_key, body, tok, expect_code=200
+ ):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id,
+ event_type,
+ state_key,
+ )
+
+ request, channel = self.make_request(
+ "PUT", path, json.dumps(body), access_token=tok
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expect_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index a8adc9a61d..a3d7e3c046 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -46,7 +46,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def cb():
yield Clock(reactor).sleep(0)
- defer.returnValue("yay")
+ return "yay"
@defer.inlineCallbacks
def test():
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 140d8b3772..02b4b8f5eb 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -229,6 +229,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
config = self.default_config()
config["require_auth_for_profile_requests"] = True
+ config["limit_profile_requests_to_known_users"] = True
self.hs = self.setup_test_homeserver(config=config)
return self.hs
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 9915367144..cdded88b7f 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -128,8 +128,12 @@ class RestHelper(object):
return channel.json_body
- def send_state(self, room_id, event_type, body, tok, expect_code=200):
- path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type)
+ def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id,
+ event_type,
+ state_key,
+ )
if tok:
path = path + "?access_token=%s" % tok
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 920de41de4..9fed900f4a 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -23,8 +23,8 @@ from email.parser import Parser
import pkg_resources
import synapse.rest.admin
-from synapse.api.constants import LoginType
-from synapse.rest.client.v1 import login
+from synapse.api.constants import LoginType, Membership
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from tests import unittest
@@ -244,6 +244,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
account.register_servlets,
+ room.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -279,3 +280,56 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request("GET", "account/whoami")
self.render(request)
self.assertEqual(request.code, 401)
+
+ @unittest.INFO
+ def test_pending_invites(self):
+ """Tests that deactivating a user rejects every pending invite for them."""
+ store = self.hs.get_datastore()
+
+ inviter_id = self.register_user("inviter", "test")
+ inviter_tok = self.login("inviter", "test")
+
+ invitee_id = self.register_user("invitee", "test")
+ invitee_tok = self.login("invitee", "test")
+
+ # Make @inviter:test invite @invitee:test in a new room.
+ room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok)
+ self.helper.invite(
+ room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok
+ )
+
+ # Make sure the invite is here.
+ pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+ self.assertEqual(len(pending_invites), 1, pending_invites)
+ self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
+
+ # Deactivate @invitee:test.
+ self.deactivate(invitee_id, invitee_tok)
+
+ # Check that the invite isn't there anymore.
+ pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+ self.assertEqual(len(pending_invites), 0, pending_invites)
+
+ # Check that the membership of @invitee:test in the room is now "leave".
+ memberships = self.get_success(
+ store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
+ )
+ self.assertEqual(len(memberships), 1, memberships)
+ self.assertEqual(memberships[0].room_id, room_id, memberships)
+
+ def deactivate(self, user_id, tok):
+ request_data = json.dumps(
+ {
+ "auth": {
+ "type": "m.login.password",
+ "user": user_id,
+ "password": "test",
+ },
+ "erase": False,
+ }
+ )
+ request, channel = self.make_request(
+ "POST", "account/deactivate", request_data, access_token=tok
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
new file mode 100644
index 0000000000..37f970c6b0
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account, password_policy, register
+
+from tests import unittest
+
+
+class PasswordPolicyTestCase(unittest.HomeserverTestCase):
+ """Tests the password policy feature and its compliance with MSC2000.
+
+ When validating a password, Synapse does the necessary checks in this order:
+
+ 1. Password is long enough
+ 2. Password contains digit(s)
+ 3. Password contains symbol(s)
+ 4. Password contains uppercase letter(s)
+ 5. Password contains lowercase letter(s)
+
+ Therefore, each test in this test case that tests whether a password triggers the
+ right error code to be returned provides a password good enough to pass the previous
+ steps but not the one it's testing (nor any step that comes after).
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ password_policy.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.register_url = "/_matrix/client/r0/register"
+ self.policy = {
+ "enabled": True,
+ "minimum_length": 10,
+ "require_digit": True,
+ "require_symbol": True,
+ "require_lowercase": True,
+ "require_uppercase": True,
+ }
+
+ config = self.default_config()
+ config["password_config"] = {"policy": self.policy}
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def test_get_policy(self):
+ """Tests if the /password_policy endpoint returns the configured policy."""
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/password_policy"
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "m.minimum_length": 10,
+ "m.require_digit": True,
+ "m.require_symbol": True,
+ "m.require_lowercase": True,
+ "m.require_uppercase": True,
+ },
+ channel.result,
+ )
+
+ def test_password_too_short(self):
+ request_data = json.dumps({"username": "kermit", "password": "shorty"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result
+ )
+
+ def test_password_no_digit(self):
+ request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result
+ )
+
+ def test_password_no_symbol(self):
+ request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result
+ )
+
+ def test_password_no_uppercase(self):
+ request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result
+ )
+
+ def test_password_no_lowercase(self):
+ request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result
+ )
+
+ def test_password_compliant(self):
+ request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ # Getting a 401 here means the password has passed validation and the server has
+ # responded with a list of registration flows.
+ self.assertEqual(channel.code, 401, channel.result)
+
+ def test_password_change(self):
+ """This doesn't test every possible use case, only that hitting /account/password
+ triggers the password validation code.
+ """
+ compliant_password = "C0mpl!antpassword"
+ not_compliant_password = "notcompliantpassword"
+
+ user_id = self.register_user("kermit", compliant_password)
+ tok = self.login("kermit", compliant_password)
+
+ request_data = json.dumps(
+ {
+ "new_password": not_compliant_password,
+ "auth": {
+ "password": compliant_password,
+ "type": LoginType.PASSWORD,
+ "user": user_id,
+ },
+ }
+ )
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/account/password",
+ request_data,
+ access_token=tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 89a3f95c0a..9b7eb1e856 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,8 +19,12 @@ import datetime
import json
import os
+from mock import Mock
+
import pkg_resources
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
@@ -200,6 +204,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+class RegisterHideProfileTestCase(unittest.HomeserverTestCase):
+
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.url = b"/_matrix/client/r0/register"
+
+ config = self.default_config()
+ config["enable_registration"] = True
+ config["show_users_in_user_directory"] = False
+ config["replicate_user_profiles_to"] = ["fakeserver"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_profile_hidden(self):
+ user_id = self.register_user("kermit", "monkey")
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # We expect post_json_get_json to have been called twice: once with the original
+ # profile and once with the None profile resulting from the request to hide it
+ # from the user directory.
+ self.assertEqual(post_json.call_count, 2, post_json.call_args_list)
+
+ # Get the args (and not kwargs) passed to post_json.
+ args = post_json.call_args[0]
+ # Make sure the last call was attempting to replicate profiles.
+ split_uri = args[0].split("/")
+ self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0])
+ # Make sure the last profile update was overriding the user's profile to None.
+ self.assertEqual(args[1]["batch"][user_id], None, args[1])
+
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -208,6 +253,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
login.register_servlets,
sync.register_servlets,
account_validity.register_servlets,
+ account.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -300,6 +346,138 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
)
+class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.client.v1.profile.register_servlets,
+ synapse.rest.client.v1.room.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Set accounts to expire after a week
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ config["replicate_user_profiles_to"] = "test.is"
+
+ # Mock homeserver requests to an identity server
+ mock_http_client = Mock(spec=["post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_expired_user_in_directory(self):
+ """Test that an expired user is hidden in the user directory"""
+ # Create an admin user to search the user directory
+ admin_id = self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ # Ensure the admin never expires
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": admin_id,
+ "expiration_ts": 999999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Create a user to expire
+ username = "kermit"
+ user_id = self.register_user(username, "monkey")
+ self.login(username, "monkey")
+
+ self.pump(1000)
+ self.reactor.advance(1000)
+ self.pump()
+
+ # Expire the user
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Wait for the background job to run which hides expired users in the directory
+ self.pump(60 * 60 * 1000)
+
+ # Mock the homeserver's HTTP client
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's None, signifying that the user should be removed from the user
+ # directory because they were expired
+ replicated_content = batch[user_id]
+ self.assertIsNone(replicated_content)
+
+ # Now renew the user, and check they get replicated again to the identity server
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 99999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.pump(10)
+ self.reactor.advance(10)
+ self.pump()
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None, signifying that the user is back in the user
+ # directory
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -323,6 +501,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"renew_at": 172800000, # Time in ms for 2 days
"renew_by_email_enabled": True,
"renew_email_subject": "Renew your account",
+ "account_renewed_html_path": "account_renewed.html",
+ "invalid_token_html_path": "invalid_token.html",
}
# Email config.
@@ -373,6 +553,19 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
+ # Check that we're getting HTML back.
+ content_type = None
+ for header in channel.result.get("headers", []):
+ if header[0] == b"Content-Type":
+ content_type = header[1]
+ self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+
+ # Check that the HTML we're getting is the one we expect on a successful renewal.
+ expected_html = self.hs.config.account_validity.account_renewed_html_content
+ self.assertEqual(
+ channel.result["body"], expected_html.encode("utf8"), channel.result
+ )
+
# Move 3 days forward. If the renewal failed, every authed request with
# our access token should be denied from now, otherwise they should
# succeed.
@@ -381,6 +574,28 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
+ def test_renewal_invalid_token(self):
+ # Hit the renewal endpoint with an invalid token and check that it behaves as
+ # expected, i.e. that it responds with 404 Not Found and the correct HTML.
+ url = "/_matrix/client/unstable/account_validity/renew?token=123"
+ request, channel = self.make_request(b"GET", url)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"404", channel.result)
+
+ # Check that we're getting HTML back.
+ content_type = None
+ for header in channel.result.get("headers", []):
+ if header[0] == b"Content-Type":
+ content_type = header[1]
+ self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+
+ # Check that the HTML we're getting is the one we expect when using an
+ # invalid/unknown token.
+ expected_html = self.hs.config.account_validity.invalid_token_html_content
+ self.assertEqual(
+ channel.result["body"], expected_html.encode("utf8"), channel.result
+ )
+
def test_manual_email_send(self):
self.email_attempts = []
@@ -414,7 +629,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"POST", "account/deactivate", request_data, access_token=tok
)
self.render(request)
- self.assertEqual(request.code, 200)
+ self.assertEqual(request.code, 200, channel.result)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py
new file mode 100644
index 0000000000..1accc70dc9
--- /dev/null
+++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+
+import synapse.rest.admin
+from synapse.config._base import ConfigError
+from synapse.rest.client.v1 import login, room
+from synapse.rulecheck.domain_rule_checker import DomainRuleChecker
+
+from tests import unittest
+from tests.server import make_request, render
+
+
+class DomainRuleCheckerTestCase(unittest.TestCase):
+ def test_allowed(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ "domains_prevented_from_being_invited_to_published_rooms": ["target_two"],
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_one", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_two", "test:target_two", None, "room", False
+ )
+ )
+
+ # User can invite internal user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test1:target_one", None, "room", False, True
+ )
+ )
+
+ # User can invite external user to a non-published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, False
+ )
+ )
+
+ def test_disallowed(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ "source_four": [],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_one", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_one", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_four", "test:target_one", None, "room", False
+ )
+ )
+
+ # User cannot invite external user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, True
+ )
+ )
+
+ def test_default_allow(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_default_deny(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_config_parse(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ self.assertEquals(config, DomainRuleChecker.parse_config(config))
+
+ def test_config_parse_failure(self):
+ config = {
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ }
+ }
+ self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config)
+
+
+class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["localhost"]
+
+ config["spam_checker"] = {
+ "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker",
+ "config": {
+ "default": True,
+ "domain_mapping": {},
+ "can_only_join_rooms_with_invite": True,
+ "can_only_create_one_to_one_rooms": True,
+ "can_only_invite_during_room_creation": True,
+ "can_invite_by_third_party_id": False,
+ },
+ }
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user_id = self.register_user("admin_user", "pass", admin=True)
+ self.admin_access_token = self.login("admin_user", "pass")
+
+ self.normal_user_id = self.register_user("normal_user", "pass", admin=False)
+ self.normal_access_token = self.login("normal_user", "pass")
+
+ self.other_user_id = self.register_user("other_user", "pass", admin=False)
+
+ def test_admin_can_create_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_normal_user_cannot_create_empty_room(self):
+ channel = self._create_room(self.normal_access_token)
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_cannot_create_room_with_multiple_invites(self):
+ channel = self._create_room(
+ self.normal_access_token,
+ content={"invite": [self.other_user_id, self.admin_user_id]},
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly counts both normal and third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [self.other_user_id],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly rejects third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_can_room_with_single_invites(self):
+ channel = self._create_room(
+ self.normal_access_token, content={"invite": [self.other_user_id]}
+ )
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_cannot_join_public_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403
+ )
+
+ def test_can_join_invited_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ def test_cannot_invite(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ def test_cannot_3pid_invite(self):
+ """Test that unbound 3pid invites get rejected.
+ """
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/invite" % (room_id),
+ {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"},
+ access_token=self.normal_access_token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 403, channel.result["body"])
+
+ def _create_room(self, token, content={}):
+ path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,)
+
+ request, channel = make_request(
+ self.hs.get_reactor(),
+ "POST",
+ path,
+ content=json.dumps(content).encode("utf8"),
+ )
+ render(request, self.resource, self.hs.get_reactor())
+
+ return channel
diff --git a/tests/server.py b/tests/server.py
index e573c4e4c5..d10c0603e9 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -387,11 +387,24 @@ class FakeTransport(object):
self.disconnecting = True
if self._protocol:
self._protocol.connectionLost(reason)
- self.disconnected = True
+
+ # if we still have data to write, delay until that is done
+ if self.buffer:
+ logger.info(
+ "FakeTransport: Delaying disconnect until buffer is flushed"
+ )
+ else:
+ self.disconnected = True
def abortConnection(self):
logger.info("FakeTransport: abortConnection()")
- self.loseConnection()
+
+ if not self.disconnecting:
+ self.disconnecting = True
+ if self._protocol:
+ self._protocol.connectionLost(None)
+
+ self.disconnected = True
def pauseProducing(self):
if not self.producer:
@@ -422,6 +435,9 @@ class FakeTransport(object):
self._reactor.callLater(0.0, _produce)
def write(self, byt):
+ if self.disconnecting:
+ raise Exception("Writing to disconnecting FakeTransport")
+
self.buffer = self.buffer + byt
# always actually do the write asynchronously. Some protocols (notably the
@@ -465,3 +481,7 @@ class FakeTransport(object):
self.buffer = self.buffer[len(to_write) :]
if self.buffer and self.autoflush:
self._reactor.callLater(0.0, self.flush)
+
+ if not self.buffer and self.disconnecting:
+ logger.info("FakeTransport: Buffer now empty, completing disconnect")
+ self.disconnected = True
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index fbb9302694..9fabe3fbc0 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -43,7 +43,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
"test_update",
progress,
)
- defer.returnValue(count)
+ return count
self.update_handler.side_effect = update
@@ -60,7 +60,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def update(progress, count):
yield self.store._end_background_update("test_update")
- defer.returnValue(count)
+ return count
self.update_handler.side_effect = update
self.update_handler.reset_mock()
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index e07ff01201..95f309fbbc 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -14,6 +14,7 @@
# limitations under the License.
import signedjson.key
+import unpaddedbase64
from twisted.internet.defer import Deferred
@@ -21,11 +22,17 @@ from synapse.storage.keys import FetchKeyResult
import tests.unittest
-KEY_1 = signedjson.key.decode_verify_key_base64(
- "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
+
+def decode_verify_key_base64(key_id: str, key_base64: str):
+ key_bytes = unpaddedbase64.decode_base64(key_base64)
+ return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
+
+
+KEY_1 = decode_verify_key_base64(
+ "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
)
-KEY_2 = signedjson.key.decode_verify_key_base64(
- "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+KEY_2 = decode_verify_key_base64(
+ "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 45824bd3b2..13e9f8ec09 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -34,9 +34,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.create_profile(self.u_frank.localpart)
-
- yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
self.assertEquals(
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
@@ -44,10 +42,8 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.create_profile(self.u_frank.localpart)
-
yield self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ self.u_frank.localpart, "http://my.site/here", 1
)
self.assertEquals(
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 732a778fab..8488b6edc8 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,23 +17,21 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
from tests import unittest
-from tests.utils import create_room, setup_test_homeserver
+from tests.utils import create_room
-class RedactionTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(
- self.addCleanup, resource_for_federation=Mock(), http_client=None
+class RedactionTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ resource_for_federation=Mock(), http_client=None
)
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -42,11 +41,12 @@ class RedactionTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test")
- yield create_room(hs, self.room1.to_string(), self.u_alice.to_string())
+ self.get_success(
+ create_room(hs, self.room1.to_string(), self.u_alice.to_string())
+ )
self.depth = 1
- @defer.inlineCallbacks
def inject_room_member(
self, room, user, membership, replaces_state=None, extra_content={}
):
@@ -63,15 +63,14 @@ class RedactionTestCase(unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.store.persist_event(event, context)
+ self.get_success(self.store.persist_event(event, context))
- defer.returnValue(event)
+ return event
- @defer.inlineCallbacks
def inject_message(self, room, user, body):
self.depth += 1
@@ -86,15 +85,14 @@ class RedactionTestCase(unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.store.persist_event(event, context)
+ self.get_success(self.store.persist_event(event, context))
- defer.returnValue(event)
+ return event
- @defer.inlineCallbacks
def inject_redaction(self, room, event_id, user, reason):
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
@@ -108,20 +106,21 @@ class RedactionTestCase(unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.store.persist_event(event, context)
+ self.get_success(self.store.persist_event(event, context))
- @defer.inlineCallbacks
def test_redact(self):
- yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+ self.get_success(
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+ )
- msg_event = yield self.inject_message(self.room1, self.u_alice, "t")
+ msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
# Check event has not been redacted:
- event = yield self.store.get_event(msg_event.event_id)
+ event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertObjectHasAttributes(
{
@@ -136,11 +135,11 @@ class RedactionTestCase(unittest.TestCase):
# Redact event
reason = "Because I said so"
- yield self.inject_redaction(
- self.room1, msg_event.event_id, self.u_alice, reason
+ self.get_success(
+ self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
)
- event = yield self.store.get_event(msg_event.event_id)
+ event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertEqual(msg_event.event_id, event.event_id)
@@ -164,15 +163,18 @@ class RedactionTestCase(unittest.TestCase):
event.unsigned["redacted_because"],
)
- @defer.inlineCallbacks
def test_redact_join(self):
- yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+ self.get_success(
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+ )
- msg_event = yield self.inject_room_member(
- self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
+ msg_event = self.get_success(
+ self.inject_room_member(
+ self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
+ )
)
- event = yield self.store.get_event(msg_event.event_id)
+ event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertObjectHasAttributes(
{
@@ -187,13 +189,13 @@ class RedactionTestCase(unittest.TestCase):
# Redact event
reason = "Because I said so"
- yield self.inject_redaction(
- self.room1, msg_event.event_id, self.u_alice, reason
+ self.get_success(
+ self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
)
# Check redaction
- event = yield self.store.get_event(msg_event.event_id)
+ event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertTrue("redacted_because" in event.unsigned)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 73ed943f5a..c6e8196b91 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -67,7 +67,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
yield self.store.persist_event(event, context)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def test_one_member(self):
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 212a7ae765..5c2cf3c2db 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -65,7 +65,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
yield self.store.persist_event(event, context)
- defer.returnValue(event)
+ return event
def assertStateMapEqual(self, s1, s2):
for t in s1:
diff --git a/tests/test_types.py b/tests/test_types.py
index 9ab5f829b0..7cb1f8acb4 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -12,9 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import string_types
from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ GroupID,
+ RoomAlias,
+ UserID,
+ map_username_to_mxid_localpart,
+ strip_invalid_mxid_characters,
+)
from tests import unittest
from tests.utils import TestHomeServer
@@ -106,3 +113,16 @@ class MapUsernameTestCase(unittest.TestCase):
self.assertEqual(
map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast"
)
+
+
+class StripInvalidMxidCharactersTestCase(unittest.TestCase):
+ def test_return_type(self):
+ unstripped = strip_invalid_mxid_characters("test")
+ stripped = strip_invalid_mxid_characters("test@")
+
+ self.assertTrue(isinstance(unstripped, string_types), type(unstripped))
+ self.assertTrue(isinstance(stripped, string_types), type(stripped))
+
+ def test_strip(self):
+ stripped = strip_invalid_mxid_characters("test@")
+ self.assertEqual(stripped, "test", stripped)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 118c3bd238..e0605dac2f 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -139,7 +139,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
builder
)
yield self.hs.get_datastore().persist_event(event, context)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def inject_room_member(self, user_id, membership="join", extra_content={}):
@@ -161,7 +161,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
)
yield self.hs.get_datastore().persist_event(event, context)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def inject_message(self, user_id, content=None):
@@ -182,7 +182,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
)
yield self.hs.get_datastore().persist_event(event, context)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def test_large_room(self):
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 7807328e2f..5713870f48 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -27,6 +27,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
+from synapse.util.caches.descriptors import cached
from tests import unittest
@@ -55,12 +56,15 @@ class CacheTestCase(unittest.TestCase):
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
- # lookup should return the deferreds
- self.assertIs(cache.get("key1"), d1)
- self.assertIs(cache.get("key2"), d2)
+ # lookup should return observable deferreds
+ self.assertFalse(cache.get("key1").has_called())
+ self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete
d2.callback("result2")
+
+ # for now at least, the cache will return real results rather than an
+ # observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
@@ -146,6 +150,28 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
+ def test_cache_with_sync_exception(self):
+ """If the wrapped function throws synchronously, things should continue to work
+ """
+
+ class Cls(object):
+ @cached()
+ def fn(self, arg1):
+ raise SynapseError(100, "mai spoon iz too big!!1")
+
+ obj = Cls()
+
+ # this should fail immediately
+ d = obj.fn(1)
+ self.failureResultOf(d, SynapseError)
+
+ # ... leaving the cache empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
+
+ # and a second call should result in a second exception
+ d = obj.fn(1)
+ self.failureResultOf(d, SynapseError)
+
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
@@ -159,7 +185,7 @@ class DescriptorTestCase(unittest.TestCase):
def inner_fn():
with PreserveLoggingContext():
yield complete_lookup
- defer.returnValue(1)
+ return 1
return inner_fn()
@@ -169,7 +195,7 @@ class DescriptorTestCase(unittest.TestCase):
c1.name = "c1"
r = yield obj.fn(1)
self.assertEqual(LoggingContext.current_context(), c1)
- defer.returnValue(r)
+ return r
def check_result(r):
self.assertEqual(r, 1)
@@ -222,6 +248,9 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(LoggingContext.current_context(), c1)
+ # the cache should now be empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
+
obj = Cls()
# set off a deferred which will do a cache lookup
@@ -268,6 +297,61 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
+ def test_cache_iterable(self):
+ class Cls(object):
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached(iterable=True)
+ def fn(self, arg1, arg2):
+ return self.mock(arg1, arg2)
+
+ obj = Cls()
+
+ obj.mock.return_value = ["spam", "eggs"]
+ r = obj.fn(1, 2)
+ self.assertEqual(r, ["spam", "eggs"])
+ obj.mock.assert_called_once_with(1, 2)
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = ["chips"]
+ r = obj.fn(1, 3)
+ self.assertEqual(r, ["chips"])
+ obj.mock.assert_called_once_with(1, 3)
+ obj.mock.reset_mock()
+
+ # the two values should now be cached
+ self.assertEqual(len(obj.fn.cache.cache), 3)
+
+ r = obj.fn(1, 2)
+ self.assertEqual(r, ["spam", "eggs"])
+ r = obj.fn(1, 3)
+ self.assertEqual(r, ["chips"])
+ obj.mock.assert_not_called()
+
+ def test_cache_iterable_with_sync_exception(self):
+ """If the wrapped function throws synchronously, things should continue to work
+ """
+
+ class Cls(object):
+ @descriptors.cached(iterable=True)
+ def fn(self, arg1):
+ raise SynapseError(100, "mai spoon iz too big!!1")
+
+ obj = Cls()
+
+ # this should fail immediately
+ d = obj.fn(1)
+ self.failureResultOf(d, SynapseError)
+
+ # ... leaving the cache empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
+
+ # and a second call should result in a second exception
+ d = obj.fn(1)
+ self.failureResultOf(d, SynapseError)
+
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
@@ -286,7 +370,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
# we want this to behave like an asynchronous function
yield run_on_reactor()
assert LoggingContext.current_context().request == "c1"
- defer.returnValue(self.mock(args1, arg2))
+ return self.mock(args1, arg2)
with LoggingContext() as c1:
c1.request = "c1"
@@ -334,7 +418,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
yield run_on_reactor()
- defer.returnValue(self.mock(args1, arg2))
+ return self.mock(args1, arg2)
obj = Cls()
invalidate0 = mock.Mock()
diff --git a/tests/utils.py b/tests/utils.py
index 99a3deae21..6350646263 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -361,7 +361,7 @@ def setup_test_homeserver(
if fed:
register_federation_servlets(hs, fed)
- defer.returnValue(hs)
+ return hs
def register_federation_servlets(hs, resource):
@@ -465,9 +465,9 @@ class MockHttpResource(HttpServer):
args = [urlparse.unquote(u) for u in matcher.groups()]
(code, response) = yield func(mock_request, *args)
- defer.returnValue((code, response))
+ return (code, response)
except CodeMessageException as e:
- defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
+ return (e.code, cs_error(e.msg, code=e.errcode))
raise KeyError("No event can handle %s" % path)
diff --git a/tox.ini b/tox.ini
index 09b4b8fc3c..689cb43db1 100644
--- a/tox.ini
+++ b/tox.ini
@@ -117,7 +117,7 @@ skip_install = True
basepython = python3.6
deps =
flake8
- black
+ black==19.3b0
commands =
python -m black --check --diff .
/bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/hash_password scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}"
@@ -131,7 +131,7 @@ commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests"
skip_install = True
deps = towncrier>=18.6.0rc1
commands =
- python -m towncrier.check --compare-with=origin/develop
+ python -m towncrier.check --compare-with=origin/dinsic
basepython = python3.6
[testenv:check-sampleconfig]
|