summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-x.buildkite/scripts/test_old_deps.sh2
-rwxr-xr-x.buildkite/scripts/test_synapse_port_db.sh2
-rw-r--r--CHANGES.md201
-rw-r--r--INSTALL.md42
-rw-r--r--README.rst18
-rw-r--r--UPGRADE.rst9
-rw-r--r--changelog.d/8553.docker1
-rw-r--r--changelog.d/8926.bugfix1
-rw-r--r--changelog.d/9458.misc1
-rw-r--r--changelog.d/9473.bugfix1
-rw-r--r--changelog.d/9491.feature1
-rw-r--r--changelog.d/9508.doc1
-rw-r--r--changelog.d/9510.feature1
-rw-r--r--changelog.d/9511.feature1
-rw-r--r--changelog.d/9520.misc1
-rw-r--r--changelog.d/9523.misc1
-rw-r--r--changelog.d/9528.misc1
-rw-r--r--changelog.d/9540.feature1
-rw-r--r--changelog.d/9540.removal1
-rw-r--r--changelog.d/9541.misc1
-rw-r--r--changelog.d/9542.bugfix1
-rw-r--r--changelog.d/9543.misc1
-rw-r--r--changelog.d/9549.feature1
-rw-r--r--changelog.d/9550.doc1
-rw-r--r--changelog.d/9559.removal1
-rw-r--r--changelog.d/9560.misc1
-rw-r--r--changelog.d/9561.misc1
-rw-r--r--changelog.d/9562.misc1
-rw-r--r--changelog.d/9563.misc1
-rw-r--r--changelog.d/9567.bugfix1
-rw-r--r--changelog.d/9568.misc1
-rw-r--r--changelog.d/9571.doc1
-rw-r--r--changelog.d/9573.feature1
-rw-r--r--changelog.d/9576.misc1
-rw-r--r--changelog.d/9580.doc1
-rw-r--r--changelog.d/9583.bugfix1
-rw-r--r--changelog.d/9586.misc1
-rw-r--r--changelog.d/9587.bugfix1
-rw-r--r--changelog.d/9590.misc1
-rw-r--r--changelog.d/9591.misc1
-rw-r--r--changelog.d/9596.misc1
-rw-r--r--changelog.d/9597.bugfix1
-rw-r--r--changelog.d/9601.feature1
-rw-r--r--changelog.d/9604.doc1
-rw-r--r--changelog.d/9604.misc1
-rw-r--r--changelog.d/9608.misc1
-rw-r--r--changelog.d/9617.feature1
-rw-r--r--changelog.d/9618.misc1
-rw-r--r--changelog.d/9619.misc1
-rw-r--r--changelog.d/9620.bugfix1
-rw-r--r--changelog.d/9623.bugfix1
-rw-r--r--changelog.d/9626.feature1
-rw-r--r--changelog.d/9654.feature1
-rw-r--r--changelog.d/9685.misc1
-rw-r--r--changelog.d/9686.misc1
-rw-r--r--changelog.d/9691.feature1
-rw-r--r--changelog.d/9700.feature1
-rw-r--r--changelog.d/9710.feature1
-rw-r--r--changelog.d/9711.bugfix1
-rw-r--r--changelog.d/9717.feature1
-rw-r--r--changelog.d/9718.removal1
-rw-r--r--changelog.d/9719.doc1
-rw-r--r--changelog.d/9725.bugfix1
-rw-r--r--changelog.d/9730.misc1
-rw-r--r--changelog.d/9736.misc1
-rw-r--r--changelog.d/9742.misc1
-rw-r--r--changelog.d/9743.misc1
-rw-r--r--changelog.d/9753.misc1
-rw-r--r--changelog.d/9770.bugfix1
-rw-r--r--contrib/purge_api/purge_history.sh2
-rw-r--r--contrib/purge_api/purge_remote_media.sh2
-rw-r--r--debian/changelog18
-rwxr-xr-xdemo/clean.sh2
-rwxr-xr-xdemo/start.sh2
-rwxr-xr-xdemo/stop.sh2
-rw-r--r--docker/Dockerfile50
-rw-r--r--docker/build_debian.sh2
-rw-r--r--docker/conf/homeserver.yaml8
-rwxr-xr-xdocker/run_pg_tests.sh2
-rw-r--r--docs/admin_api/user_admin_api.rst85
-rw-r--r--docs/code_style.md3
-rw-r--r--docs/deprecation_policy.md33
-rw-r--r--docs/presence_router_module.md235
-rw-r--r--docs/reverse_proxy.md7
-rw-r--r--docs/sample_config.yaml114
-rw-r--r--docs/workers.md3
-rw-r--r--mypy.ini9
-rwxr-xr-xscripts-dev/check-newsfragment2
-rwxr-xr-xscripts-dev/complement.sh49
-rwxr-xr-xscripts-dev/config-lint.sh2
-rwxr-xr-xscripts-dev/federation_client.py75
-rwxr-xr-xscripts-dev/generate_sample_config2
-rwxr-xr-xscripts-dev/lint.sh2
-rwxr-xr-xscripts-dev/make_full_schema.sh2
-rwxr-xr-xscripts-dev/next_github_number.sh4
-rwxr-xr-xscripts/move_remote_media_to_new_store.py2
-rw-r--r--setup.cfg3
-rwxr-xr-xsetup.py3
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py5
-rw-r--r--synapse/api/constants.py9
-rw-r--r--synapse/api/ratelimiting.py100
-rw-r--r--synapse/api/room_versions.py24
-rw-r--r--synapse/app/__init__.py4
-rw-r--r--synapse/app/_base.py21
-rw-r--r--synapse/app/generic_worker.py27
-rw-r--r--synapse/app/homeserver.py9
-rw-r--r--synapse/config/api.py139
-rw-r--r--synapse/config/cache.py6
-rw-r--r--synapse/config/experimental.py9
-rw-r--r--synapse/config/key.py6
-rw-r--r--synapse/config/metrics.py4
-rw-r--r--synapse/config/oidc_config.py72
-rw-r--r--synapse/config/ratelimiting.py8
-rw-r--r--synapse/config/registration.py4
-rw-r--r--synapse/config/repository.py4
-rw-r--r--synapse/config/saml2_config.py4
-rw-r--r--synapse/config/server.py39
-rw-r--r--synapse/config/tracer.py4
-rw-r--r--synapse/crypto/context_factory.py4
-rw-r--r--synapse/crypto/keyring.py2
-rw-r--r--synapse/event_auth.py28
-rw-r--r--synapse/events/__init__.py9
-rw-r--r--synapse/events/presence_router.py104
-rw-r--r--synapse/events/third_party_rules.py15
-rw-r--r--synapse/events/utils.py16
-rw-r--r--synapse/federation/federation_client.py184
-rw-r--r--synapse/federation/federation_server.py56
-rw-r--r--synapse/federation/send_queue.py88
-rw-r--r--synapse/federation/sender/__init__.py141
-rw-r--r--synapse/federation/sender/per_destination_queue.py114
-rw-r--r--synapse/federation/transport/client.py35
-rw-r--r--synapse/federation/transport/server.py71
-rw-r--r--synapse/groups/attestations.py2
-rw-r--r--synapse/groups/groups_server.py2
-rw-r--r--synapse/handlers/_base.py16
-rw-r--r--synapse/handlers/account_data.py2
-rw-r--r--synapse/handlers/account_validity.py9
-rw-r--r--synapse/handlers/acme.py2
-rw-r--r--synapse/handlers/admin.py2
-rw-r--r--synapse/handlers/appservice.py2
-rw-r--r--synapse/handlers/auth.py39
-rw-r--r--synapse/handlers/cas_handler.py2
-rw-r--r--synapse/handlers/deactivate_account.py2
-rw-r--r--synapse/handlers/device.py30
-rw-r--r--synapse/handlers/devicemessage.py42
-rw-r--r--synapse/handlers/e2e_keys.py16
-rw-r--r--synapse/handlers/e2e_room_keys.py2
-rw-r--r--synapse/handlers/federation.py163
-rw-r--r--synapse/handlers/groups_local.py2
-rw-r--r--synapse/handlers/identity.py12
-rw-r--r--synapse/handlers/message.py2
-rw-r--r--synapse/handlers/oidc_handler.py16
-rw-r--r--synapse/handlers/password_policy.py2
-rw-r--r--synapse/handlers/presence.py290
-rw-r--r--synapse/handlers/profile.py2
-rw-r--r--synapse/handlers/read_marker.py2
-rw-r--r--synapse/handlers/receipts.py2
-rw-r--r--synapse/handlers/register.py12
-rw-r--r--synapse/handlers/room_list.py2
-rw-r--r--synapse/handlers/room_member.py27
-rw-r--r--synapse/handlers/room_member_worker.py10
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/handlers/set_password.py4
-rw-r--r--synapse/handlers/space_summary.py395
-rw-r--r--synapse/handlers/state_deltas.py2
-rw-r--r--synapse/handlers/stats.py2
-rw-r--r--synapse/handlers/sync.py38
-rw-r--r--synapse/handlers/typing.py6
-rw-r--r--synapse/handlers/user_directory.py2
-rw-r--r--synapse/http/client.py4
-rw-r--r--synapse/http/connectproxyclient.py96
-rw-r--r--synapse/http/federation/well_known_resolver.py10
-rw-r--r--synapse/http/proxyagent.py81
-rw-r--r--synapse/http/site.py112
-rw-r--r--synapse/logging/context.py72
-rw-r--r--synapse/logging/opentracing.py12
-rw-r--r--synapse/metrics/__init__.py16
-rw-r--r--synapse/metrics/background_process_metrics.py18
-rw-r--r--synapse/module_api/__init__.py50
-rw-r--r--synapse/notifier.py47
-rw-r--r--synapse/push/__init__.py2
-rw-r--r--synapse/push/action_generator.py2
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py2
-rw-r--r--synapse/push/emailpusher.py2
-rw-r--r--synapse/push/httppusher.py4
-rw-r--r--synapse/push/mailer.py2
-rw-r--r--synapse/push/pusher.py2
-rw-r--r--synapse/python_dependencies.py20
-rw-r--r--synapse/replication/http/federation.py3
-rw-r--r--synapse/replication/http/register.py2
-rw-r--r--synapse/replication/http/send_event.py4
-rw-r--r--synapse/replication/slave/storage/pushers.py2
-rw-r--r--synapse/replication/tcp/commands.py6
-rw-r--r--synapse/replication/tcp/protocol.py7
-rw-r--r--synapse/replication/tcp/redis.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py17
-rw-r--r--synapse/replication/tcp/streams/federation.py14
-rw-r--r--synapse/rest/admin/media.py2
-rw-r--r--synapse/rest/admin/rooms.py3
-rw-r--r--synapse/rest/admin/users.py26
-rw-r--r--synapse/rest/client/v1/login.py14
-rw-r--r--synapse/rest/client/v1/room.py84
-rw-r--r--synapse/rest/client/v2_alpha/account.py12
-rw-r--r--synapse/rest/client/v2_alpha/capabilities.py23
-rw-r--r--synapse/rest/client/v2_alpha/groups.py2
-rw-r--r--synapse/rest/client/v2_alpha/register.py8
-rw-r--r--synapse/rest/client/v2_alpha/sync.py14
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/rest/media/v1/config_resource.py2
-rw-r--r--synapse/rest/media/v1/download_resource.py2
-rw-r--r--synapse/rest/media/v1/media_repository.py2
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py6
-rw-r--r--synapse/rest/media/v1/storage_provider.py2
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py2
-rw-r--r--synapse/rest/media/v1/upload_resource.py2
-rw-r--r--synapse/rest/synapse/client/pick_username.py3
-rw-r--r--synapse/secrets.py8
-rw-r--r--synapse/server.py34
-rw-r--r--synapse/server_notices/consent_server_notices.py18
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py11
-rw-r--r--synapse/server_notices/server_notices_manager.py2
-rw-r--r--synapse/server_notices/server_notices_sender.py18
-rw-r--r--synapse/server_notices/worker_server_notices_sender.py11
-rw-r--r--synapse/state/__init__.py5
-rw-r--r--synapse/storage/__init__.py2
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/background_updates.py2
-rw-r--r--synapse/storage/database.py13
-rw-r--r--synapse/storage/databases/main/__init__.py26
-rw-r--r--synapse/storage/databases/main/appservice.py2
-rw-r--r--synapse/storage/databases/main/deviceinbox.py6
-rw-r--r--synapse/storage/databases/main/event_federation.py2
-rw-r--r--synapse/storage/databases/main/events.py19
-rw-r--r--synapse/storage/databases/main/events_worker.py9
-rw-r--r--synapse/storage/databases/main/group_server.py4
-rw-r--r--synapse/storage/databases/main/media_repository.py21
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py4
-rw-r--r--synapse/storage/databases/main/presence.py60
-rw-r--r--synapse/storage/databases/main/pusher.py2
-rw-r--r--synapse/storage/databases/main/registration.py1
-rw-r--r--synapse/storage/databases/main/schema/delta/59/11drop_thumbnail_constraint.sql.postgres22
-rw-r--r--synapse/storage/databases/main/stats.py25
-rw-r--r--synapse/storage/databases/main/transactions.py45
-rw-r--r--synapse/storage/databases/state/store.py9
-rw-r--r--synapse/storage/prepare_database.py11
-rw-r--r--synapse/storage/purge_events.py2
-rw-r--r--synapse/storage/state.py6
-rw-r--r--synapse/util/async_helpers.py2
-rw-r--r--synapse/util/caches/__init__.py6
-rw-r--r--synapse/util/caches/deferred_cache.py4
-rw-r--r--synapse/util/caches/dictionary_cache.py64
-rw-r--r--synapse/util/caches/expiringcache.py83
-rw-r--r--synapse/util/caches/ttlcache.py53
-rw-r--r--synapse/util/frozenutils.py2
-rw-r--r--synapse/visibility.py78
-rwxr-xr-xtest_postgresql.sh2
-rw-r--r--tests/api/test_ratelimiting.py168
-rw-r--r--tests/config/test_load.py5
-rw-r--r--tests/crypto/test_keyring.py23
-rw-r--r--tests/events/test_presence_router.py386
-rw-r--r--tests/federation/test_federation_catch_up.py49
-rw-r--r--tests/handlers/test_oidc.py132
-rw-r--r--tests/handlers/test_presence.py20
-rw-r--r--tests/handlers/test_sync.py21
-rw-r--r--tests/http/test_proxyagent.py40
-rw-r--r--tests/logging/test_terse_json.py70
-rw-r--r--tests/module_api/test_api.py175
-rw-r--r--tests/replication/_base.py2
-rw-r--r--tests/replication/tcp/streams/test_typing.py1
-rw-r--r--tests/replication/test_multi_media_repo.py4
-rw-r--r--tests/rest/admin/test_user.py294
-rw-r--r--tests/rest/client/test_third_party_rules.py62
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py7
-rw-r--r--tests/rest/client/v2_alpha/test_capabilities.py36
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py62
-rw-r--r--tests/server.py30
-rw-r--r--tests/storage/test_devices.py80
-rw-r--r--tests/storage/test_directory.py44
-rw-r--r--tests/storage/test_end_to_end_keys.py59
-rw-r--r--tests/storage/test_event_push_actions.py133
-rw-r--r--tests/storage/test_profile.py35
-rw-r--r--tests/storage/test_redaction.py12
-rw-r--r--tests/storage/test_registration.py108
-rw-r--r--tests/storage/test_room.py61
-rw-r--r--tests/storage/test_state.py167
-rw-r--r--tests/storage/test_user_directory.py86
-rw-r--r--tests/test_event_auth.py246
-rw-r--r--tests/unittest.py14
-rw-r--r--tests/util/caches/test_descriptors.py7
-rw-r--r--tests/util/test_dict_cache.py4
-rw-r--r--tests/util/test_logcontext.py35
-rw-r--r--tests/utils.py1
293 files changed, 5891 insertions, 1900 deletions
diff --git a/.buildkite/scripts/test_old_deps.sh b/.buildkite/scripts/test_old_deps.sh

index 28e6694b5d..9fe5b696b0 100755 --- a/.buildkite/scripts/test_old_deps.sh +++ b/.buildkite/scripts/test_old_deps.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # this script is run by buildkite in a plain `xenial` container; it installs the # minimal requirements for tox and hands over to the py35-old tox environment. diff --git a/.buildkite/scripts/test_synapse_port_db.sh b/.buildkite/scripts/test_synapse_port_db.sh
index 9ed2177635..8914319e38 100755 --- a/.buildkite/scripts/test_synapse_port_db.sh +++ b/.buildkite/scripts/test_synapse_port_db.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Test script for 'synapse_port_db', which creates a virtualenv, installs Synapse along # with additional dependencies needed for the test (such as coverage or the PostgreSQL diff --git a/CHANGES.md b/CHANGES.md
index 3f0c5685ea..27483532d0 100644 --- a/CHANGES.md +++ b/CHANGES.md
@@ -1,8 +1,201 @@ -Removal warning ---------------- +Synapse 1.31.0 (2021-04-06) +=========================== + +**Note:** As announced in v1.25.0, and in line with the deprecation policy for platform dependencies, this is the last release to support Python 3.5 and PostgreSQL 9.5. Future versions of Synapse will require Python 3.6+ and PostgreSQL 9.6+, as per our [deprecation policy](docs/deprecation_policy.md). + +This is also the last release that the Synapse team will be publishing packages for Debian Stretch and Ubuntu Xenial. + + +Improved Documentation +---------------------- + +- Add a document describing the deprecation policy for platform dependencies. ([\#9723](https://github.com/matrix-org/synapse/issues/9723)) + + +Internal Changes +---------------- + +- Revert using `dmypy run` in lint script. ([\#9720](https://github.com/matrix-org/synapse/issues/9720)) +- Pin flake8-bugbear's version. ([\#9734](https://github.com/matrix-org/synapse/issues/9734)) + + +Synapse 1.31.0rc1 (2021-03-30) +============================== + +Features +-------- + +- Add support to OpenID Connect login for requiring attributes on the `userinfo` response. Contributed by Hubbe King. ([\#9609](https://github.com/matrix-org/synapse/issues/9609)) +- Add initial experimental support for a "space summary" API. ([\#9643](https://github.com/matrix-org/synapse/issues/9643), [\#9652](https://github.com/matrix-org/synapse/issues/9652), [\#9653](https://github.com/matrix-org/synapse/issues/9653)) +- Add support for the busy presence state as described in [MSC3026](https://github.com/matrix-org/matrix-doc/pull/3026). ([\#9644](https://github.com/matrix-org/synapse/issues/9644)) +- Add support for credentials for proxy authentication in the `HTTPS_PROXY` environment variable. ([\#9657](https://github.com/matrix-org/synapse/issues/9657)) + + +Bugfixes +-------- + +- Fix a longstanding bug that could cause issues when editing a reply to a message. ([\#9585](https://github.com/matrix-org/synapse/issues/9585)) +- Fix the `/capabilities` endpoint to return `m.change_password` as disabled if the local password database is not used for authentication. Contributed by @dklimpel. ([\#9588](https://github.com/matrix-org/synapse/issues/9588)) +- Check if local passwords are enabled before setting them for the user. ([\#9636](https://github.com/matrix-org/synapse/issues/9636)) +- Fix a bug where federation sending can stall due to `concurrent access` database exceptions when it falls behind. ([\#9639](https://github.com/matrix-org/synapse/issues/9639)) +- Fix a bug introduced in Synapse 1.30.1 which meant the suggested `pip` incantation to install an updated `cryptography` was incorrect. ([\#9699](https://github.com/matrix-org/synapse/issues/9699)) + + +Updates to the Docker image +--------------------------- + +- Speed up Docker builds and make it nicer to test against Complement while developing (install all dependencies before copying the project). ([\#9610](https://github.com/matrix-org/synapse/issues/9610)) +- Include [opencontainers labels](https://github.com/opencontainers/image-spec/blob/master/annotations.md#pre-defined-annotation-keys) in the Docker image. ([\#9612](https://github.com/matrix-org/synapse/issues/9612)) + + +Improved Documentation +---------------------- + +- Clarify that `register_new_matrix_user` is present also when installed via non-pip package. ([\#9074](https://github.com/matrix-org/synapse/issues/9074)) +- Update source install documentation to mention platform prerequisites before the source install steps. ([\#9667](https://github.com/matrix-org/synapse/issues/9667)) +- Improve worker documentation for fallback/web auth endpoints. ([\#9679](https://github.com/matrix-org/synapse/issues/9679)) +- Update the sample configuration for OIDC authentication. ([\#9695](https://github.com/matrix-org/synapse/issues/9695)) + + +Internal Changes +---------------- + +- Preparatory steps for removing redundant `outlier` data from `event_json.internal_metadata` column. ([\#9411](https://github.com/matrix-org/synapse/issues/9411)) +- Add type hints to the caching module. ([\#9442](https://github.com/matrix-org/synapse/issues/9442)) +- Introduce flake8-bugbear to the test suite and fix some of its lint violations. ([\#9499](https://github.com/matrix-org/synapse/issues/9499), [\#9659](https://github.com/matrix-org/synapse/issues/9659)) +- Add additional type hints to the Homeserver object. ([\#9631](https://github.com/matrix-org/synapse/issues/9631), [\#9638](https://github.com/matrix-org/synapse/issues/9638), [\#9675](https://github.com/matrix-org/synapse/issues/9675), [\#9681](https://github.com/matrix-org/synapse/issues/9681)) +- Only save remote cross-signing and device keys if they're different from the current ones. ([\#9634](https://github.com/matrix-org/synapse/issues/9634)) +- Rename storage function to fix spelling and not conflict with another function's name. ([\#9637](https://github.com/matrix-org/synapse/issues/9637)) +- Improve performance of federation catch up by sending the latest events in the room to the remote, rather than just the last event sent by the local server. ([\#9640](https://github.com/matrix-org/synapse/issues/9640), [\#9664](https://github.com/matrix-org/synapse/issues/9664)) +- In the `federation_client` commandline client, stop automatically adding the URL prefix, so that servlets on other prefixes can be tested. ([\#9645](https://github.com/matrix-org/synapse/issues/9645)) +- In the `federation_client` commandline client, handle inline `signing_key`s in `homeserver.yaml`. ([\#9647](https://github.com/matrix-org/synapse/issues/9647)) +- Fixed some antipattern issues to improve code quality. ([\#9649](https://github.com/matrix-org/synapse/issues/9649)) +- Add a storage method for pulling all current user presence state from the database. ([\#9650](https://github.com/matrix-org/synapse/issues/9650)) +- Import `HomeServer` from the proper module. ([\#9665](https://github.com/matrix-org/synapse/issues/9665)) +- Increase default join ratelimiting burst rate. ([\#9674](https://github.com/matrix-org/synapse/issues/9674)) +- Add type hints to third party event rules and visibility modules. ([\#9676](https://github.com/matrix-org/synapse/issues/9676)) +- Bump mypy-zope to 0.2.13 to fix "Cannot determine consistent method resolution order (MRO)" errors when running mypy a second time. ([\#9678](https://github.com/matrix-org/synapse/issues/9678)) +- Use interpreter from `$PATH` via `/usr/bin/env` instead of absolute paths in various scripts. ([\#9689](https://github.com/matrix-org/synapse/issues/9689)) +- Make it possible to use `dmypy`. ([\#9692](https://github.com/matrix-org/synapse/issues/9692)) +- Suppress "CryptographyDeprecationWarning: int_from_bytes is deprecated". ([\#9698](https://github.com/matrix-org/synapse/issues/9698)) +- Use `dmypy run` in lint script for improved performance in type-checking while developing. ([\#9701](https://github.com/matrix-org/synapse/issues/9701)) +- Fix undetected mypy error when using Python 3.6. ([\#9703](https://github.com/matrix-org/synapse/issues/9703)) +- Fix type-checking CI on develop. ([\#9709](https://github.com/matrix-org/synapse/issues/9709)) + + +Synapse 1.30.1 (2021-03-26) +=========================== + +This release is identical to Synapse 1.30.0, with the exception of explicitly +setting a minimum version of Python's Cryptography library to ensure that users +of Synapse are protected from the recent [OpenSSL security advisories](https://mta.openssl.org/pipermail/openssl-announce/2021-March/000198.html), +especially CVE-2021-3449. + +Note that Cryptography defaults to bundling its own statically linked copy of +OpenSSL, which means that you may not be protected by your operating system's +security updates. + +It's also worth noting that Cryptography no longer supports Python 3.5, so +admins deploying to older environments may not be protected against this or +future vulnerabilities. Synapse will be dropping support for Python 3.5 at the +end of March. + + +Updates to the Docker image +--------------------------- + +- Ensure that the docker container has up to date versions of openssl. ([\#9697](https://github.com/matrix-org/synapse/issues/9697)) + + +Internal Changes +---------------- + +- Enforce that `cryptography` dependency is up to date to ensure it has the most recent openssl patches. ([\#9697](https://github.com/matrix-org/synapse/issues/9697)) + + +Synapse 1.30.0 (2021-03-22) +=========================== + +Note that this release deprecates the ability for appservices to +call `POST /_matrix/client/r0/register` without the body parameter `type`. Appservice +developers should use a `type` value of `m.login.application_service` as +per [the spec](https://matrix.org/docs/spec/application_service/r0.1.2#server-admin-style-permissions). +In future releases, calling this endpoint with an access token - but without a `m.login.application_service` +type - will fail. + + +No significant changes. + + +Synapse 1.30.0rc1 (2021-03-16) +============================== + +Features +-------- + +- Add prometheus metrics for number of users successfully registering and logging in. ([\#9510](https://github.com/matrix-org/synapse/issues/9510), [\#9511](https://github.com/matrix-org/synapse/issues/9511), [\#9573](https://github.com/matrix-org/synapse/issues/9573)) +- Add `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time` prometheus metrics, which monitor federation delays by reporting the timestamps of messages sent and received to a set of remote servers. ([\#9540](https://github.com/matrix-org/synapse/issues/9540)) +- Add support for generating JSON Web Tokens dynamically for use as OIDC client secrets. ([\#9549](https://github.com/matrix-org/synapse/issues/9549)) +- Optimise handling of incomplete room history for incoming federation. ([\#9601](https://github.com/matrix-org/synapse/issues/9601)) +- Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). ([\#9617](https://github.com/matrix-org/synapse/issues/9617)) +- Tell spam checker modules about the SSO IdP a user registered through if one was used. ([\#9626](https://github.com/matrix-org/synapse/issues/9626)) + + +Bugfixes +-------- + +- Fix long-standing bug when generating thumbnails for some images with transparency: `TypeError: cannot unpack non-iterable int object`. ([\#9473](https://github.com/matrix-org/synapse/issues/9473)) +- Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. ([\#9542](https://github.com/matrix-org/synapse/issues/9542), [\#9583](https://github.com/matrix-org/synapse/issues/9583)) +- Fix bug where federation requests were not correctly retried on 5xx responses. ([\#9567](https://github.com/matrix-org/synapse/issues/9567)) +- Fix re-activating an account via the admin API when local passwords are disabled. ([\#9587](https://github.com/matrix-org/synapse/issues/9587)) +- Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages. ([\#9597](https://github.com/matrix-org/synapse/issues/9597)) +- Fix a bug introduced in v1.28.0 where the OpenID Connect callback endpoint could error with a `MacaroonInitException`. ([\#9620](https://github.com/matrix-org/synapse/issues/9620)) +- Fix Internal Server Error on `GET /_synapse/client/saml2/authn_response` request. ([\#9623](https://github.com/matrix-org/synapse/issues/9623)) + + +Updates to the Docker image +--------------------------- + +- Make use of an improved malloc implementation (`jemalloc`) in the docker image. ([\#8553](https://github.com/matrix-org/synapse/issues/8553)) + + +Improved Documentation +---------------------- + +- Add relayd entry to reverse proxy example configurations. ([\#9508](https://github.com/matrix-org/synapse/issues/9508)) +- Improve the SAML2 upgrade notes for 1.27.0. ([\#9550](https://github.com/matrix-org/synapse/issues/9550)) +- Link to the "List user's media" admin API from the media admin API docs. ([\#9571](https://github.com/matrix-org/synapse/issues/9571)) +- Clarify the spam checker modules documentation example to mention that `parse_config` is a required method. ([\#9580](https://github.com/matrix-org/synapse/issues/9580)) +- Clarify the sample configuration for `stats` settings. ([\#9604](https://github.com/matrix-org/synapse/issues/9604)) + + +Deprecations and Removals +------------------------- + +- The `synapse_federation_last_sent_pdu_age` and `synapse_federation_last_received_pdu_age` prometheus metrics have been removed. They are replaced by `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time`. ([\#9540](https://github.com/matrix-org/synapse/issues/9540)) +- Registering an Application Service user without using the `m.login.application_service` login type will be unsupported in an upcoming Synapse release. ([\#9559](https://github.com/matrix-org/synapse/issues/9559)) + + +Internal Changes +---------------- + +- Add tests to ResponseCache. ([\#9458](https://github.com/matrix-org/synapse/issues/9458)) +- Add type hints to purge room and server notice admin API. ([\#9520](https://github.com/matrix-org/synapse/issues/9520)) +- Add extra logging to ObservableDeferred when callbacks throw exceptions. ([\#9523](https://github.com/matrix-org/synapse/issues/9523)) +- Fix incorrect type hints. ([\#9528](https://github.com/matrix-org/synapse/issues/9528), [\#9543](https://github.com/matrix-org/synapse/issues/9543), [\#9591](https://github.com/matrix-org/synapse/issues/9591), [\#9608](https://github.com/matrix-org/synapse/issues/9608), [\#9618](https://github.com/matrix-org/synapse/issues/9618)) +- Add an additional test for purging a room. ([\#9541](https://github.com/matrix-org/synapse/issues/9541)) +- Add a `.git-blame-ignore-revs` file with the hashes of auto-formatting. ([\#9560](https://github.com/matrix-org/synapse/issues/9560)) +- Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle. ([\#9561](https://github.com/matrix-org/synapse/issues/9561)) +- Fix spurious errors reported by the `config-lint.sh` script. ([\#9562](https://github.com/matrix-org/synapse/issues/9562)) +- Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper. ([\#9563](https://github.com/matrix-org/synapse/issues/9563)) +- Do not have mypy ignore type hints from unpaddedbase64. ([\#9568](https://github.com/matrix-org/synapse/issues/9568)) +- Improve efficiency of calculating the auth chain in large rooms. ([\#9576](https://github.com/matrix-org/synapse/issues/9576)) +- Convert `synapse.types.Requester` to an `attrs` class. ([\#9586](https://github.com/matrix-org/synapse/issues/9586)) +- Add logging for redis connection setup. ([\#9590](https://github.com/matrix-org/synapse/issues/9590)) +- Improve logging when processing incoming transactions. ([\#9596](https://github.com/matrix-org/synapse/issues/9596)) +- Remove unused `stats.retention` setting, and emit a warning if stats are disabled. ([\#9604](https://github.com/matrix-org/synapse/issues/9604)) +- Prevent attempting to bundle aggregations for state events in /context APIs. ([\#9619](https://github.com/matrix-org/synapse/issues/9619)) -Note that this release deprecates the ability for appservices to call `POST /_matrix/client/r0/register` without the body parameter `type`. Appservice developers should use a `type` value of `m.login.application_service` as per the spec. In future releases, calling this endpoint with an access token but -without a valid type will fail. Synapse 1.29.0 (2021-03-08) =========================== diff --git a/INSTALL.md b/INSTALL.md
index b9e3f613d1..7b40689234 100644 --- a/INSTALL.md +++ b/INSTALL.md
@@ -6,7 +6,7 @@ There are 3 steps to follow under **Installation Instructions**. - [Choosing your server name](#choosing-your-server-name) - [Installing Synapse](#installing-synapse) - [Installing from source](#installing-from-source) - - [Platform-Specific Instructions](#platform-specific-instructions) + - [Platform-specific prerequisites](#platform-specific-prerequisites) - [Debian/Ubuntu/Raspbian](#debianubunturaspbian) - [ArchLinux](#archlinux) - [CentOS/Fedora](#centosfedora) @@ -38,6 +38,7 @@ There are 3 steps to follow under **Installation Instructions**. - [URL previews](#url-previews) - [Troubleshooting Installation](#troubleshooting-installation) + ## Choosing your server name It is important to choose the name for your server before you install Synapse, @@ -60,17 +61,14 @@ that your email address is probably `user@example.com` rather than (Prebuilt packages are available for some platforms - see [Prebuilt packages](#prebuilt-packages).) +When installing from source please make sure that the [Platform-specific prerequisites](#platform-specific-prerequisites) are already installed. + System requirements: - POSIX-compliant system (tested on Linux & OS X) - Python 3.5.2 or later, up to Python 3.9. - At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org -Synapse is written in Python but some of the libraries it uses are written in -C. So before we can install Synapse itself we need a working C compiler and the -header files for Python C extensions. See [Platform-Specific -Instructions](#platform-specific-instructions) for information on installing -these on various platforms. To install the Synapse homeserver run: @@ -128,7 +126,11 @@ source env/bin/activate synctl start ``` -#### Platform-Specific Instructions +#### Platform-specific prerequisites + +Synapse is written in Python but some of the libraries it uses are written in +C. So before we can install Synapse itself we need a working C compiler and the +header files for Python C extensions. ##### Debian/Ubuntu/Raspbian @@ -526,14 +528,24 @@ email will be disabled. The easiest way to create a new user is to do so from a client like [Element](https://element.io/). -Alternatively you can do so from the command line if you have installed via pip. - -This can be done as follows: - -```sh -$ source ~/synapse/env/bin/activate -$ synctl start # if not already running -$ register_new_matrix_user -c homeserver.yaml http://localhost:8008 +Alternatively, you can do so from the command line. This can be done as follows: + + 1. If synapse was installed via pip, activate the virtualenv as follows (if Synapse was + installed via a prebuilt package, `register_new_matrix_user` should already be + on the search path): + ```sh + cd ~/synapse + source env/bin/activate + synctl start # if not already running + ``` + 2. Run the following command: + ```sh + register_new_matrix_user -c homeserver.yaml http://localhost:8008 + ``` + +This will prompt you to add details for the new user, and will then connect to +the running Synapse to create the new user. For example: +``` New user localpart: erikj Password: Confirm password: diff --git a/README.rst b/README.rst
index 6a1e713590..1a5503572e 100644 --- a/README.rst +++ b/README.rst
@@ -314,6 +314,15 @@ Testing with SyTest is recommended for verifying that changes related to the Client-Server API are functioning correctly. See the `installation instructions <https://github.com/matrix-org/sytest#installing>`_ for details. + +Platform dependencies +===================== + +Synapse uses a number of platform dependencies such as Python and PostgreSQL, +and aims to follow supported upstream versions. See the +`<docs/deprecation_policy.md>`_ document for more details. + + Troubleshooting =============== @@ -384,12 +393,17 @@ massive excess of outgoing federation requests (see `discussion indicate that your server is also issuing far more outgoing federation requests than can be accounted for by your users' activity, this is a likely cause. The misbehavior can be worked around by setting -``use_presence: false`` in the Synapse config file. +the following in the Synapse config file: + +.. code-block:: yaml + + presence: + enabled: false People can't accept room invitations from me -------------------------------------------- -The typical failure mode here is that you send an invitation to someone +The typical failure mode here is that you send an invitation to someone to join a room or direct chat, but when they go to accept it, they get an error (typically along the lines of "Invalid signature"). They might see something like the following in their logs:: diff --git a/UPGRADE.rst b/UPGRADE.rst
index 8bc2ff91ab..ba488e1041 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst
@@ -98,9 +98,12 @@ will log a warning on each received request. To avoid the warning, administrators using a reverse proxy should ensure that the reverse proxy sets `X-Forwarded-Proto` header to `https` or `http` to -indicate the protocol used by the client. See the `reverse proxy documentation -<docs/reverse_proxy.md>`_, where the example configurations have been updated to -show how to set this header. +indicate the protocol used by the client. + +Synapse also requires the `Host` header to be preserved. + +See the `reverse proxy documentation <docs/reverse_proxy.md>`_, where the +example configurations have been updated to show how to set these headers. (Users of `Caddy <https://caddyserver.com/>`_ are unaffected, since we believe it sets `X-Forwarded-Proto` by default.) diff --git a/changelog.d/8553.docker b/changelog.d/8553.docker deleted file mode 100644
index f99c4207b8..0000000000 --- a/changelog.d/8553.docker +++ /dev/null
@@ -1 +0,0 @@ -Use jemalloc if available in docker. diff --git a/changelog.d/8926.bugfix b/changelog.d/8926.bugfix new file mode 100644
index 0000000000..aad7bd83ce --- /dev/null +++ b/changelog.d/8926.bugfix
@@ -0,0 +1 @@ +Prevent `synapse_forward_extremities` and `synapse_excess_extremity_events` Prometheus metrics from initially reporting zero-values after startup. diff --git a/changelog.d/9458.misc b/changelog.d/9458.misc deleted file mode 100644
index 8ceeed1352..0000000000 --- a/changelog.d/9458.misc +++ /dev/null
@@ -1 +0,0 @@ -Add tests to ResponseCache. \ No newline at end of file diff --git a/changelog.d/9473.bugfix b/changelog.d/9473.bugfix deleted file mode 100644
index 71fb487cf2..0000000000 --- a/changelog.d/9473.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix long-standing bug when generating thumbnails for some images with transparency: `TypeError: cannot unpack non-iterable int object`. diff --git a/changelog.d/9491.feature b/changelog.d/9491.feature new file mode 100644
index 0000000000..8b56a95a44 --- /dev/null +++ b/changelog.d/9491.feature
@@ -0,0 +1 @@ +Add a Synapse module for routing presence updates between users. diff --git a/changelog.d/9508.doc b/changelog.d/9508.doc deleted file mode 100644
index a17a8faecf..0000000000 --- a/changelog.d/9508.doc +++ /dev/null
@@ -1 +0,0 @@ -Add relayd entry to reverse proxy example configurations. diff --git a/changelog.d/9510.feature b/changelog.d/9510.feature deleted file mode 100644
index 5214b50d41..0000000000 --- a/changelog.d/9510.feature +++ /dev/null
@@ -1 +0,0 @@ -Add prometheus metrics for number of users successfully registering and logging in. diff --git a/changelog.d/9511.feature b/changelog.d/9511.feature deleted file mode 100644
index 5214b50d41..0000000000 --- a/changelog.d/9511.feature +++ /dev/null
@@ -1 +0,0 @@ -Add prometheus metrics for number of users successfully registering and logging in. diff --git a/changelog.d/9520.misc b/changelog.d/9520.misc deleted file mode 100644
index 825ba5bbc1..0000000000 --- a/changelog.d/9520.misc +++ /dev/null
@@ -1 +0,0 @@ -Add type hints to purge room and server notice admin API. \ No newline at end of file diff --git a/changelog.d/9523.misc b/changelog.d/9523.misc deleted file mode 100644
index f03e939efb..0000000000 --- a/changelog.d/9523.misc +++ /dev/null
@@ -1 +0,0 @@ -Add extra logging to ObservableDeferred when callbacks throw exceptions. \ No newline at end of file diff --git a/changelog.d/9528.misc b/changelog.d/9528.misc deleted file mode 100644
index 14c7b78dd9..0000000000 --- a/changelog.d/9528.misc +++ /dev/null
@@ -1 +0,0 @@ -Fix incorrect type hints. diff --git a/changelog.d/9540.feature b/changelog.d/9540.feature deleted file mode 100644
index 5417e51b93..0000000000 --- a/changelog.d/9540.feature +++ /dev/null
@@ -1 +0,0 @@ -Add `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time` prometheus metrics, which monitor federation delays by reporting the timestamps of messages sent and received to a set of remote servers. diff --git a/changelog.d/9540.removal b/changelog.d/9540.removal deleted file mode 100644
index d54f553cb9..0000000000 --- a/changelog.d/9540.removal +++ /dev/null
@@ -1 +0,0 @@ -The `synapse_federation_last_sent_pdu_age` and `synapse_federation_last_received_pdu_age` prometheus metrics have been removed. They are replaced by `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time`. diff --git a/changelog.d/9541.misc b/changelog.d/9541.misc deleted file mode 100644
index a82bef3431..0000000000 --- a/changelog.d/9541.misc +++ /dev/null
@@ -1 +0,0 @@ -Add an additional test for purging a room. diff --git a/changelog.d/9542.bugfix b/changelog.d/9542.bugfix deleted file mode 100644
index 51b1876f3b..0000000000 --- a/changelog.d/9542.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. diff --git a/changelog.d/9543.misc b/changelog.d/9543.misc deleted file mode 100644
index 14c7b78dd9..0000000000 --- a/changelog.d/9543.misc +++ /dev/null
@@ -1 +0,0 @@ -Fix incorrect type hints. diff --git a/changelog.d/9549.feature b/changelog.d/9549.feature deleted file mode 100644
index 709e61eced..0000000000 --- a/changelog.d/9549.feature +++ /dev/null
@@ -1 +0,0 @@ -Add support for generating JSON Web Tokens dynamically for use as OIDC client secrets. diff --git a/changelog.d/9550.doc b/changelog.d/9550.doc deleted file mode 100644
index adbbeb0ae4..0000000000 --- a/changelog.d/9550.doc +++ /dev/null
@@ -1 +0,0 @@ -Improve the SAML2 upgrade notes for 1.27.0. diff --git a/changelog.d/9559.removal b/changelog.d/9559.removal deleted file mode 100644
index f97bf56dc0..0000000000 --- a/changelog.d/9559.removal +++ /dev/null
@@ -1 +0,0 @@ -Registering an Application Service user without using the `m.login.application_service` login type will be unsupported in an upcoming Synapse release. diff --git a/changelog.d/9560.misc b/changelog.d/9560.misc deleted file mode 100644
index 57a698f846..0000000000 --- a/changelog.d/9560.misc +++ /dev/null
@@ -1 +0,0 @@ -Add a `.git-blame-ignore-revs` file with the hashes of auto-formatting. diff --git a/changelog.d/9561.misc b/changelog.d/9561.misc deleted file mode 100644
index 6c529a82ee..0000000000 --- a/changelog.d/9561.misc +++ /dev/null
@@ -1 +0,0 @@ -Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle. diff --git a/changelog.d/9562.misc b/changelog.d/9562.misc deleted file mode 100644
index 2f0133bff0..0000000000 --- a/changelog.d/9562.misc +++ /dev/null
@@ -1 +0,0 @@ -Fix spurious errors reported by the `config-lint.sh` script. \ No newline at end of file diff --git a/changelog.d/9563.misc b/changelog.d/9563.misc deleted file mode 100644
index 7a3493e4a1..0000000000 --- a/changelog.d/9563.misc +++ /dev/null
@@ -1 +0,0 @@ -Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper. diff --git a/changelog.d/9567.bugfix b/changelog.d/9567.bugfix deleted file mode 100644
index e7322c2b5e..0000000000 --- a/changelog.d/9567.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix bug where federation requests were not correctly retried on 5xx responses. diff --git a/changelog.d/9568.misc b/changelog.d/9568.misc deleted file mode 100644
index 561963de93..0000000000 --- a/changelog.d/9568.misc +++ /dev/null
@@ -1 +0,0 @@ -Do not have mypy ignore type hints from unpaddedbase64. diff --git a/changelog.d/9571.doc b/changelog.d/9571.doc deleted file mode 100644
index 1bba72e7d0..0000000000 --- a/changelog.d/9571.doc +++ /dev/null
@@ -1 +0,0 @@ -Link to the "List user's media" admin API from the media admin API docs. diff --git a/changelog.d/9573.feature b/changelog.d/9573.feature deleted file mode 100644
index 5214b50d41..0000000000 --- a/changelog.d/9573.feature +++ /dev/null
@@ -1 +0,0 @@ -Add prometheus metrics for number of users successfully registering and logging in. diff --git a/changelog.d/9576.misc b/changelog.d/9576.misc deleted file mode 100644
index bc257d05b7..0000000000 --- a/changelog.d/9576.misc +++ /dev/null
@@ -1 +0,0 @@ -Improve efficiency of calculating the auth chain in large rooms. diff --git a/changelog.d/9580.doc b/changelog.d/9580.doc deleted file mode 100644
index f9c8b328b3..0000000000 --- a/changelog.d/9580.doc +++ /dev/null
@@ -1 +0,0 @@ -Clarify the spam checker modules documentation example to mention that `parse_config` is a required method. diff --git a/changelog.d/9583.bugfix b/changelog.d/9583.bugfix deleted file mode 100644
index 51b1876f3b..0000000000 --- a/changelog.d/9583.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. diff --git a/changelog.d/9586.misc b/changelog.d/9586.misc deleted file mode 100644
index 2def9d5f55..0000000000 --- a/changelog.d/9586.misc +++ /dev/null
@@ -1 +0,0 @@ -Convert `synapse.types.Requester` to an `attrs` class. diff --git a/changelog.d/9587.bugfix b/changelog.d/9587.bugfix deleted file mode 100644
index d8f04c4f21..0000000000 --- a/changelog.d/9587.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Re-Activating account with admin API when local passwords are disabled. \ No newline at end of file diff --git a/changelog.d/9590.misc b/changelog.d/9590.misc deleted file mode 100644
index 186396c45b..0000000000 --- a/changelog.d/9590.misc +++ /dev/null
@@ -1 +0,0 @@ -Add logging for redis connection setup. diff --git a/changelog.d/9591.misc b/changelog.d/9591.misc deleted file mode 100644
index 14c7b78dd9..0000000000 --- a/changelog.d/9591.misc +++ /dev/null
@@ -1 +0,0 @@ -Fix incorrect type hints. diff --git a/changelog.d/9596.misc b/changelog.d/9596.misc deleted file mode 100644
index fc19a95f75..0000000000 --- a/changelog.d/9596.misc +++ /dev/null
@@ -1 +0,0 @@ -Improve logging when processing incoming transactions. diff --git a/changelog.d/9597.bugfix b/changelog.d/9597.bugfix deleted file mode 100644
index 349dc9d664..0000000000 --- a/changelog.d/9597.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages. diff --git a/changelog.d/9601.feature b/changelog.d/9601.feature deleted file mode 100644
index 5078d63ffa..0000000000 --- a/changelog.d/9601.feature +++ /dev/null
@@ -1 +0,0 @@ -Optimise handling of incomplete room history for incoming federation. diff --git a/changelog.d/9604.doc b/changelog.d/9604.doc deleted file mode 100644
index d413e38b72..0000000000 --- a/changelog.d/9604.doc +++ /dev/null
@@ -1 +0,0 @@ -Clarify the sample configuration for `stats` settings. diff --git a/changelog.d/9604.misc b/changelog.d/9604.misc deleted file mode 100644
index 0583988588..0000000000 --- a/changelog.d/9604.misc +++ /dev/null
@@ -1 +0,0 @@ -Remove unused `stats.retention` setting, and emit a warning if stats are disabled. diff --git a/changelog.d/9608.misc b/changelog.d/9608.misc deleted file mode 100644
index 14c7b78dd9..0000000000 --- a/changelog.d/9608.misc +++ /dev/null
@@ -1 +0,0 @@ -Fix incorrect type hints. diff --git a/changelog.d/9617.feature b/changelog.d/9617.feature deleted file mode 100644
index b462a32b92..0000000000 --- a/changelog.d/9617.feature +++ /dev/null
@@ -1 +0,0 @@ -Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). diff --git a/changelog.d/9618.misc b/changelog.d/9618.misc deleted file mode 100644
index 14c7b78dd9..0000000000 --- a/changelog.d/9618.misc +++ /dev/null
@@ -1 +0,0 @@ -Fix incorrect type hints. diff --git a/changelog.d/9619.misc b/changelog.d/9619.misc deleted file mode 100644
index 50267bfbc4..0000000000 --- a/changelog.d/9619.misc +++ /dev/null
@@ -1 +0,0 @@ -Prevent attempting to bundle aggregations for state events in /context APIs. \ No newline at end of file diff --git a/changelog.d/9620.bugfix b/changelog.d/9620.bugfix deleted file mode 100644
index 427580f4ad..0000000000 --- a/changelog.d/9620.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix a bug introduced in v1.28.0 where the OpenID Connect callback endpoint could error with a `MacaroonInitException`. diff --git a/changelog.d/9623.bugfix b/changelog.d/9623.bugfix deleted file mode 100644
index ecccb46105..0000000000 --- a/changelog.d/9623.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix Internal Server Error on `GET /_synapse/client/saml2/authn_response` request. diff --git a/changelog.d/9626.feature b/changelog.d/9626.feature deleted file mode 100644
index eacba6201b..0000000000 --- a/changelog.d/9626.feature +++ /dev/null
@@ -1 +0,0 @@ -Tell spam checker modules about the SSO IdP a user registered through if one was used. diff --git a/changelog.d/9654.feature b/changelog.d/9654.feature new file mode 100644
index 0000000000..a54c96cf19 --- /dev/null +++ b/changelog.d/9654.feature
@@ -0,0 +1 @@ +Include request information in structured logging output. diff --git a/changelog.d/9685.misc b/changelog.d/9685.misc new file mode 100644
index 0000000000..0506d8af0c --- /dev/null +++ b/changelog.d/9685.misc
@@ -0,0 +1 @@ +Update `scripts-dev/complement.sh` to use a local checkout of Complement, allow running a subset of tests and have it use Synapse's Complement test blacklist. \ No newline at end of file diff --git a/changelog.d/9686.misc b/changelog.d/9686.misc new file mode 100644
index 0000000000..bb2335acf9 --- /dev/null +++ b/changelog.d/9686.misc
@@ -0,0 +1 @@ +Improve Jaeger tracing for `to_device` messages. diff --git a/changelog.d/9691.feature b/changelog.d/9691.feature new file mode 100644
index 0000000000..3c711db4f5 --- /dev/null +++ b/changelog.d/9691.feature
@@ -0,0 +1 @@ +Add `order_by` to the admin API `GET /_synapse/admin/v2/users`. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/9700.feature b/changelog.d/9700.feature new file mode 100644
index 0000000000..037de8367f --- /dev/null +++ b/changelog.d/9700.feature
@@ -0,0 +1 @@ +Replace the `room_invite_state_types` configuration setting with `room_prejoin_state`. diff --git a/changelog.d/9710.feature b/changelog.d/9710.feature new file mode 100644
index 0000000000..fce308cc41 --- /dev/null +++ b/changelog.d/9710.feature
@@ -0,0 +1 @@ +Experimental Spaces support: include `m.room.create` in the room state sent with room-invites. diff --git a/changelog.d/9711.bugfix b/changelog.d/9711.bugfix new file mode 100644
index 0000000000..4ca3438d46 --- /dev/null +++ b/changelog.d/9711.bugfix
@@ -0,0 +1 @@ +Fix recently added ratelimits to correctly honour the application service `rate_limited` flag. diff --git a/changelog.d/9717.feature b/changelog.d/9717.feature new file mode 100644
index 0000000000..c2c74f13d5 --- /dev/null +++ b/changelog.d/9717.feature
@@ -0,0 +1 @@ +Add experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership. diff --git a/changelog.d/9718.removal b/changelog.d/9718.removal new file mode 100644
index 0000000000..6de7814217 --- /dev/null +++ b/changelog.d/9718.removal
@@ -0,0 +1 @@ +Replace deprecated `imp` module with successor `importlib`. Contributed by Cristina Muñoz. diff --git a/changelog.d/9719.doc b/changelog.d/9719.doc new file mode 100644
index 0000000000..f018606dd6 --- /dev/null +++ b/changelog.d/9719.doc
@@ -0,0 +1 @@ +Make the allowed_local_3pids regex example in the sample config stricter. diff --git a/changelog.d/9725.bugfix b/changelog.d/9725.bugfix new file mode 100644
index 0000000000..71283685c8 --- /dev/null +++ b/changelog.d/9725.bugfix
@@ -0,0 +1 @@ +Fix longstanding bug which caused `duplicate key value violates unique constraint "remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"` errors. diff --git a/changelog.d/9730.misc b/changelog.d/9730.misc new file mode 100644
index 0000000000..8063059b0b --- /dev/null +++ b/changelog.d/9730.misc
@@ -0,0 +1 @@ +Add type hints to expiring cache. diff --git a/changelog.d/9736.misc b/changelog.d/9736.misc new file mode 100644
index 0000000000..1e445e4344 --- /dev/null +++ b/changelog.d/9736.misc
@@ -0,0 +1 @@ +Convert various testcases to `HomeserverTestCase`. diff --git a/changelog.d/9742.misc b/changelog.d/9742.misc new file mode 100644
index 0000000000..681ab04df8 --- /dev/null +++ b/changelog.d/9742.misc
@@ -0,0 +1 @@ +Start linting mypy with `no_implicit_optional`. \ No newline at end of file diff --git a/changelog.d/9743.misc b/changelog.d/9743.misc new file mode 100644
index 0000000000..c2f75c1df9 --- /dev/null +++ b/changelog.d/9743.misc
@@ -0,0 +1 @@ +Add missing type hints to federation handler and server. diff --git a/changelog.d/9753.misc b/changelog.d/9753.misc new file mode 100644
index 0000000000..31184fe0bd --- /dev/null +++ b/changelog.d/9753.misc
@@ -0,0 +1 @@ +Check that a `ConfigError` is raised, rather than simply `Exception`, when appropriate in homeserver config file generation tests. \ No newline at end of file diff --git a/changelog.d/9770.bugfix b/changelog.d/9770.bugfix new file mode 100644
index 0000000000..baf93138de --- /dev/null +++ b/changelog.d/9770.bugfix
@@ -0,0 +1 @@ +Fix bug where sharded federation senders could get stuck repeatedly querying the DB in a loop, using lots of CPU. diff --git a/contrib/purge_api/purge_history.sh b/contrib/purge_api/purge_history.sh
index e7dd5d6468..c45136ff53 100644 --- a/contrib/purge_api/purge_history.sh +++ b/contrib/purge_api/purge_history.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # this script will use the api: # https://github.com/matrix-org/synapse/blob/master/docs/admin_api/purge_history_api.rst diff --git a/contrib/purge_api/purge_remote_media.sh b/contrib/purge_api/purge_remote_media.sh
index 77220d3bd5..4930d9529c 100644 --- a/contrib/purge_api/purge_remote_media.sh +++ b/contrib/purge_api/purge_remote_media.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash DOMAIN=yourserver.tld # add this user as admin in your home server: diff --git a/debian/changelog b/debian/changelog
index 3feefd8a1c..09602ff54b 100644 --- a/debian/changelog +++ b/debian/changelog
@@ -1,3 +1,21 @@ +matrix-synapse-py3 (1.31.0) stable; urgency=medium + + * New synapse release 1.31.0. + + -- Synapse Packaging team <packages@matrix.org> Tue, 06 Apr 2021 13:08:29 +0100 + +matrix-synapse-py3 (1.30.1) stable; urgency=medium + + * New synapse release 1.30.1. + + -- Synapse Packaging team <packages@matrix.org> Fri, 26 Mar 2021 12:01:28 +0000 + +matrix-synapse-py3 (1.30.0) stable; urgency=medium + + * New synapse release 1.30.0. + + -- Synapse Packaging team <packages@matrix.org> Mon, 22 Mar 2021 13:15:34 +0000 + matrix-synapse-py3 (1.29.0) stable; urgency=medium [ Jonathan de Jong ] diff --git a/demo/clean.sh b/demo/clean.sh
index 418ca9457e..6b809f6e83 100755 --- a/demo/clean.sh +++ b/demo/clean.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -e diff --git a/demo/start.sh b/demo/start.sh
index f6b5ea137f..621a5698b8 100755 --- a/demo/start.sh +++ b/demo/start.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash DIR="$( cd "$( dirname "$0" )" && pwd )" diff --git a/demo/stop.sh b/demo/stop.sh
index 85a1d2c161..f9dddc5914 100755 --- a/demo/stop.sh +++ b/demo/stop.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash DIR="$( cd "$( dirname "$0" )" && pwd )" diff --git a/docker/Dockerfile b/docker/Dockerfile
index def4501541..5b7bf02776 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile
@@ -18,6 +18,11 @@ ARG PYTHON_VERSION=3.8 ### FROM docker.io/python:${PYTHON_VERSION}-slim as builder +LABEL org.opencontainers.image.url='https://matrix.org/docs/projects/server/synapse' +LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/synapse/blob/master/docker/README.md' +LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git' +LABEL org.opencontainers.image.licenses='Apache-2.0' + # install the OS build deps RUN apt-get update && apt-get install -y \ build-essential \ @@ -28,33 +33,32 @@ RUN apt-get update && apt-get install -y \ libwebp-dev \ libxml++2.6-dev \ libxslt1-dev \ + openssl \ rustc \ zlib1g-dev \ - && rm -rf /var/lib/apt/lists/* + && rm -rf /var/lib/apt/lists/* -# Build dependencies that are not available as wheels, to speed up rebuilds -RUN pip install --prefix="/install" --no-warn-script-location \ - cryptography \ - frozendict \ - jaeger-client \ - opentracing \ - # Match the version constraints of Synapse - "prometheus_client>=0.4.0" \ - psycopg2 \ - pycparser \ - pyrsistent \ - pyyaml \ - simplejson \ - threadloop \ - thrift - -# now install synapse and all of the python deps to /install. -COPY synapse /synapse/synapse/ +# Copy just what we need to pip install COPY scripts /synapse/scripts/ COPY MANIFEST.in README.rst setup.py synctl /synapse/ +COPY synapse/__init__.py /synapse/synapse/__init__.py +COPY synapse/python_dependencies.py /synapse/synapse/python_dependencies.py +# To speed up rebuilds, install all of the dependencies before we copy over +# the whole synapse project so that we this layer in the Docker cache can be +# used while you develop on the source +# +# This is aiming at installing the `install_requires` and `extras_require` from `setup.py` RUN pip install --prefix="/install" --no-warn-script-location \ - /synapse[all] + /synapse[all] + +# Copy over the rest of the project +COPY synapse /synapse/synapse/ + +# Install the synapse package itself and all of its children packages. +# +# This is aiming at installing only the `packages=find_packages(...)` from `setup.py +RUN pip install --prefix="/install" --no-deps --no-warn-script-location /synapse ### ### Stage 1: runtime @@ -70,7 +74,9 @@ RUN apt-get update && apt-get install -y \ libwebp6 \ xmlsec1 \ libjemalloc2 \ - && rm -rf /var/lib/apt/lists/* + libssl-dev \ + openssl \ + && rm -rf /var/lib/apt/lists/* COPY --from=builder /install /usr/local COPY ./docker/start.py /start.py @@ -83,4 +89,4 @@ EXPOSE 8008/tcp 8009/tcp 8448/tcp ENTRYPOINT ["/start.py"] HEALTHCHECK --interval=1m --timeout=5s \ - CMD curl -fSs http://localhost:8008/health || exit 1 + CMD curl -fSs http://localhost:8008/health || exit 1 diff --git a/docker/build_debian.sh b/docker/build_debian.sh
index f312f0715f..f426d2b77b 100644 --- a/docker/build_debian.sh +++ b/docker/build_debian.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # The script to build the Debian package, as ran inside the Docker image. diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml
index 0dea62a87d..a792899540 100644 --- a/docker/conf/homeserver.yaml +++ b/docker/conf/homeserver.yaml
@@ -173,18 +173,10 @@ report_stats: False ## API Configuration ## -room_invite_state_types: - - "m.room.join_rules" - - "m.room.canonical_alias" - - "m.room.avatar" - - "m.room.name" - {% if SYNAPSE_APPSERVICES %} app_service_config_files: {% for appservice in SYNAPSE_APPSERVICES %} - "{{ appservice }}" {% endfor %} -{% else %} -app_service_config_files: [] {% endif %} macaroon_secret_key: "{{ SYNAPSE_MACAROON_SECRET_KEY }}" diff --git a/docker/run_pg_tests.sh b/docker/run_pg_tests.sh
index d18d1e4c8e..1fd08cb62b 100755 --- a/docker/run_pg_tests.sh +++ b/docker/run_pg_tests.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # This script runs the PostgreSQL tests inside a Docker container. It expects # the relevant source files to be mounted into /src (done automatically by the diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 8d4ec5a6f9..a8a5a2628c 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst
@@ -111,35 +111,16 @@ List Accounts ============= This API returns all local user accounts. +By default, the response is ordered by ascending user ID. -The api is:: +The API is:: GET /_synapse/admin/v2/users?from=0&limit=10&guests=false To use it, you will need to authenticate by providing an ``access_token`` for a server admin: see `README.rst <README.rst>`_. -The parameter ``from`` is optional but used for pagination, denoting the -offset in the returned results. This should be treated as an opaque value and -not explicitly set to anything other than the return value of ``next_token`` -from a previous call. - -The parameter ``limit`` is optional but is used for pagination, denoting the -maximum number of items to return in this call. Defaults to ``100``. - -The parameter ``user_id`` is optional and filters to only return users with user IDs -that contain this value. This parameter is ignored when using the ``name`` parameter. - -The parameter ``name`` is optional and filters to only return users with user ID localparts -**or** displaynames that contain this value. - -The parameter ``guests`` is optional and if ``false`` will **exclude** guest users. -Defaults to ``true`` to include guest users. - -The parameter ``deactivated`` is optional and if ``true`` will **include** deactivated users. -Defaults to ``false`` to exclude deactivated users. - -A JSON body is returned with the following shape: +A response body like the following is returned: .. code:: json @@ -175,6 +156,66 @@ with ``from`` set to the value of ``next_token``. This will return a new page. If the endpoint does not return a ``next_token`` then there are no more users to paginate through. +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - Is optional and filters to only return users with user IDs + that contain this value. This parameter is ignored when using the ``name`` parameter. +- ``name`` - Is optional and filters to only return users with user ID localparts + **or** displaynames that contain this value. +- ``guests`` - string representing a bool - Is optional and if ``false`` will **exclude** guest users. + Defaults to ``true`` to include guest users. +- ``deactivated`` - string representing a bool - Is optional and if ``true`` will **include** deactivated users. + Defaults to ``false`` to exclude deactivated users. +- ``limit`` - string representing a positive integer - Is optional but is used for pagination, + denoting the maximum number of items to return in this call. Defaults to ``100``. +- ``from`` - string representing a positive integer - Is optional but used for pagination, + denoting the offset in the returned results. This should be treated as an opaque value and + not explicitly set to anything other than the return value of ``next_token`` from a previous call. + Defaults to ``0``. +- ``order_by`` - The method by which to sort the returned list of users. + If the ordered field has duplicates, the second order is always by ascending ``name``, + which guarantees a stable ordering. Valid values are: + + - ``name`` - Users are ordered alphabetically by ``name``. This is the default. + - ``is_guest`` - Users are ordered by ``is_guest`` status. + - ``admin`` - Users are ordered by ``admin`` status. + - ``user_type`` - Users are ordered alphabetically by ``user_type``. + - ``deactivated`` - Users are ordered by ``deactivated`` status. + - ``shadow_banned`` - Users are ordered by ``shadow_banned`` status. + - ``displayname`` - Users are ordered alphabetically by ``displayname``. + - ``avatar_url`` - Users are ordered alphabetically by avatar URL. + +- ``dir`` - Direction of media order. Either ``f`` for forwards or ``b`` for backwards. + Setting this value to ``b`` will reverse the above sort order. Defaults to ``f``. + +Caution. The database only has indexes on the columns ``name`` and ``created_ts``. +This means that if a different sort order is used (``is_guest``, ``admin``, +``user_type``, ``deactivated``, ``shadow_banned``, ``avatar_url`` or ``displayname``), +this can cause a large load on the database, especially for large environments. + +**Response** + +The following fields are returned in the JSON response body: + +- ``users`` - An array of objects, each containing information about an user. + User objects contain the following fields: + + - ``name`` - string - Fully-qualified user ID (ex. `@user:server.com`). + - ``is_guest`` - bool - Status if that user is a guest account. + - ``admin`` - bool - Status if that user is a server administrator. + - ``user_type`` - string - Type of the user. Normal users are type ``None``. + This allows user type specific behaviour. There are also types ``support`` and ``bot``. + - ``deactivated`` - bool - Status if that user has been marked as deactivated. + - ``shadow_banned`` - bool - Status if that user has been marked as shadow banned. + - ``displayname`` - string - The user's display name if they have set one. + - ``avatar_url`` - string - The user's avatar URL if they have set one. + +- ``next_token``: string representing a positive integer - Indication for pagination. See above. +- ``total`` - integer - Total number of media. + + Query current sessions for a user ================================= diff --git a/docs/code_style.md b/docs/code_style.md
index 190f8ab2de..28fb7277c4 100644 --- a/docs/code_style.md +++ b/docs/code_style.md
@@ -128,6 +128,9 @@ Some guidelines follow: will be if no sub-options are enabled). - Lines should be wrapped at 80 characters. - Use two-space indents. +- `true` and `false` are spelt thus (as opposed to `True`, etc.) +- Use single quotes (`'`) rather than double-quotes (`"`) or backticks + (`` ` ``) to refer to configuration options. Example: diff --git a/docs/deprecation_policy.md b/docs/deprecation_policy.md new file mode 100644
index 0000000000..06ea340559 --- /dev/null +++ b/docs/deprecation_policy.md
@@ -0,0 +1,33 @@ +Deprecation Policy for Platform Dependencies +============================================ + +Synapse has a number of platform dependencies, including Python and PostgreSQL. +This document outlines the policy towards which versions we support, and when we +drop support for versions in the future. + + +Policy +------ + +Synapse follows the upstream support life cycles for Python and PostgreSQL, +i.e. when a version reaches End of Life Synapse will withdraw support for that +version in future releases. + +Details on the upstream support life cycles for Python and PostgreSQL are +documented at https://endoflife.date/python and +https://endoflife.date/postgresql. + + +Context +------- + +It is important for system admins to have a clear understanding of the platform +requirements of Synapse and its deprecation policies so that they can +effectively plan upgrading their infrastructure ahead of time. This is +especially important in contexts where upgrading the infrastructure requires +auditing and approval from a security team, or where otherwise upgrading is a +long process. + +By following the upstream support life cycles Synapse can ensure that its +dependencies continue to get security patches, while not requiring system admins +to constantly update their platform dependencies to the latest versions. diff --git a/docs/presence_router_module.md b/docs/presence_router_module.md new file mode 100644
index 0000000000..d6566d978d --- /dev/null +++ b/docs/presence_router_module.md
@@ -0,0 +1,235 @@ +# Presence Router Module + +Synapse supports configuring a module that can specify additional users +(local or remote) to should receive certain presence updates from local +users. + +Note that routing presence via Application Service transactions is not +currently supported. + +The presence routing module is implemented as a Python class, which will +be imported by the running Synapse. + +## Python Presence Router Class + +The Python class is instantiated with two objects: + +* A configuration object of some type (see below). +* An instance of `synapse.module_api.ModuleApi`. + +It then implements methods related to presence routing. + +Note that one method of `ModuleApi` that may be useful is: + +```python +async def ModuleApi.send_local_online_presence_to(users: Iterable[str]) -> None +``` + +which can be given a list of local or remote MXIDs to broadcast known, online user +presence to (for those users that the receiving user is considered interested in). +It does not include state for users who are currently offline, and it can only be +called on workers that support sending federation. + +### Module structure + +Below is a list of possible methods that can be implemented, and whether they are +required. + +#### `parse_config` + +```python +def parse_config(config_dict: dict) -> Any +``` + +**Required.** A static method that is passed a dictionary of config options, and + should return a validated config object. This method is described further in + [Configuration](#configuration). + +#### `get_users_for_states` + +```python +async def get_users_for_states( + self, + state_updates: Iterable[UserPresenceState], +) -> Dict[str, Set[UserPresenceState]]: +``` + +**Required.** An asynchronous method that is passed an iterable of user presence +state. This method can determine whether a given presence update should be sent to certain +users. It does this by returning a dictionary with keys representing local or remote +Matrix User IDs, and values being a python set +of `synapse.handlers.presence.UserPresenceState` instances. + +Synapse will then attempt to send the specified presence updates to each user when +possible. + +#### `get_interested_users` + +```python +async def get_interested_users(self, user_id: str) -> Union[Set[str], str] +``` + +**Required.** An asynchronous method that is passed a single Matrix User ID. This +method is expected to return the users that the passed in user may be interested in the +presence of. Returned users may be local or remote. The presence routed as a result of +what this method returns is sent in addition to the updates already sent between users +that share a room together. Presence updates are deduplicated. + +This method should return a python set of Matrix User IDs, or the object +`synapse.events.presence_router.PresenceRouter.ALL_USERS` to indicate that the passed +user should receive presence information for *all* known users. + +For clarity, if the user `@alice:example.org` is passed to this method, and the Set +`{"@bob:example.com", "@charlie:somewhere.org"}` is returned, this signifies that Alice +should receive presence updates sent by Bob and Charlie, regardless of whether these +users share a room. + +### Example + +Below is an example implementation of a presence router class. + +```python +from typing import Dict, Iterable, Set, Union +from synapse.events.presence_router import PresenceRouter +from synapse.handlers.presence import UserPresenceState +from synapse.module_api import ModuleApi + +class PresenceRouterConfig: + def __init__(self): + # Config options with their defaults + # A list of users to always send all user presence updates to + self.always_send_to_users = [] # type: List[str] + + # A list of users to ignore presence updates for. Does not affect + # shared-room presence relationships + self.blacklisted_users = [] # type: List[str] + +class ExamplePresenceRouter: + """An example implementation of synapse.presence_router.PresenceRouter. + Supports routing all presence to a configured set of users, or a subset + of presence from certain users to members of certain rooms. + + Args: + config: A configuration object. + module_api: An instance of Synapse's ModuleApi. + """ + def __init__(self, config: PresenceRouterConfig, module_api: ModuleApi): + self._config = config + self._module_api = module_api + + @staticmethod + def parse_config(config_dict: dict) -> PresenceRouterConfig: + """Parse a configuration dictionary from the homeserver config, do + some validation and return a typed PresenceRouterConfig. + + Args: + config_dict: The configuration dictionary. + + Returns: + A validated config object. + """ + # Initialise a typed config object + config = PresenceRouterConfig() + always_send_to_users = config_dict.get("always_send_to_users") + blacklisted_users = config_dict.get("blacklisted_users") + + # Do some validation of config options... otherwise raise a + # synapse.config.ConfigError. + config.always_send_to_users = always_send_to_users + config.blacklisted_users = blacklisted_users + + return config + + async def get_users_for_states( + self, + state_updates: Iterable[UserPresenceState], + ) -> Dict[str, Set[UserPresenceState]]: + """Given an iterable of user presence updates, determine where each one + needs to go. Returned results will not affect presence updates that are + sent between users who share a room. + + Args: + state_updates: An iterable of user presence state updates. + + Returns: + A dictionary of user_id -> set of UserPresenceState that the user should + receive. + """ + destination_users = {} # type: Dict[str, Set[UserPresenceState] + + # Ignore any updates for blacklisted users + desired_updates = set() + for update in state_updates: + if update.state_key not in self._config.blacklisted_users: + desired_updates.add(update) + + # Send all presence updates to specific users + for user_id in self._config.always_send_to_users: + destination_users[user_id] = desired_updates + + return destination_users + + async def get_interested_users( + self, + user_id: str, + ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + """ + Retrieve a list of users that `user_id` is interested in receiving the + presence of. This will be in addition to those they share a room with. + Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate + that this user should receive all incoming local and remote presence updates. + + Note that this method will only be called for local users. + + Args: + user_id: A user requesting presence updates. + + Returns: + A set of user IDs to return additional presence updates for, or + PresenceRouter.ALL_USERS to return presence updates for all other users. + """ + if user_id in self._config.always_send_to_users: + return PresenceRouter.ALL_USERS + + return set() +``` + +#### A note on `get_users_for_states` and `get_interested_users` + +Both of these methods are effectively two different sides of the same coin. The logic +regarding which users should receive updates for other users should be the same +between them. + +`get_users_for_states` is called when presence updates come in from either federation +or local users, and is used to either direct local presence to remote users, or to +wake up the sync streams of local users to collect remote presence. + +In contrast, `get_interested_users` is used to determine the users that presence should +be fetched for when a local user is syncing. This presence is then retrieved, before +being fed through `get_users_for_states` once again, with only the syncing user's +routing information pulled from the resulting dictionary. + +Their routing logic should thus line up, else you may run into unintended behaviour. + +## Configuration + +Once you've crafted your module and installed it into the same Python environment as +Synapse, amend your homeserver config file with the following. + +```yaml +presence: + routing_module: + module: my_module.ExamplePresenceRouter + config: + # Any configuration options for your module. The below is an example. + # of setting options for ExamplePresenceRouter. + always_send_to_users: ["@presence_gobbler:example.org"] + blacklisted_users: + - "@alice:example.com" + - "@bob:example.com" + ... +``` + +The contents of `config` will be passed as a Python dictionary to the static +`parse_config` method of your class. The object returned by this method will +then be passed to the `__init__` method of your module as `config`. diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index 860afd5a04..cf1b835b9d 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md
@@ -104,10 +104,11 @@ example.com:8448 { ``` <VirtualHost *:443> SSLEngine on - ServerName matrix.example.com; + ServerName matrix.example.com RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME} AllowEncodedSlashes NoDecode + ProxyPreserveHost on ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix ProxyPass /_synapse/client http://127.0.0.1:8008/_synapse/client nocanon @@ -116,7 +117,7 @@ example.com:8448 { <VirtualHost *:8448> SSLEngine on - ServerName example.com; + ServerName example.com RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME} AllowEncodedSlashes NoDecode @@ -135,6 +136,8 @@ example.com:8448 { </IfModule> ``` +**NOTE 3**: Missing `ProxyPreserveHost on` can lead to a redirect loop. + ### HAProxy ``` diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 7de000f4a4..9182dcd987 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml
@@ -82,9 +82,28 @@ pid_file: DATADIR/homeserver.pid # #soft_file_limit: 0 -# Set to false to disable presence tracking on this homeserver. +# Presence tracking allows users to see the state (e.g online/offline) +# of other local and remote users. # -#use_presence: false +presence: + # Uncomment to disable presence tracking on this homeserver. This option + # replaces the previous top-level 'use_presence' option. + # + #enabled: false + + # Presence routers are third-party modules that can specify additional logic + # to where presence updates from users are routed. + # + presence_router: + # The custom module's class. Uncomment to use a custom presence router module. + # + #module: "my_custom_router.PresenceRouter" + + # Configuration options of the custom module. Refer to your module's + # documentation for available options. + # + #config: + # example_option: 'something' # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. Defaults to @@ -869,10 +888,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" #rc_joins: # local: # per_second: 0.1 -# burst_count: 3 +# burst_count: 10 # remote: # per_second: 0.01 -# burst_count: 3 +# burst_count: 10 # #rc_3pid_validation: # per_second: 0.003 @@ -1246,9 +1265,9 @@ account_validity: # #allowed_local_3pids: # - medium: email -# pattern: '.*@matrix\.org' +# pattern: '^[^@]+@matrix\.org$' # - medium: email -# pattern: '.*@vector\.im' +# pattern: '^[^@]+@vector\.im$' # - medium: msisdn # pattern: '\+44' @@ -1451,14 +1470,31 @@ metrics_flags: ## API Configuration ## -# A list of event types that will be included in the room_invite_state +# Controls for the state that is shared with users who receive an invite +# to a room # -#room_invite_state_types: -# - "m.room.join_rules" -# - "m.room.canonical_alias" -# - "m.room.avatar" -# - "m.room.encryption" -# - "m.room.name" +room_prejoin_state: + # By default, the following state event types are shared with users who + # receive invites to the room: + # + # - m.room.join_rules + # - m.room.canonical_alias + # - m.room.avatar + # - m.room.encryption + # - m.room.name + # + # Uncomment the following to disable these defaults (so that only the event + # types listed in 'additional_event_types' are shared). Defaults to 'false'. + # + #disable_default_event_types: true + + # Additional state event types to share with users when they are invited + # to a room. + # + # By default, this list is empty (so only the default event types are shared). + # + #additional_event_types: + # - org.example.custom.event.type # A list of application service config files to use @@ -1758,6 +1794,9 @@ saml2_config: # Note that, if this is changed, users authenticating via that provider # will no longer be recognised as the same user! # +# (Use "oidc" here if you are migrating from an old "oidc_config" +# configuration.) +# # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # @@ -1873,6 +1912,24 @@ saml2_config: # which is set to the claims returned by the UserInfo Endpoint and/or # in the ID Token. # +# It is possible to configure Synapse to only allow logins if certain attributes +# match particular values in the OIDC userinfo. The requirements can be listed under +# `attribute_requirements` as shown below. All of the listed attributes must +# match for the login to be permitted. Additional attributes can be added to +# userinfo by expanding the `scopes` section of the OIDC config to retrieve +# additional information from the OIDC provider. +# +# If the OIDC claim is a list, then the attribute must match any value in the list. +# Otherwise, it must exactly match the value of the claim. Using the example +# below, the `family_name` claim MUST be "Stephensson", but the `groups` +# claim MUST contain "admin". +# +# attribute_requirements: +# - attribute: family_name +# value: "Stephensson" +# - attribute: groups +# value: "admin" +# # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md # for information on how to configure these options. # @@ -1905,34 +1962,9 @@ oidc_providers: # localpart_template: "{{ user.login }}" # display_name_template: "{{ user.name }}" # email_template: "{{ user.email }}" - - # For use with Keycloak - # - #- idp_id: keycloak - # idp_name: Keycloak - # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name" - # client_id: "synapse" - # client_secret: "copy secret generated in Keycloak UI" - # scopes: ["openid", "profile"] - - # For use with Github - # - #- idp_id: github - # idp_name: Github - # idp_brand: github - # discover: false - # issuer: "https://github.com/" - # client_id: "your-client-id" # TO BE FILLED - # client_secret: "your-client-secret" # TO BE FILLED - # authorization_endpoint: "https://github.com/login/oauth/authorize" - # token_endpoint: "https://github.com/login/oauth/access_token" - # userinfo_endpoint: "https://api.github.com/user" - # scopes: ["read:user"] - # user_mapping_provider: - # config: - # subject_claim: "id" - # localpart_template: "{{ user.login }}" - # display_name_template: "{{ user.name }}" + # attribute_requirements: + # - attribute: userGroup + # value: "synapseUsers" # Enable Central Authentication Service (CAS) for registration and login. diff --git a/docs/workers.md b/docs/workers.md
index e7bf9b8ce4..c6282165b0 100644 --- a/docs/workers.md +++ b/docs/workers.md
@@ -232,7 +232,6 @@ expressions: # Registration/login requests ^/_matrix/client/(api/v1|r0|unstable)/login$ ^/_matrix/client/(r0|unstable)/register$ - ^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$ # Event sending requests ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact @@ -276,7 +275,7 @@ using): Ensure that all SSO logins go to a single process. For multiple workers not handling the SSO endpoints properly, see -[#7530](https://github.com/matrix-org/synapse/issues/7530) and +[#7530](https://github.com/matrix-org/synapse/issues/7530) and [#9427](https://github.com/matrix-org/synapse/issues/9427). Note that a HTTP listener with `client` and `federation` resources must be diff --git a/mypy.ini b/mypy.ini
index e0685e097c..32e6197409 100644 --- a/mypy.ini +++ b/mypy.ini
@@ -1,12 +1,14 @@ [mypy] namespace_packages = True plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py -follow_imports = silent +follow_imports = normal check_untyped_defs = True show_error_codes = True show_traceback = True mypy_path = stubs warn_unreachable = True +local_partial_types = True +no_implicit_optional = True # To find all folders that pass mypy you run: # @@ -20,8 +22,9 @@ files = synapse/crypto, synapse/event_auth.py, synapse/events/builder.py, - synapse/events/validator.py, synapse/events/spamcheck.py, + synapse/events/third_party_rules.py, + synapse/events/validator.py, synapse/federation, synapse/groups, synapse/handlers, @@ -38,6 +41,7 @@ files = synapse/push, synapse/replication, synapse/rest, + synapse/secrets.py, synapse/server.py, synapse/server_notices, synapse/spam_checker_api, @@ -71,6 +75,7 @@ files = synapse/util/metrics.py, synapse/util/macaroons.py, synapse/util/stringutils.py, + synapse/visibility.py, tests/replication, tests/test_utils, tests/handlers/test_password_providers.py, diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment
index 448cadb829..af6d32e332 100755 --- a/scripts-dev/check-newsfragment +++ b/scripts-dev/check-newsfragment
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # A script which checks that an appropriate news file has been added on this # branch. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh
index 3cde53f5c0..31cc20a826 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh
@@ -1,22 +1,49 @@ -#! /bin/bash -eu +#!/usr/bin/env bash # This script is designed for developers who want to test their code # against Complement. # # It makes a Synapse image which represents the current checkout, -# then downloads Complement and runs it with that image. +# builds a synapse-complement image on top, then runs tests with it. +# +# By default the script will fetch the latest Complement master branch and +# run tests with that. This can be overridden to use a custom Complement +# checkout by setting the COMPLEMENT_DIR environment variable to the +# filepath of a local Complement checkout. +# +# A regular expression of test method names can be supplied as the first +# argument to the script. Complement will then only run those tests. If +# no regex is supplied, all tests are run. For example; +# +# ./complement.sh "TestOutboundFederation(Profile|Send)" +# + +# Exit if a line returns a non-zero exit code +set -e +# Change to the repository root cd "$(dirname $0)/.." +# Check for a user-specified Complement checkout +if [[ -z "$COMPLEMENT_DIR" ]]; then + echo "COMPLEMENT_DIR not set. Fetching the latest Complement checkout..." + wget -Nq https://github.com/matrix-org/complement/archive/master.tar.gz + tar -xzf master.tar.gz + COMPLEMENT_DIR=complement-master + echo "Checkout available at 'complement-master'" +fi + # Build the base Synapse image from the local checkout -docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile . +docker build -t matrixdotorg/synapse -f docker/Dockerfile . +# Build the Synapse monolith image from Complement, based on the above image we just built +docker build -t complement-synapse -f "$COMPLEMENT_DIR/dockerfiles/Synapse.Dockerfile" "$COMPLEMENT_DIR/dockerfiles" -# Download Complement -wget -N https://github.com/matrix-org/complement/archive/master.tar.gz -tar -xzf master.tar.gz -cd complement-master +cd "$COMPLEMENT_DIR" -# Build the Synapse image from Complement, based on the above image we just built -docker build -t complement-synapse -f dockerfiles/Synapse.Dockerfile ./dockerfiles +EXTRA_COMPLEMENT_ARGS="" +if [[ -n "$1" ]]; then + # A test name regex has been set, supply it to Complement + EXTRA_COMPLEMENT_ARGS+="-run $1 " +fi -# Run the tests on the resulting image! -COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -count=1 ./tests +# Run the tests! +COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -tags synapse_blacklist -count=1 $EXTRA_COMPLEMENT_ARGS ./tests diff --git a/scripts-dev/config-lint.sh b/scripts-dev/config-lint.sh
index 9132160463..8c6323e59a 100755 --- a/scripts-dev/config-lint.sh +++ b/scripts-dev/config-lint.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # Find linting errors in Synapse's default config file. # Exits with 0 if there are no problems, or another code otherwise. diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index abcec48c4f..6f76c08fcf 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py
@@ -22,8 +22,8 @@ import sys from typing import Any, Optional from urllib import parse as urlparse -import nacl.signing import requests +import signedjson.key import signedjson.types import srvlookup import yaml @@ -44,18 +44,6 @@ def encode_base64(input_bytes): return output_string -def decode_base64(input_string): - """Decode a base64 string to bytes inferring padding from the length of the - string.""" - - input_bytes = input_string.encode("ascii") - input_len = len(input_bytes) - padding = b"=" * (3 - ((input_len + 3) % 4)) - output_len = 3 * ((input_len + 2) // 4) + (input_len + 2) % 4 - 2 - output_bytes = base64.b64decode(input_bytes + padding) - return output_bytes[:output_len] - - def encode_canonical_json(value): return json.dumps( value, @@ -88,42 +76,6 @@ def sign_json( return json_object -NACL_ED25519 = "ed25519" - - -def decode_signing_key_base64(algorithm, version, key_base64): - """Decode a base64 encoded signing key - Args: - algorithm (str): The algorithm the key is for (currently "ed25519"). - version (str): Identifies this key out of the keys for this entity. - key_base64 (str): Base64 encoded bytes of the key. - Returns: - A SigningKey object. - """ - if algorithm == NACL_ED25519: - key_bytes = decode_base64(key_base64) - key = nacl.signing.SigningKey(key_bytes) - key.version = version - key.alg = NACL_ED25519 - return key - else: - raise ValueError("Unsupported algorithm %s" % (algorithm,)) - - -def read_signing_keys(stream): - """Reads a list of keys from a stream - Args: - stream : A stream to iterate for keys. - Returns: - list of SigningKey objects. - """ - keys = [] - for line in stream: - algorithm, version, key_base64 = line.split() - keys.append(decode_signing_key_base64(algorithm, version, key_base64)) - return keys - - def request( method: Optional[str], origin_name: str, @@ -223,23 +175,28 @@ def main(): parser.add_argument("--body", help="Data to send as the body of the HTTP request") parser.add_argument( - "path", help="request path. We will add '/_matrix/federation/v1/' to this." + "path", help="request path, including the '/_matrix/federation/...' prefix." ) args = parser.parse_args() - if not args.server_name or not args.signing_key_path: + args.signing_key = None + if args.signing_key_path: + with open(args.signing_key_path) as f: + args.signing_key = f.readline() + + if not args.server_name or not args.signing_key: read_args_from_config(args) - with open(args.signing_key_path) as f: - key = read_signing_keys(f)[0] + algorithm, version, key_base64 = args.signing_key.split() + key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64) result = request( args.method, args.server_name, key, args.destination, - "/_matrix/federation/v1/" + args.path, + args.path, content=args.body, ) @@ -255,10 +212,16 @@ def main(): def read_args_from_config(args): with open(args.config, "r") as fh: config = yaml.safe_load(fh) + if not args.server_name: args.server_name = config["server_name"] - if not args.signing_key_path: - args.signing_key_path = config["signing_key_path"] + + if not args.signing_key: + if "signing_key" in config: + args.signing_key = config["signing_key"] + else: + with open(config["signing_key_path"]) as f: + args.signing_key = f.readline() class MatrixConnectionAdapter(HTTPAdapter): diff --git a/scripts-dev/generate_sample_config b/scripts-dev/generate_sample_config
index 9cb4630a5c..02739894b5 100755 --- a/scripts-dev/generate_sample_config +++ b/scripts-dev/generate_sample_config
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Update/check the docs/sample_config.yaml diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index fe2965cd36..9761e97594 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Runs linting scripts over the local Synapse checkout # isort - sorts import statements diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh
index b8d1e636f1..bc8f978660 100755 --- a/scripts-dev/make_full_schema.sh +++ b/scripts-dev/make_full_schema.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # This script generates SQL files for creating a brand new Synapse DB with the latest # schema, on both SQLite3 and Postgres. diff --git a/scripts-dev/next_github_number.sh b/scripts-dev/next_github_number.sh
index 376280025a..00e9b14569 100755 --- a/scripts-dev/next_github_number.sh +++ b/scripts-dev/next_github_number.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -e @@ -6,4 +6,4 @@ set -e # next PR number. CURRENT_NUMBER=`curl -s "https://api.github.com/repos/matrix-org/synapse/issues?state=all&per_page=1" | jq -r ".[0].number"` CURRENT_NUMBER=$((CURRENT_NUMBER+1)) -echo $CURRENT_NUMBER \ No newline at end of file +echo $CURRENT_NUMBER diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py
index ab2e763386..8477955a90 100755 --- a/scripts/move_remote_media_to_new_store.py +++ b/scripts/move_remote_media_to_new_store.py
@@ -51,7 +51,7 @@ def main(src_repo, dest_repo): parts = line.split("|") if len(parts) != 2: print("Unable to parse input line %s" % line, file=sys.stderr) - exit(1) + sys.exit(1) move_media(parts[0], parts[1], src_paths, dest_paths) diff --git a/setup.cfg b/setup.cfg
index 5e301c2cd7..7329eed213 100644 --- a/setup.cfg +++ b/setup.cfg
@@ -18,7 +18,8 @@ ignore = # E203: whitespace before ':' (which is contrary to pep8?) # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) -ignore=W503,W504,E203,E731,E501 +# B00*: Subsection of the bugbear suite (TODO: add in remaining fixes) +ignore=W503,W504,E203,E731,E501,B006,B007,B008 [isort] line_length = 88 diff --git a/setup.py b/setup.py
index bbd9e7862a..29e9971dc1 100755 --- a/setup.py +++ b/setup.py
@@ -99,10 +99,11 @@ CONDITIONAL_REQUIREMENTS["lint"] = [ "isort==5.7.0", "black==20.8b1", "flake8-comprehensions", + "flake8-bugbear==21.3.2", "flake8", ] -CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.11"] +CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.13"] # Dependencies which are exclusively required by unit test code. This is # NOT a list of all modules that are necessary to run the unit tests. diff --git a/synapse/__init__.py b/synapse/__init__.py
index 56ca888862..1d2883acf6 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try: except ImportError: pass -__version__ = "1.29.0" +__version__ = "1.31.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index e10e33fd23..7d9930ae7b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py
@@ -558,6 +558,9 @@ class Auth: Returns: bool: False if no access_token was given, True otherwise. """ + # This will always be set by the time Twisted calls us. + assert request.args is not None + query_params = request.args.get(b"access_token") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") return bool(query_params) or bool(auth_headers) @@ -574,6 +577,8 @@ class Auth: MissingClientTokenError: If there isn't a single access_token in the request """ + # This will always be set by the time Twisted calls us. + assert request.args is not None auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") query_params = request.args.get(b"access_token") diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 691f8f9adf..6856dab06c 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py
@@ -51,6 +51,7 @@ class PresenceState: OFFLINE = "offline" UNAVAILABLE = "unavailable" ONLINE = "online" + BUSY = "org.matrix.msc3026.busy" class JoinRules: @@ -58,6 +59,8 @@ class JoinRules: KNOCK = "knock" INVITE = "invite" PRIVATE = "private" + # As defined for MSC3083. + MSC3083_RESTRICTED = "restricted" class LoginType: @@ -100,6 +103,9 @@ class EventTypes: Dummy = "org.matrix.dummy_event" + MSC1772_SPACE_CHILD = "org.matrix.msc1772.space.child" + MSC1772_SPACE_PARENT = "org.matrix.msc1772.space.parent" + class EduTypes: Presence = "m.presence" @@ -160,6 +166,9 @@ class EventContentFields: # cf https://github.com/matrix-org/matrix-doc/pull/2228 SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" + # cf https://github.com/matrix-org/matrix-doc/pull/1772 + MSC1772_ROOM_TYPE = "org.matrix.msc1772.type" + class RoomEncryptionAlgorithms: MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index c3f07bc1a3..2244b8a340 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py
@@ -17,6 +17,7 @@ from collections import OrderedDict from typing import Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util import Clock @@ -31,10 +32,13 @@ class Ratelimiter: burst_count: How many actions that can be performed before being limited. """ - def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + def __init__( + self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int + ): self.clock = clock self.rate_hz = rate_hz self.burst_count = burst_count + self.store = store # A ordered dictionary keeping track of actions, when they were last # performed and how often. Each entry is a mapping from a key of arbitrary type @@ -46,45 +50,10 @@ class Ratelimiter: OrderedDict() ) # type: OrderedDict[Hashable, Tuple[float, int, float]] - def can_requester_do_action( - self, - requester: Requester, - rate_hz: Optional[float] = None, - burst_count: Optional[int] = None, - update: bool = True, - _time_now_s: Optional[int] = None, - ) -> Tuple[bool, float]: - """Can the requester perform the action? - - Args: - requester: The requester to key off when rate limiting. The user property - will be used. - rate_hz: The long term number of actions that can be performed in a second. - Overrides the value set during instantiation if set. - burst_count: How many actions that can be performed before being limited. - Overrides the value set during instantiation if set. - update: Whether to count this check as performing the action - _time_now_s: The current time. Optional, defaults to the current time according - to self.clock. Only used by tests. - - Returns: - A tuple containing: - * A bool indicating if they can perform the action now - * The reactor timestamp for when the action can be performed next. - -1 if rate_hz is less than or equal to zero - """ - # Disable rate limiting of users belonging to any AS that is configured - # not to be rate limited in its registration file (rate_limited: true|false). - if requester.app_service and not requester.app_service.is_rate_limited(): - return True, -1.0 - - return self.can_do_action( - requester.user.to_string(), rate_hz, burst_count, update, _time_now_s - ) - - def can_do_action( + async def can_do_action( self, - key: Hashable, + requester: Optional[Requester], + key: Optional[Hashable] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -92,9 +61,16 @@ class Ratelimiter: ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? + Checks if the user has ratelimiting disabled in the database by looking + for null/zero values in the `ratelimit_override` table. (Non-zero + values aren't honoured, as they're specific to the event sending + ratelimiter, rather than all ratelimiters) + Args: - key: The key we should use when rate limiting. Can be a user ID - (when sending events), an IP address, etc. + requester: The requester that is doing the action, if any. Used to check + if the user has ratelimits disabled in the database. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. @@ -109,6 +85,30 @@ class Ratelimiter: * The reactor timestamp for when the action can be performed next. -1 if rate_hz is less than or equal to zero """ + if key is None: + if not requester: + raise ValueError("Must supply at least one of `requester` or `key`") + + key = requester.user.to_string() + + if requester: + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return True, -1.0 + + # Check if ratelimiting has been disabled for the user. + # + # Note that we don't use the returned rate/burst count, as the table + # is specifically for the event sending ratelimiter. Instead, we + # only use it to (somewhat cheekily) infer whether the user should + # be subject to any rate limiting or not. + override = await self.store.get_ratelimit_for_user( + requester.authenticated_entity + ) + if override and not override.messages_per_second: + return True, -1.0 + # Override default values if set time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() rate_hz = rate_hz if rate_hz is not None else self.rate_hz @@ -175,9 +175,10 @@ class Ratelimiter: else: del self.actions[key] - def ratelimit( + async def ratelimit( self, - key: Hashable, + requester: Optional[Requester], + key: Optional[Hashable] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -185,8 +186,16 @@ class Ratelimiter: ): """Checks if an action can be performed. If not, raises a LimitExceededError + Checks if the user has ratelimiting disabled in the database by looking + for null/zero values in the `ratelimit_override` table. (Non-zero + values aren't honoured, as they're specific to the event sending + ratelimiter, rather than all ratelimiters) + Args: - key: An arbitrary key used to classify an action + requester: The requester that is doing the action, if any. Used to check for + if the user has ratelimits disabled. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. @@ -201,7 +210,8 @@ class Ratelimiter: """ time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() - allowed, time_allowed = self.can_do_action( + allowed, time_allowed = await self.can_do_action( + requester, key, rate_hz=rate_hz, burst_count=burst_count, diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index de2cc15d33..87038d436d 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py
@@ -57,7 +57,7 @@ class RoomVersion: state_res = attr.ib(type=int) # one of the StateResolutionVersions enforce_key_validity = attr.ib(type=bool) - # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules + # Before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules special_case_aliases_auth = attr.ib(type=bool) # Strictly enforce canonicaljson, do not allow: # * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] @@ -69,6 +69,8 @@ class RoomVersion: limit_notifications_power_levels = attr.ib(type=bool) # MSC2174/MSC2176: Apply updated redaction rules algorithm. msc2176_redaction_rules = attr.ib(type=bool) + # MSC3083: Support the 'restricted' join_rule. + msc3083_join_rules = attr.ib(type=bool) class RoomVersions: @@ -82,6 +84,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + msc3083_join_rules=False, ) V2 = RoomVersion( "2", @@ -93,6 +96,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + msc3083_join_rules=False, ) V3 = RoomVersion( "3", @@ -104,6 +108,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + msc3083_join_rules=False, ) V4 = RoomVersion( "4", @@ -115,6 +120,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + msc3083_join_rules=False, ) V5 = RoomVersion( "5", @@ -126,6 +132,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + msc3083_join_rules=False, ) V6 = RoomVersion( "6", @@ -137,6 +144,7 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=False, + msc3083_join_rules=False, ) MSC2176 = RoomVersion( "org.matrix.msc2176", @@ -148,6 +156,19 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=True, + msc3083_join_rules=False, + ) + MSC3083 = RoomVersion( + "org.matrix.msc3083", + RoomDisposition.UNSTABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + msc2176_redaction_rules=False, + msc3083_join_rules=True, ) @@ -162,4 +183,5 @@ KNOWN_ROOM_VERSIONS = { RoomVersions.V6, RoomVersions.MSC2176, ) + # Note that we do not include MSC3083 here unless it is enabled in the config. } # type: Dict[str, RoomVersion] diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index 4a9b0129c3..d1a2cd5e19 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py
@@ -22,7 +22,9 @@ logger = logging.getLogger(__name__) try: python_dependencies.check_requirements() except python_dependencies.DependencyException as e: - sys.stderr.writelines(e.message) + sys.stderr.writelines( + e.message # noqa: B306, DependencyException.message is a property + ) sys.exit(1) diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 43b1f1e94b..3912c8994c 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py
@@ -21,8 +21,10 @@ import signal import socket import sys import traceback +import warnings from typing import Awaitable, Callable, Iterable +from cryptography.utils import CryptographyDeprecationWarning from typing_extensions import NoReturn from twisted.internet import defer, error, reactor @@ -195,6 +197,25 @@ def listen_metrics(bind_addresses, port): start_http_server(port, addr=host, registry=RegistryProxy) +def listen_manhole(bind_addresses: Iterable[str], port: int, manhole_globals: dict): + # twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing + # warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so + # suppress the warning for now. + warnings.filterwarnings( + action="ignore", + category=CryptographyDeprecationWarning, + message="int_from_bytes is deprecated", + ) + + from synapse.util.manhole import manhole + + listen_tcp( + bind_addresses, + port, + manhole(username="matrix", password="rabbithole", globals=manhole_globals), + ) + + def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): """ Create a TCP socket for a port and several addresses diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 274d582d07..d1c2079233 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -147,7 +147,6 @@ from synapse.storage.databases.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.generic_worker") @@ -282,6 +281,7 @@ class GenericWorkerPresence(BasePresenceHandler): self.hs = hs self.is_mine_id = hs.is_mine_id + self.presence_router = hs.get_presence_router() self._presence_enabled = hs.config.use_presence # The number of ongoing syncs on this process, by user id. @@ -302,6 +302,8 @@ class GenericWorkerPresence(BasePresenceHandler): self.send_stop_syncing, UPDATE_SYNCING_USERS_MS ) + self._busy_presence_enabled = hs.config.experimental.msc3026_enabled + hs.get_reactor().addSystemEventTrigger( "before", "shutdown", @@ -394,7 +396,7 @@ class GenericWorkerPresence(BasePresenceHandler): return _user_syncing() async def notify_from_replication(self, states, stream_id): - parties = await get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, self.presence_router, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -439,8 +441,12 @@ class GenericWorkerPresence(BasePresenceHandler): PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE, + PresenceState.BUSY, ) - if presence not in valid_presence: + + if presence not in valid_presence or ( + presence == PresenceState.BUSY and not self._busy_presence_enabled + ): raise SynapseError(400, "Invalid presence state") user_id = target_user.to_string() @@ -634,12 +640,8 @@ class GenericWorkerServer(HomeServer): if listener.type == "http": self._listen_http(listener) elif listener.type == "manhole": - _base.listen_tcp( - listener.bind_addresses, - listener.port, - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), + _base.listen_manhole( + listener.bind_addresses, listener.port, manhole_globals={"hs": self} ) elif listener.type == "metrics": if not self.get_config().enable_metrics: @@ -786,13 +788,6 @@ class FederationSenderHandler: self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") - def on_start(self): - # There may be some events that are persisted but haven't been sent, - # so send them now. - self.federation_sender.notify_new_events( - self.store.get_room_max_stream_ordering() - ) - def wake_destination(self, server: str): self.federation_sender.wake_destination(server) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 244657cb88..3bfe9d507f 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py
@@ -67,7 +67,6 @@ from synapse.storage import DataStore from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole from synapse.util.module_loader import load_module from synapse.util.versionstring import get_version_string @@ -288,12 +287,8 @@ class SynapseHomeServer(HomeServer): if listener.type == "http": self._listening_services.extend(self._listener_http(config, listener)) elif listener.type == "manhole": - listen_tcp( - listener.bind_addresses, - listener.port, - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), + _base.listen_manhole( + listener.bind_addresses, listener.port, manhole_globals={"hs": self} ) elif listener.type == "replication": services = listen_tcp( diff --git a/synapse/config/api.py b/synapse/config/api.py
index 74cd53a8ed..55c038c0c4 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py
@@ -1,4 +1,4 @@ -# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2015-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,38 +12,131 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typing import Iterable + from synapse.api.constants import EventTypes +from synapse.config._base import Config, ConfigError +from synapse.config._util import validate_config +from synapse.types import JsonDict -from ._base import Config +logger = logging.getLogger(__name__) class ApiConfig(Config): section = "api" - def read_config(self, config, **kwargs): - self.room_invite_state_types = config.get( - "room_invite_state_types", - [ - EventTypes.JoinRules, - EventTypes.CanonicalAlias, - EventTypes.RoomAvatar, - EventTypes.RoomEncryption, - EventTypes.Name, - ], + def read_config(self, config: JsonDict, **kwargs): + validate_config(_MAIN_SCHEMA, config, ()) + self.room_prejoin_state = list(self._get_prejoin_state_types(config)) + + def generate_config_section(cls, **kwargs) -> str: + formatted_default_state_types = "\n".join( + " # - %s" % (t,) for t in _DEFAULT_PREJOIN_STATE_TYPES ) - def generate_config_section(cls, **kwargs): return """\ ## API Configuration ## - # A list of event types that will be included in the room_invite_state + # Controls for the state that is shared with users who receive an invite + # to a room # - #room_invite_state_types: - # - "{JoinRules}" - # - "{CanonicalAlias}" - # - "{RoomAvatar}" - # - "{RoomEncryption}" - # - "{Name}" - """.format( - **vars(EventTypes) - ) + room_prejoin_state: + # By default, the following state event types are shared with users who + # receive invites to the room: + # +%(formatted_default_state_types)s + # + # Uncomment the following to disable these defaults (so that only the event + # types listed in 'additional_event_types' are shared). Defaults to 'false'. + # + #disable_default_event_types: true + + # Additional state event types to share with users when they are invited + # to a room. + # + # By default, this list is empty (so only the default event types are shared). + # + #additional_event_types: + # - org.example.custom.event.type + """ % { + "formatted_default_state_types": formatted_default_state_types + } + + def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]: + """Get the event types to include in the prejoin state + + Parses the config and returns an iterable of the event types to be included. + """ + room_prejoin_state_config = config.get("room_prejoin_state") or {} + + # backwards-compatibility support for room_invite_state_types + if "room_invite_state_types" in config: + # if both "room_invite_state_types" and "room_prejoin_state" are set, then + # we don't really know what to do. + if room_prejoin_state_config: + raise ConfigError( + "Can't specify both 'room_invite_state_types' and 'room_prejoin_state' " + "in config" + ) + + logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING) + + yield from config["room_invite_state_types"] + return + + if not room_prejoin_state_config.get("disable_default_event_types"): + yield from _DEFAULT_PREJOIN_STATE_TYPES + + if self.spaces_enabled: + # MSC1772 suggests adding m.room.create to the prejoin state + yield EventTypes.Create + + yield from room_prejoin_state_config.get("additional_event_types", []) + + +_ROOM_INVITE_STATE_TYPES_WARNING = """\ +WARNING: The 'room_invite_state_types' configuration setting is now deprecated, +and replaced with 'room_prejoin_state'. New features may not work correctly +unless 'room_invite_state_types' is removed. See the sample configuration file for +details of 'room_prejoin_state'. +-------------------------------------------------------------------------------- +""" + +_DEFAULT_PREJOIN_STATE_TYPES = [ + EventTypes.JoinRules, + EventTypes.CanonicalAlias, + EventTypes.RoomAvatar, + EventTypes.RoomEncryption, + EventTypes.Name, +] + + +# room_prejoin_state can either be None (as it is in the default config), or +# an object containing other config settings +_ROOM_PREJOIN_STATE_CONFIG_SCHEMA = { + "oneOf": [ + { + "type": "object", + "properties": { + "disable_default_event_types": {"type": "boolean"}, + "additional_event_types": { + "type": "array", + "items": {"type": "string"}, + }, + }, + }, + {"type": "null"}, + ] +} + +# the legacy room_invite_state_types setting +_ROOM_INVITE_STATE_TYPES_SCHEMA = {"type": "array", "items": {"type": "string"}} + +_MAIN_SCHEMA = { + "type": "object", + "properties": { + "room_prejoin_state": _ROOM_PREJOIN_STATE_CONFIG_SCHEMA, + "room_invite_state_types": _ROOM_INVITE_STATE_TYPES_SCHEMA, + }, +} diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 8e03f14005..4e8abbf88a 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py
@@ -24,7 +24,7 @@ from ._base import Config, ConfigError _CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" # Map from canonicalised cache name to cache. -_CACHES = {} +_CACHES = {} # type: Dict[str, Callable[[float], None]] # a lock on the contents of _CACHES _CACHES_LOCK = threading.Lock() @@ -59,7 +59,9 @@ def _canonicalise_cache_name(cache_name: str) -> str: return cache_name.lower() -def add_resizable_cache(cache_name: str, cache_resize_callback: Callable): +def add_resizable_cache( + cache_name: str, cache_resize_callback: Callable[[float], None] +): """Register a cache that's size can dynamically change Args: diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index b1c1c51e4d..eb96ecda74 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py
@@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config._base import Config from synapse.types import JsonDict @@ -27,3 +28,11 @@ class ExperimentalConfig(Config): # MSC2858 (multiple SSO identity providers) self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool + + # Spaces (MSC1772, MSC2946, MSC3083, etc) + self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool + if self.spaces_enabled: + KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083 + + # MSC3026 (busy presence state) + self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool diff --git a/synapse/config/key.py b/synapse/config/key.py
index de964dff13..350ff1d665 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py
@@ -404,7 +404,11 @@ def _parse_key_servers(key_servers, federation_verify_certificates): try: jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA) except jsonschema.ValidationError as e: - raise ConfigError("Unable to parse 'trusted_key_servers': " + e.message) + raise ConfigError( + "Unable to parse 'trusted_key_servers': {}".format( + e.message # noqa: B306, jsonschema.ValidationError.message is a valid attribute + ) + ) for server in key_servers: server_name = server["server_name"] diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index dfd27e1523..2b289f4208 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py
@@ -56,7 +56,9 @@ class MetricsConfig(Config): try: check_requirements("sentry") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) self.sentry_dsn = config["sentry"].get("dsn") if not self.sentry_dsn: diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 2bfb537c15..05733ec41d 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py
@@ -15,11 +15,12 @@ # limitations under the License. from collections import Counter -from typing import Iterable, Mapping, Optional, Tuple, Type +from typing import Iterable, List, Mapping, Optional, Tuple, Type import attr from synapse.config._util import validate_config +from synapse.config.sso import SsoAttributeRequirement from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module @@ -41,7 +42,9 @@ class OIDCConfig(Config): try: check_requirements("oidc") except DependencyException as e: - raise ConfigError(e.message) from e + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) from e # check we don't have any duplicate idp_ids now. (The SSO handler will also # check for duplicates when the REST listeners get registered, but that happens @@ -76,6 +79,9 @@ class OIDCConfig(Config): # Note that, if this is changed, users authenticating via that provider # will no longer be recognised as the same user! # + # (Use "oidc" here if you are migrating from an old "oidc_config" + # configuration.) + # # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # @@ -191,6 +197,24 @@ class OIDCConfig(Config): # which is set to the claims returned by the UserInfo Endpoint and/or # in the ID Token. # + # It is possible to configure Synapse to only allow logins if certain attributes + # match particular values in the OIDC userinfo. The requirements can be listed under + # `attribute_requirements` as shown below. All of the listed attributes must + # match for the login to be permitted. Additional attributes can be added to + # userinfo by expanding the `scopes` section of the OIDC config to retrieve + # additional information from the OIDC provider. + # + # If the OIDC claim is a list, then the attribute must match any value in the list. + # Otherwise, it must exactly match the value of the claim. Using the example + # below, the `family_name` claim MUST be "Stephensson", but the `groups` + # claim MUST contain "admin". + # + # attribute_requirements: + # - attribute: family_name + # value: "Stephensson" + # - attribute: groups + # value: "admin" + # # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md # for information on how to configure these options. # @@ -223,34 +247,9 @@ class OIDCConfig(Config): # localpart_template: "{{{{ user.login }}}}" # display_name_template: "{{{{ user.name }}}}" # email_template: "{{{{ user.email }}}}" - - # For use with Keycloak - # - #- idp_id: keycloak - # idp_name: Keycloak - # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name" - # client_id: "synapse" - # client_secret: "copy secret generated in Keycloak UI" - # scopes: ["openid", "profile"] - - # For use with Github - # - #- idp_id: github - # idp_name: Github - # idp_brand: github - # discover: false - # issuer: "https://github.com/" - # client_id: "your-client-id" # TO BE FILLED - # client_secret: "your-client-secret" # TO BE FILLED - # authorization_endpoint: "https://github.com/login/oauth/authorize" - # token_endpoint: "https://github.com/login/oauth/access_token" - # userinfo_endpoint: "https://api.github.com/user" - # scopes: ["read:user"] - # user_mapping_provider: - # config: - # subject_claim: "id" - # localpart_template: "{{{{ user.login }}}}" - # display_name_template: "{{{{ user.name }}}}" + # attribute_requirements: + # - attribute: userGroup + # value: "synapseUsers" """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) @@ -329,6 +328,10 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { }, "allow_existing_users": {"type": "boolean"}, "user_mapping_provider": {"type": ["object", "null"]}, + "attribute_requirements": { + "type": "array", + "items": SsoAttributeRequirement.JSON_SCHEMA, + }, }, } @@ -465,6 +468,11 @@ def _parse_oidc_config_dict( jwt_header=client_secret_jwt_key_config["jwt_header"], jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}), ) + # parse attribute_requirements from config (list of dicts) into a list of SsoAttributeRequirement + attribute_requirements = [ + SsoAttributeRequirement(**x) + for x in oidc_config.get("attribute_requirements", []) + ] return OidcProviderConfig( idp_id=idp_id, @@ -488,6 +496,7 @@ def _parse_oidc_config_dict( allow_existing_users=oidc_config.get("allow_existing_users", False), user_mapping_provider_class=user_mapping_provider_class, user_mapping_provider_config=user_mapping_provider_config, + attribute_requirements=attribute_requirements, ) @@ -577,3 +586,6 @@ class OidcProviderConfig: # the config of the user mapping provider user_mapping_provider_config = attr.ib() + + # required attributes to require in userinfo to allow login/registration + attribute_requirements = attr.ib(type=List[SsoAttributeRequirement]) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 847d25122c..3f3997f4e5 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py
@@ -95,11 +95,11 @@ class RatelimitConfig(Config): self.rc_joins_local = RateLimitConfig( config.get("rc_joins", {}).get("local", {}), - defaults={"per_second": 0.1, "burst_count": 3}, + defaults={"per_second": 0.1, "burst_count": 10}, ) self.rc_joins_remote = RateLimitConfig( config.get("rc_joins", {}).get("remote", {}), - defaults={"per_second": 0.01, "burst_count": 3}, + defaults={"per_second": 0.01, "burst_count": 10}, ) # Ratelimit cross-user key requests: @@ -187,10 +187,10 @@ class RatelimitConfig(Config): #rc_joins: # local: # per_second: 0.1 - # burst_count: 3 + # burst_count: 10 # remote: # per_second: 0.01 - # burst_count: 3 + # burst_count: 10 # #rc_3pid_validation: # per_second: 0.003 diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index ead007ba5a..f27d1e14ac 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py
@@ -298,9 +298,9 @@ class RegistrationConfig(Config): # #allowed_local_3pids: # - medium: email - # pattern: '.*@matrix\\.org' + # pattern: '^[^@]+@matrix\\.org$' # - medium: email - # pattern: '.*@vector\\.im' + # pattern: '^[^@]+@vector\\.im$' # - medium: msisdn # pattern: '\\+44' diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 69d9de5a43..061c4ec83f 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py
@@ -176,7 +176,9 @@ class ContentRepositoryConfig(Config): check_requirements("url_preview") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) if "url_preview_ip_range_blacklist" not in config: raise ConfigError( diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 4b494f217f..6db9cb5ced 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py
@@ -76,7 +76,9 @@ class SAML2Config(Config): try: check_requirements("saml2") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) self.saml2_enabled = True diff --git a/synapse/config/server.py b/synapse/config/server.py
index 5f8910b6e1..8decc9d10d 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -27,6 +27,7 @@ import yaml from netaddr import AddrFormatError, IPNetwork, IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_server_name from ._base import Config, ConfigError @@ -238,7 +239,20 @@ class ServerConfig(Config): self.public_baseurl = config.get("public_baseurl") # Whether to enable user presence. - self.use_presence = config.get("use_presence", True) + presence_config = config.get("presence") or {} + self.use_presence = presence_config.get("enabled") + if self.use_presence is None: + self.use_presence = config.get("use_presence", True) + + # Custom presence router module + self.presence_router_module_class = None + self.presence_router_config = None + presence_router_config = presence_config.get("presence_router") + if presence_router_config: + ( + self.presence_router_module_class, + self.presence_router_config, + ) = load_module(presence_router_config, ("presence", "presence_router")) # Whether to update the user directory or not. This should be set to # false only if we are updating the user directory in a worker @@ -834,9 +848,28 @@ class ServerConfig(Config): # #soft_file_limit: 0 - # Set to false to disable presence tracking on this homeserver. + # Presence tracking allows users to see the state (e.g online/offline) + # of other local and remote users. # - #use_presence: false + presence: + # Uncomment to disable presence tracking on this homeserver. This option + # replaces the previous top-level 'use_presence' option. + # + #enabled: false + + # Presence routers are third-party modules that can specify additional logic + # to where presence updates from users are routed. + # + presence_router: + # The custom module's class. Uncomment to use a custom presence router module. + # + #module: "my_custom_router.PresenceRouter" + + # Configuration options of the custom module. Refer to your module's + # documentation for available options. + # + #config: + # example_option: 'something' # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. Defaults to diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 0c1a854f09..727a1e7008 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py
@@ -39,7 +39,9 @@ class TracerConfig(Config): try: check_requirements("opentracing") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) # The tracer is enabled so sanitize the config diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 14b21796d9..c644b4dfc5 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py
@@ -191,7 +191,7 @@ def _context_info_cb(ssl_connection, where, ret): # ... we further assume that SSLClientConnectionCreator has set the # '_synapse_tls_verifier' attribute to a ConnectionVerifier object. tls_protocol._synapse_tls_verifier.verify_context_info_cb(ssl_connection, where) - except: # noqa: E722, taken from the twisted implementation + except BaseException: # taken from the twisted implementation logger.exception("Error during info_callback") f = Failure() tls_protocol.failVerification(f) @@ -219,7 +219,7 @@ class SSLClientConnectionCreator: # ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the # tls_protocol so that the SSL context's info callback has something to # call to do the cert verification. - setattr(tls_protocol, "_synapse_tls_verifier", self._verifier) + tls_protocol._synapse_tls_verifier = self._verifier return connection diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 902128a23c..d5fb51513b 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py
@@ -57,7 +57,7 @@ from synapse.util.metrics import Measure from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 91ad5b3d3c..9863953f5c 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py
@@ -162,7 +162,7 @@ def check( logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) if event.type == EventTypes.Member: - _is_membership_change_allowed(event, auth_events) + _is_membership_change_allowed(room_version_obj, event, auth_events) logger.debug("Allowing! %s", event) return @@ -220,8 +220,19 @@ def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool: def _is_membership_change_allowed( - event: EventBase, auth_events: StateMap[EventBase] + room_version: RoomVersion, event: EventBase, auth_events: StateMap[EventBase] ) -> None: + """ + Confirms that the event which changes membership is an allowed change. + + Args: + room_version: The version of the room. + event: The event to check. + auth_events: The current auth events of the room. + + Raises: + AuthError if the event is not allowed. + """ membership = event.content["membership"] # Check if this is the room creator joining: @@ -315,14 +326,19 @@ def _is_membership_change_allowed( if user_level < invite_level: raise AuthError(403, "You don't have permission to invite users") elif Membership.JOIN == membership: - # Joins are valid iff caller == target and they were: - # invited: They are accepting the invitation - # joined: It's a NOOP + # Joins are valid iff caller == target and: + # * They are not banned. + # * They are accepting a previously sent invitation. + # * They are already joined (it's a NOOP). + # * The room is public or restricted. if event.user_id != target_user_id: raise AuthError(403, "Cannot force another user to join.") elif target_banned: raise AuthError(403, "You are banned from this room") - elif join_rule == JoinRules.PUBLIC: + elif join_rule == JoinRules.PUBLIC or ( + room_version.msc3083_join_rules + and join_rule == JoinRules.MSC3083_RESTRICTED + ): pass elif join_rule == JoinRules.INVITE: if not caller_in_room and not caller_invited: diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 3ec4120f85..8f6b955d17 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py
@@ -98,7 +98,7 @@ class DefaultDictProperty(DictProperty): class _EventInternalMetadata: - __slots__ = ["_dict", "stream_ordering"] + __slots__ = ["_dict", "stream_ordering", "outlier"] def __init__(self, internal_metadata_dict: JsonDict): # we have to copy the dict, because it turns out that the same dict is @@ -108,7 +108,10 @@ class _EventInternalMetadata: # the stream ordering of this event. None, until it has been persisted. self.stream_ordering = None # type: Optional[int] - outlier = DictProperty("outlier") # type: bool + # whether this event is an outlier (ie, whether we have the state at that point + # in the DAG) + self.outlier = False + out_of_band_membership = DictProperty("out_of_band_membership") # type: bool send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str recheck_redaction = DictProperty("recheck_redaction") # type: bool @@ -129,7 +132,7 @@ class _EventInternalMetadata: return dict(self._dict) def is_outlier(self) -> bool: - return self._dict.get("outlier", False) + return self.outlier def is_out_of_band_membership(self) -> bool: """Whether this is an out of band membership, like an invite or an invite diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py new file mode 100644
index 0000000000..24cd389d80 --- /dev/null +++ b/synapse/events/presence_router.py
@@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, Iterable, Set, Union + +from synapse.api.presence import UserPresenceState + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class PresenceRouter: + """ + A module that the homeserver will call upon to help route user presence updates to + additional destinations. If a custom presence router is configured, calls will be + passed to that instead. + """ + + ALL_USERS = "ALL" + + def __init__(self, hs: "HomeServer"): + self.custom_presence_router = None + + # Check whether a custom presence router module has been configured + if hs.config.presence_router_module_class: + # Initialise the module + self.custom_presence_router = hs.config.presence_router_module_class( + config=hs.config.presence_router_config, module_api=hs.get_module_api() + ) + + # Ensure the module has implemented the required methods + required_methods = ["get_users_for_states", "get_interested_users"] + for method_name in required_methods: + if not hasattr(self.custom_presence_router, method_name): + raise Exception( + "PresenceRouter module '%s' must implement all required methods: %s" + % ( + hs.config.presence_router_module_class.__name__, + ", ".join(required_methods), + ) + ) + + async def get_users_for_states( + self, + state_updates: Iterable[UserPresenceState], + ) -> Dict[str, Set[UserPresenceState]]: + """ + Given an iterable of user presence updates, determine where each one + needs to go. + + Args: + state_updates: An iterable of user presence state updates. + + Returns: + A dictionary of user_id -> set of UserPresenceState, indicating which + presence updates each user should receive. + """ + if self.custom_presence_router is not None: + # Ask the custom module + return await self.custom_presence_router.get_users_for_states( + state_updates=state_updates + ) + + # Don't include any extra destinations for presence updates + return {} + + async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]: + """ + Retrieve a list of users that `user_id` is interested in receiving the + presence of. This will be in addition to those they share a room with. + Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate + that this user should receive all incoming local and remote presence updates. + + Note that this method will only be called for local users, but can return users + that are local or remote. + + Args: + user_id: A user requesting presence updates. + + Returns: + A set of user IDs to return presence updates for, or ALL_USERS to return all + known updates. + """ + if self.custom_presence_router is not None: + # Ask the custom module for interested users + return await self.custom_presence_router.get_interested_users( + user_id=user_id + ) + + # A custom presence router is not defined. + # Don't report any additional interested users + return set() diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 02bce8b5c9..9767d23940 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py
@@ -13,12 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Union +from typing import TYPE_CHECKING, Union from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import Requester, StateMap +if TYPE_CHECKING: + from synapse.server import HomeServer + class ThirdPartyEventRules: """Allows server admins to provide a Python module implementing an extra @@ -28,7 +31,7 @@ class ThirdPartyEventRules: behaviours. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.third_party_rules = None self.store = hs.get_datastore() @@ -95,10 +98,9 @@ class ThirdPartyEventRules: if self.third_party_rules is None: return True - ret = await self.third_party_rules.on_create_room( + return await self.third_party_rules.on_create_room( requester, config, is_requester_admin ) - return ret async def check_threepid_can_be_invited( self, medium: str, address: str, room_id: str @@ -119,10 +121,9 @@ class ThirdPartyEventRules: state_events = await self._get_state_map_for_room(room_id) - ret = await self.third_party_rules.check_threepid_can_be_invited( + return await self.third_party_rules.check_threepid_can_be_invited( medium, address, state_events ) - return ret async def check_visibility_can_be_modified( self, room_id: str, new_visibility: str @@ -143,7 +144,7 @@ class ThirdPartyEventRules: check_func = getattr( self.third_party_rules, "check_visibility_can_be_modified", None ) - if not check_func or not isinstance(check_func, Callable): + if not check_func or not callable(check_func): return True state_events = await self._get_state_map_for_room(room_id) diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 7ca5c9940a..0f8a3b5ad8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py
@@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.util.async_helpers import yieldable_gather_results +from synapse.util.frozenutils import unfreeze from . import EventBase @@ -54,6 +55,8 @@ def prune_event(event: EventBase) -> EventBase: event.internal_metadata.stream_ordering ) + pruned_event.internal_metadata.outlier = event.internal_metadata.outlier + # Mark the event as redacted pruned_event.internal_metadata.redacted = True @@ -400,10 +403,19 @@ class EventClientSerializer: # If there is an edit replace the content, preserving existing # relations. + # Ensure we take copies of the edit content, otherwise we risk modifying + # the original event. + edit_content = edit.content.copy() + + # Unfreeze the event content if necessary, so that we may modify it below + edit_content = unfreeze(edit_content) + serialized_event["content"] = edit_content.get("m.new_content", {}) + + # Check for existing relations relations = event.content.get("m.relates_to") - serialized_event["content"] = edit.content.get("m.new_content", {}) if relations: - serialized_event["content"]["m.relates_to"] = relations + # Keep the relations, ensuring we use a dict copy of the original + serialized_event["content"]["m.relates_to"] = relations.copy() else: serialized_event["content"].pop("m.relates_to", None) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index bee81fc019..55533d7501 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -27,11 +27,13 @@ from typing import ( List, Mapping, Optional, + Sequence, Tuple, TypeVar, Union, ) +import attr from prometheus_client import Counter from twisted.internet import defer @@ -62,7 +64,7 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -100,7 +102,7 @@ class FederationClient(FederationBase): max_len=1000, expiry_ms=120 * 1000, reset_expiry_on_get=False, - ) + ) # type: ExpiringCache[str, EventBase] def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" @@ -455,6 +457,7 @@ class FederationClient(FederationBase): description: str, destinations: Iterable[str], callback: Callable[[str], Awaitable[T]], + failover_on_unknown_endpoint: bool = False, ) -> T: """Try an operation on a series of servers, until it succeeds @@ -474,6 +477,10 @@ class FederationClient(FederationBase): next server tried. Normally the stacktrace is logged but this is suppressed if the exception is an InvalidResponseError. + failover_on_unknown_endpoint: if True, we will try other servers if it looks + like a server doesn't support the endpoint. This is typically useful + if the endpoint in question is new or experimental. + Returns: The result of callback, if it succeeds @@ -493,16 +500,31 @@ class FederationClient(FederationBase): except UnsupportedRoomVersionError: raise except HttpResponseException as e: - if not 500 <= e.code < 600: - raise e.to_synapse_error() - else: - logger.warning( - "Failed to %s via %s: %i %s", - description, - destination, - e.code, - e.args[0], - ) + synapse_error = e.to_synapse_error() + failover = False + + if 500 <= e.code < 600: + failover = True + + elif failover_on_unknown_endpoint: + # there is no good way to detect an "unknown" endpoint. Dendrite + # returns a 404 (with no body); synapse returns a 400 + # with M_UNRECOGNISED. + if e.code == 404 or ( + e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED + ): + failover = True + + if not failover: + raise synapse_error from e + + logger.warning( + "Failed to %s via %s: %i %s", + description, + destination, + e.code, + e.args[0], + ) except Exception: logger.warning( "Failed to %s via %s", description, destination, exc_info=True @@ -1042,3 +1064,141 @@ class FederationClient(FederationBase): # If we don't manage to find it, return None. It's not an error if a # server doesn't give it to us. return None + + async def get_space_summary( + self, + destinations: Iterable[str], + room_id: str, + suggested_only: bool, + max_rooms_per_space: Optional[int], + exclude_rooms: List[str], + ) -> "FederationSpaceSummaryResult": + """ + Call other servers to get a summary of the given space + + + Args: + destinations: The remote servers. We will try them in turn, omitting any + that have been blacklisted. + + room_id: ID of the space to be queried + + suggested_only: If true, ask the remote server to only return children + with the "suggested" flag set + + max_rooms_per_space: A limit on the number of children to return for each + space + + exclude_rooms: A list of room IDs to tell the remote server to skip + + Returns: + a parsed FederationSpaceSummaryResult + + Raises: + SynapseError if we were unable to get a valid summary from any of the + remote servers + """ + + async def send_request(destination: str) -> FederationSpaceSummaryResult: + res = await self.transport_layer.get_space_summary( + destination=destination, + room_id=room_id, + suggested_only=suggested_only, + max_rooms_per_space=max_rooms_per_space, + exclude_rooms=exclude_rooms, + ) + + try: + return FederationSpaceSummaryResult.from_json_dict(res) + except ValueError as e: + raise InvalidResponseError(str(e)) + + return await self._try_destination_list( + "fetch space summary", + destinations, + send_request, + failover_on_unknown_endpoint=True, + ) + + +@attr.s(frozen=True, slots=True) +class FederationSpaceSummaryEventResult: + """Represents a single event in the result of a successful get_space_summary call. + + It's essentially just a serialised event object, but we do a bit of parsing and + validation in `from_json_dict` and store some of the validated properties in + object attributes. + """ + + event_type = attr.ib(type=str) + state_key = attr.ib(type=str) + via = attr.ib(type=Sequence[str]) + + # the raw data, including the above keys + data = attr.ib(type=JsonDict) + + @classmethod + def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": + """Parse an event within the result of a /spaces/ request + + Args: + d: json object to be parsed + + Raises: + ValueError if d is not a valid event + """ + + event_type = d.get("type") + if not isinstance(event_type, str): + raise ValueError("Invalid event: 'event_type' must be a str") + + state_key = d.get("state_key") + if not isinstance(state_key, str): + raise ValueError("Invalid event: 'state_key' must be a str") + + content = d.get("content") + if not isinstance(content, dict): + raise ValueError("Invalid event: 'content' must be a dict") + + via = content.get("via") + if not isinstance(via, Sequence): + raise ValueError("Invalid event: 'via' must be a list") + if any(not isinstance(v, str) for v in via): + raise ValueError("Invalid event: 'via' must be a list of strings") + + return cls(event_type, state_key, via, d) + + +@attr.s(frozen=True, slots=True) +class FederationSpaceSummaryResult: + """Represents the data returned by a successful get_space_summary call.""" + + rooms = attr.ib(type=Sequence[JsonDict]) + events = attr.ib(type=Sequence[FederationSpaceSummaryEventResult]) + + @classmethod + def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult": + """Parse the result of a /spaces/ request + + Args: + d: json object to be parsed + + Raises: + ValueError if d is not a valid /spaces/ response + """ + rooms = d.get("rooms") + if not isinstance(rooms, Sequence): + raise ValueError("'rooms' must be a list") + if any(not isinstance(r, dict) for r in rooms): + raise ValueError("Invalid room in 'rooms' list") + + events = d.get("events") + if not isinstance(events, Sequence): + raise ValueError("'events' must be a list") + if any(not isinstance(e, dict) for e in events): + raise ValueError("Invalid event in 'events' list") + parsed_events = [ + FederationSpaceSummaryEventResult.from_json_dict(e) for e in events + ] + + return cls(rooms, parsed_events) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 9839d3d016..b9f8d966a6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -35,7 +35,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EduTypes, EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes from synapse.api.errors import ( AuthError, Codes, @@ -63,7 +63,7 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, ReplicationGetQueryRestServlet, ) -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -727,27 +727,6 @@ class FederationServer(FederationBase): if the event was unacceptable for any other reason (eg, too large, too many prev_events, couldn't find the prev_events) """ - # check that it's actually being sent from a valid destination to - # workaround bug #1753 in 0.18.5 and 0.18.6 - if origin != get_domain_from_id(pdu.sender): - # We continue to accept join events from any server; this is - # necessary for the federation join dance to work correctly. - # (When we join over federation, the "helper" server is - # responsible for sending out the join event, rather than the - # origin. See bug #1893. This is also true for some third party - # invites). - if not ( - pdu.type == "m.room.member" - and pdu.content - and pdu.content.get("membership", None) - in (Membership.JOIN, Membership.INVITE) - ): - logger.info( - "Discarding PDU %s from invalid origin %s", pdu.event_id, origin - ) - return - else: - logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point room_version = await self.store.get_room_version(pdu.room_id) @@ -760,22 +739,20 @@ class FederationServer(FederationBase): await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) - def __str__(self): + def __str__(self) -> str: return "<ReplicationLayer(%s)>" % self.server_name async def exchange_third_party_invite( self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict - ): - ret = await self.handler.exchange_third_party_invite( + ) -> None: + await self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed ) - return ret - async def on_exchange_third_party_invite_request(self, event_dict: Dict): - ret = await self.handler.on_exchange_third_party_invite_request(event_dict) - return ret + async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None: + await self.handler.on_exchange_third_party_invite_request(event_dict) - async def check_server_matches_acl(self, server_name: str, room_id: str): + async def check_server_matches_acl(self, server_name: str, room_id: str) -> None: """Check if the given server is allowed by the server ACLs in the room Args: @@ -891,6 +868,7 @@ class FederationHandlerRegistry: # A rate limiter for incoming room key requests per origin. self._room_key_request_rate_limiter = Ratelimiter( + store=hs.get_datastore(), clock=self.clock, rate_hz=self.config.rc_key_requests.per_second, burst_count=self.config.rc_key_requests.burst_count, @@ -898,7 +876,7 @@ class FederationHandlerRegistry: def register_edu_handler( self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] - ): + ) -> None: """Sets the handler callable that will be used to handle an incoming federation EDU of the given type. @@ -917,7 +895,7 @@ class FederationHandlerRegistry: def register_query_handler( self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] - ): + ) -> None: """Sets the handler callable that will be used to handle an incoming federation query of the given type. @@ -935,15 +913,17 @@ class FederationHandlerRegistry: self.query_handlers[query_type] = handler - def register_instance_for_edu(self, edu_type: str, instance_name: str): + def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None: """Register that the EDU handler is on a different instance than master.""" self._edu_type_to_instance[edu_type] = [instance_name] - def register_instances_for_edu(self, edu_type: str, instance_names: List[str]): + def register_instances_for_edu( + self, edu_type: str, instance_names: List[str] + ) -> None: """Register that the EDU handler is on multiple instances.""" self._edu_type_to_instance[edu_type] = instance_names - async def on_edu(self, edu_type: str, origin: str, content: dict): + async def on_edu(self, edu_type: str, origin: str, content: dict) -> None: if not self.config.use_presence and edu_type == EduTypes.Presence: return @@ -951,7 +931,9 @@ class FederationHandlerRegistry: # the limit, drop them. if ( edu_type == EduTypes.RoomKeyRequest - and not self._room_key_request_rate_limiter.can_do_action(origin) + and not await self._room_key_request_rate_limiter.can_do_action( + None, origin + ) ): return diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 3e993b428b..0c18c49abb 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py
@@ -31,25 +31,39 @@ Events are replicated via a separate events stream. import logging from collections import namedtuple -from typing import Dict, List, Tuple, Type +from typing import ( + TYPE_CHECKING, + Dict, + Hashable, + Iterable, + List, + Optional, + Sized, + Tuple, + Type, +) from sortedcontainers import SortedDict -from twisted.internet import defer - from synapse.api.presence import UserPresenceState +from synapse.federation.sender import AbstractFederationSender, FederationSender from synapse.metrics import LaterGauge +from synapse.replication.tcp.streams.federation import FederationStream +from synapse.types import JsonDict, ReadReceipt, RoomStreamToken from synapse.util.metrics import Measure from .units import Edu +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) -class FederationRemoteSendQueue: +class FederationRemoteSendQueue(AbstractFederationSender): """A drop in replacement for FederationSender""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() @@ -58,7 +72,7 @@ class FederationRemoteSendQueue: # We may have multiple federation sender instances, so we need to track # their positions separately. self._sender_instances = hs.config.worker.federation_shard_config.instances - self._sender_positions = {} + self._sender_positions = {} # type: Dict[str, int] # Pending presence map user_id -> UserPresenceState self.presence_map = {} # type: Dict[str, UserPresenceState] @@ -71,7 +85,7 @@ class FederationRemoteSendQueue: # Stream position -> (user_id, destinations) self.presence_destinations = ( SortedDict() - ) # type: SortedDict[int, Tuple[str, List[str]]] + ) # type: SortedDict[int, Tuple[str, Iterable[str]]] # (destination, key) -> EDU self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu] @@ -94,7 +108,7 @@ class FederationRemoteSendQueue: # we make a new function, so we need to make a new function so the inner # lambda binds to the queue rather than to the name of the queue which # changes. ARGH. - def register(name, queue): + def register(name: str, queue: Sized) -> None: LaterGauge( "synapse_federation_send_queue_%s_size" % (queue_name,), "", @@ -115,13 +129,13 @@ class FederationRemoteSendQueue: self.clock.looping_call(self._clear_queue, 30 * 1000) - def _next_pos(self): + def _next_pos(self) -> int: pos = self.pos self.pos += 1 self.pos_time[self.clock.time_msec()] = pos return pos - def _clear_queue(self): + def _clear_queue(self) -> None: """Clear the queues for anything older than N minutes""" FIVE_MINUTES_AGO = 5 * 60 * 1000 @@ -138,7 +152,7 @@ class FederationRemoteSendQueue: self._clear_queue_before_pos(position_to_delete) - def _clear_queue_before_pos(self, position_to_delete): + def _clear_queue_before_pos(self, position_to_delete: int) -> None: """Clear all the queues from before a given position""" with Measure(self.clock, "send_queue._clear"): # Delete things out of presence maps @@ -188,13 +202,18 @@ class FederationRemoteSendQueue: for key in keys[:i]: del self.edus[key] - def notify_new_events(self, max_token): + def notify_new_events(self, max_token: RoomStreamToken) -> None: """As per FederationSender""" - # We don't need to replicate this as it gets sent down a different - # stream. - pass + # This should never get called. + raise NotImplementedError() - def build_and_send_edu(self, destination, edu_type, content, key=None): + def build_and_send_edu( + self, + destination: str, + edu_type: str, + content: JsonDict, + key: Optional[Hashable] = None, + ) -> None: """As per FederationSender""" if destination == self.server_name: logger.info("Not sending EDU to ourselves") @@ -218,38 +237,39 @@ class FederationRemoteSendQueue: self.notifier.on_new_replication_data() - def send_read_receipt(self, receipt): + async def send_read_receipt(self, receipt: ReadReceipt) -> None: """As per FederationSender Args: - receipt (synapse.types.ReadReceipt): + receipt: """ # nothing to do here: the replication listener will handle it. - return defer.succeed(None) - def send_presence(self, states): + def send_presence(self, states: List[UserPresenceState]) -> None: """As per FederationSender Args: - states (list(UserPresenceState)) + states """ pos = self._next_pos() # We only want to send presence for our own users, so lets always just # filter here just in case. - local_states = list(filter(lambda s: self.is_mine_id(s.user_id), states)) + local_states = [s for s in states if self.is_mine_id(s.user_id)] self.presence_map.update({state.user_id: state for state in local_states}) self.presence_changed[pos] = [state.user_id for state in local_states] self.notifier.on_new_replication_data() - def send_presence_to_destinations(self, states, destinations): + def send_presence_to_destinations( + self, states: Iterable[UserPresenceState], destinations: Iterable[str] + ) -> None: """As per FederationSender Args: - states (list[UserPresenceState]) - destinations (list[str]) + states + destinations """ for state in states: pos = self._next_pos() @@ -258,15 +278,18 @@ class FederationRemoteSendQueue: self.notifier.on_new_replication_data() - def send_device_messages(self, destination): + def send_device_messages(self, destination: str) -> None: """As per FederationSender""" # We don't need to replicate this as it gets sent down a different # stream. - def get_current_token(self): + def wake_destination(self, server: str) -> None: + pass + + def get_current_token(self) -> int: return self.pos - 1 - def federation_ack(self, instance_name, token): + def federation_ack(self, instance_name: str, token: int) -> None: if self._sender_instances: # If we have configured multiple federation sender instances we need # to track their positions separately, and only clear the queue up @@ -504,13 +527,16 @@ ParsedFederationStreamData = namedtuple( ) -def process_rows_for_federation(transaction_queue, rows): +def process_rows_for_federation( + transaction_queue: FederationSender, + rows: List[FederationStream.FederationStreamRow], +) -> None: """Parse a list of rows from the federation stream and put them in the transaction queue ready for sending to the relevant homeservers. Args: - transaction_queue (FederationSender) - rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow)) + transaction_queue + rows """ # The federation stream contains a bunch of different types of diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 24ebc4b803..d821dcbf6a 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py
@@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging -from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple from prometheus_client import Counter from twisted.internet import defer -import synapse import synapse.metrics from synapse.api.presence import UserPresenceState from synapse.events import EventBase @@ -40,9 +40,13 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import ReadReceipt, RoomStreamToken +from synapse.types import JsonDict, ReadReceipt, RoomStreamToken from synapse.util.metrics import Measure, measure_func +if TYPE_CHECKING: + from synapse.events.presence_router import PresenceRouter + from synapse.server import HomeServer + logger = logging.getLogger(__name__) sent_pdus_destination_dist_count = Counter( @@ -65,8 +69,91 @@ CATCH_UP_STARTUP_DELAY_SEC = 15 CATCH_UP_STARTUP_INTERVAL_SEC = 5 -class FederationSender: - def __init__(self, hs: "synapse.server.HomeServer"): +class AbstractFederationSender(metaclass=abc.ABCMeta): + @abc.abstractmethod + def notify_new_events(self, max_token: RoomStreamToken) -> None: + """This gets called when we have some new events we might want to + send out to other servers. + """ + raise NotImplementedError() + + @abc.abstractmethod + async def send_read_receipt(self, receipt: ReadReceipt) -> None: + """Send a RR to any other servers in the room + + Args: + receipt: receipt to be sent + """ + raise NotImplementedError() + + @abc.abstractmethod + def send_presence(self, states: List[UserPresenceState]) -> None: + """Send the new presence states to the appropriate destinations. + + This actually queues up the presence states ready for sending and + triggers a background task to process them and send out the transactions. + """ + raise NotImplementedError() + + @abc.abstractmethod + def send_presence_to_destinations( + self, states: Iterable[UserPresenceState], destinations: Iterable[str] + ) -> None: + """Send the given presence states to the given destinations. + + Args: + destinations: + """ + raise NotImplementedError() + + @abc.abstractmethod + def build_and_send_edu( + self, + destination: str, + edu_type: str, + content: JsonDict, + key: Optional[Hashable] = None, + ) -> None: + """Construct an Edu object, and queue it for sending + + Args: + destination: name of server to send to + edu_type: type of EDU to send + content: content of EDU + key: clobbering key for this edu + """ + raise NotImplementedError() + + @abc.abstractmethod + def send_device_messages(self, destination: str) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def wake_destination(self, destination: str) -> None: + """Called when we want to retry sending transactions to a remote. + + This is mainly useful if the remote server has been down and we think it + might have come back. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_current_token(self) -> int: + raise NotImplementedError() + + @abc.abstractmethod + def federation_ack(self, instance_name: str, token: int) -> None: + raise NotImplementedError() + + @abc.abstractmethod + async def get_replication_rows( + self, instance_name: str, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple[int, Tuple]], int, bool]: + raise NotImplementedError() + + +class FederationSender(AbstractFederationSender): + def __init__(self, hs: "HomeServer"): self.hs = hs self.server_name = hs.hostname @@ -76,6 +163,7 @@ class FederationSender: self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id + self._presence_router = None # type: Optional[PresenceRouter] self._transaction_manager = TransactionManager(hs) self._instance_name = hs.get_instance_name() @@ -432,7 +520,7 @@ class FederationSender: queue.flush_read_receipts_for_room(room_id) @preserve_fn # the caller should not yield on this - async def send_presence(self, states: List[UserPresenceState]): + async def send_presence(self, states: List[UserPresenceState]) -> None: """Send the new presence states to the appropriate destinations. This actually queues up the presence states ready for sending and @@ -494,11 +582,26 @@ class FederationSender: self._get_per_destination_queue(destination).send_presence(states) @measure_func("txnqueue._process_presence") - async def _process_presence_inner(self, states: List[UserPresenceState]): + async def _process_presence_inner(self, states: List[UserPresenceState]) -> None: """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination """ - hosts_and_states = await get_interested_remotes(self.store, states, self.state) + # We pull the presence router here instead of __init__ + # to prevent a dependency cycle: + # + # AuthHandler -> Notifier -> FederationSender + # -> PresenceRouter -> ModuleApi -> AuthHandler + if self._presence_router is None: + self._presence_router = self.hs.get_presence_router() + + assert self._presence_router is not None + + hosts_and_states = await get_interested_remotes( + self.store, + self._presence_router, + states, + self.state, + ) for destinations, states in hosts_and_states: for destination in destinations: @@ -516,9 +619,9 @@ class FederationSender: self, destination: str, edu_type: str, - content: dict, + content: JsonDict, key: Optional[Hashable] = None, - ): + ) -> None: """Construct an Edu object, and queue it for sending Args: @@ -545,7 +648,7 @@ class FederationSender: self.send_edu(edu, key) - def send_edu(self, edu: Edu, key: Optional[Hashable]): + def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None: """Queue an EDU for sending Args: @@ -563,7 +666,7 @@ class FederationSender: else: queue.send_edu(edu) - def send_device_messages(self, destination: str): + def send_device_messages(self, destination: str) -> None: if destination == self.server_name: logger.warning("Not sending device update to ourselves") return @@ -575,7 +678,7 @@ class FederationSender: self._get_per_destination_queue(destination).attempt_new_transaction() - def wake_destination(self, destination: str): + def wake_destination(self, destination: str) -> None: """Called when we want to retry sending transactions to a remote. This is mainly useful if the remote server has been down and we think it @@ -599,6 +702,10 @@ class FederationSender: # to a worker. return 0 + def federation_ack(self, instance_name: str, token: int) -> None: + # It is not expected that this gets called on FederationSender. + raise NotImplementedError() + @staticmethod async def get_replication_rows( instance_name: str, from_token: int, to_token: int, target_row_count: int @@ -607,7 +714,7 @@ class FederationSender: # to a worker. return [], 0, False - async def _wake_destinations_needing_catchup(self): + async def _wake_destinations_needing_catchup(self) -> None: """ Wakes up destinations that need catch-up and are not currently being backed off from. @@ -627,16 +734,18 @@ class FederationSender: self._catchup_after_startup_timer = None break + last_processed = destinations_to_wake[-1] + destinations_to_wake = [ d for d in destinations_to_wake if self._federation_shard_config.should_handle(self._instance_name, d) ] - for last_processed in destinations_to_wake: + for destination in destinations_to_wake: logger.info( "Destination %s has outstanding catch-up, waking up.", last_processed, ) - self.wake_destination(last_processed) + self.wake_destination(destination) await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index cc0d765e5f..e9c8a9f20a 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,7 +15,7 @@ # limitations under the License. import datetime import logging -from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple import attr from prometheus_client import Counter @@ -29,6 +29,7 @@ from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.federation.units import Edu from synapse.handlers.presence import format_user_presence_state +from synapse.logging.opentracing import SynapseTags, set_tag from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ReadReceipt @@ -77,6 +78,7 @@ class PerDestinationQueue: self._transaction_manager = transaction_manager self._instance_name = hs.get_instance_name() self._federation_shard_config = hs.config.worker.federation_shard_config + self._state = hs.get_state_handler() self._should_send_on_this_instance = True if not self._federation_shard_config.should_handle( @@ -415,22 +417,97 @@ class PerDestinationQueue: "This should not happen." % event_ids ) - if logger.isEnabledFor(logging.INFO): - rooms = [p.room_id for p in catchup_pdus] - logger.info("Catching up rooms to %s: %r", self._destination, rooms) + # We send transactions with events from one room only, as its likely + # that the remote will have to do additional processing, which may + # take some time. It's better to give it small amounts of work + # rather than risk the request timing out and repeatedly being + # retried, and not making any progress. + # + # Note: `catchup_pdus` will have exactly one PDU per room. + for pdu in catchup_pdus: + # The PDU from the DB will be the last PDU in the room from + # *this server* that wasn't sent to the remote. However, other + # servers may have sent lots of events since then, and we want + # to try and tell the remote only about the *latest* events in + # the room. This is so that it doesn't get inundated by events + # from various parts of the DAG, which all need to be processed. + # + # Note: this does mean that in large rooms a server coming back + # online will get sent the same events from all the different + # servers, but the remote will correctly deduplicate them and + # handle it only once. + + # Step 1, fetch the current extremities + extrems = await self._store.get_prev_events_for_room(pdu.room_id) + + if pdu.event_id in extrems: + # If the event is in the extremities, then great! We can just + # use that without having to do further checks. + room_catchup_pdus = [pdu] + else: + # If not, fetch the extremities and figure out which we can + # send. + extrem_events = await self._store.get_events_as_list(extrems) + + new_pdus = [] + for p in extrem_events: + # We pulled this from the DB, so it'll be non-null + assert p.internal_metadata.stream_ordering + + # Filter out events that happened before the remote went + # offline + if ( + p.internal_metadata.stream_ordering + < self._last_successful_stream_ordering + ): + continue - await self._transaction_manager.send_new_transaction( - self._destination, catchup_pdus, [] - ) + # Filter out events where the server is not in the room, + # e.g. it may have left/been kicked. *Ideally* we'd pull + # out the kick and send that, but it's a rare edge case + # so we don't bother for now (the server that sent the + # kick should send it out if its online). + hosts = await self._state.get_hosts_in_room_at_events( + p.room_id, [p.event_id] + ) + if self._destination not in hosts: + continue - sent_transactions_counter.inc() - final_pdu = catchup_pdus[-1] - self._last_successful_stream_ordering = cast( - int, final_pdu.internal_metadata.stream_ordering - ) - await self._store.set_destination_last_successful_stream_ordering( - self._destination, self._last_successful_stream_ordering - ) + new_pdus.append(p) + + # If we've filtered out all the extremities, fall back to + # sending the original event. This should ensure that the + # server gets at least some of missed events (especially if + # the other sending servers are up). + if new_pdus: + room_catchup_pdus = new_pdus + else: + room_catchup_pdus = [pdu] + + logger.info( + "Catching up rooms to %s: %r", self._destination, pdu.room_id + ) + + await self._transaction_manager.send_new_transaction( + self._destination, room_catchup_pdus, [] + ) + + sent_transactions_counter.inc() + + # We pulled this from the DB, so it'll be non-null + assert pdu.internal_metadata.stream_ordering + + # Note that we mark the last successful stream ordering as that + # from the *original* PDU, rather than the PDU(s) we actually + # send. This is because we use it to mark our position in the + # queue of missed PDUs to process. + self._last_successful_stream_ordering = ( + pdu.internal_metadata.stream_ordering + ) + + await self._store.set_destination_last_successful_stream_ordering( + self._destination, self._last_successful_stream_ordering + ) def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: if not self._pending_rrs: @@ -481,6 +558,13 @@ class PerDestinationQueue: contents, stream_id = await self._store.get_new_device_msgs_for_remote( self._destination, last_device_stream_id, to_device_stream_id, limit ) + for content in contents: + message_id = content.get("message_id") + if not message_id: + continue + + set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) + edus = [ Edu( origin=self._server_name, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 10c4747f97..6aee47c431 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -16,7 +16,7 @@ import logging import urllib -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from synapse.api.constants import Membership from synapse.api.errors import Codes, HttpResponseException, SynapseError @@ -26,6 +26,7 @@ from synapse.api.urls import ( FEDERATION_V2_PREFIX, ) from synapse.logging.utils import log_function +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -978,6 +979,38 @@ class TransportLayerClient: return self.client.get_json(destination=destination, path=path) + async def get_space_summary( + self, + destination: str, + room_id: str, + suggested_only: bool, + max_rooms_per_space: Optional[int], + exclude_rooms: List[str], + ) -> JsonDict: + """ + Args: + destination: The remote server + room_id: The room ID to ask about. + suggested_only: if True, only suggested rooms will be returned + max_rooms_per_space: an optional limit to the number of children to be + returned per space + exclude_rooms: a list of any rooms we can skip + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc2946/spaces/%s", room_id + ) + + params = { + "suggested_only": suggested_only, + "exclude_rooms": exclude_rooms, + } + if max_rooms_per_space is not None: + params["max_rooms_per_space"] = max_rooms_per_space + + return await self.client.post_json( + destination=destination, path=path, data=params + ) + def _create_path(federation_prefix, path, *args): """ diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2cf935f38d..5ef0556ef7 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -18,7 +18,7 @@ import functools import logging import re -from typing import Optional, Tuple, Type +from typing import Container, Mapping, Optional, Sequence, Tuple, Type import synapse from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH @@ -29,7 +29,7 @@ from synapse.api.urls import ( FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) -from synapse.http.server import JsonResource +from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( parse_boolean_from_args, parse_integer_from_args, @@ -44,7 +44,8 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.server import HomeServer -from synapse.types import ThirdPartyInstanceID, get_domain_from_id +from synapse.types import JsonDict, ThirdPartyInstanceID, get_domain_from_id +from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.versionstring import get_version_string @@ -619,8 +620,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id): - content = await self.handler.on_exchange_third_party_invite_request(content) - return 200, content + await self.handler.on_exchange_third_party_invite_request(content) + return 200, {} class FederationClientKeysQueryServlet(BaseFederationServlet): @@ -1376,6 +1377,40 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): return 200, new_content +class FederationSpaceSummaryServlet(BaseFederationServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" + PATH = "/spaces/(?P<room_id>[^/]*)" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + suggested_only = content.get("suggested_only", False) + if not isinstance(suggested_only, bool): + raise SynapseError( + 400, "'suggested_only' must be a boolean", Codes.BAD_JSON + ) + + exclude_rooms = content.get("exclude_rooms", []) + if not isinstance(exclude_rooms, list) or any( + not isinstance(x, str) for x in exclude_rooms + ): + raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) + + max_rooms_per_space = content.get("max_rooms_per_space") + if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON + ) + + return 200, await self.handler.federation_space_summary( + room_id, suggested_only, max_rooms_per_space, exclude_rooms + ) + + class RoomComplexityServlet(BaseFederationServlet): """ Indicates to other servers how complex (and therefore likely @@ -1474,18 +1509,24 @@ DEFAULT_SERVLET_GROUPS = ( ) -def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=None): +def register_servlets( + hs: HomeServer, + resource: HttpServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + servlet_groups: Optional[Container[str]] = None, +): """Initialize and register servlet classes. Will by default register all servlets. For custom behaviour, pass in a list of servlet_groups to register. Args: - hs (synapse.server.HomeServer): homeserver - resource (JsonResource): resource class to register to - authenticator (Authenticator): authenticator to use - ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use - servlet_groups (list[str], optional): List of servlet groups to register. + hs: homeserver + resource: resource class to register to + authenticator: authenticator to use + ratelimiter: ratelimiter to use + servlet_groups: List of servlet groups to register. Defaults to ``DEFAULT_SERVLET_GROUPS``. """ if not servlet_groups: @@ -1500,6 +1541,14 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N server_name=hs.hostname, ).register(resource) + if hs.config.experimental.spaces_enabled: + FederationSpaceSummaryServlet( + handler=hs.get_space_summary_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + if "openid" in servlet_groups: for servletclass in OPENID_SERVLET_CLASSES: servletclass( diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index a3f8d92d08..368c44708d 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py
@@ -46,7 +46,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict, get_domain_from_id if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index f9a0f40221..4b16a4ac29 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py
@@ -25,7 +25,7 @@ from synapse.types import GroupID, JsonDict, RoomID, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index d29b066a56..fb899aa90d 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py
@@ -24,7 +24,7 @@ from synapse.api.ratelimiting import Ratelimiter from synapse.types import UserID if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ class BaseHandler: # The rate_hz and burst_count are overridden on a per-user basis self.request_ratelimiter = Ratelimiter( - clock=self.clock, rate_hz=0, burst_count=0 + store=self.store, clock=self.clock, rate_hz=0, burst_count=0 ) self._rc_message = self.hs.config.rc_message @@ -57,6 +57,7 @@ class BaseHandler: # by the presence of rate limits in the config if self.hs.config.rc_admin_redaction: self.admin_redaction_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, @@ -91,11 +92,6 @@ class BaseHandler: if app_service is not None: return # do not ratelimit app service senders - # Disable rate limiting of users belonging to any AS that is configured - # not to be rate limited in its registration file (rate_limited: true|false). - if requester.app_service and not requester.app_service.is_rate_limited(): - return - messages_per_second = self._rc_message.per_second burst_count = self._rc_message.burst_count @@ -113,11 +109,11 @@ class BaseHandler: if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) + await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) else: # Override rate and burst count per-user - self.request_ratelimiter.ratelimit( - user_id, + await self.request_ratelimiter.ratelimit( + requester, rate_hz=messages_per_second, burst_count=burst_count, update=update, diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index b1a5df9638..1ce6d697ed 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py
@@ -25,7 +25,7 @@ from synapse.replication.http.account_data import ( from synapse.types import JsonDict, UserID if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer class AccountDataHandler: diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 664d09da1c..bee1447c2e 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py
@@ -18,7 +18,7 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from synapse.api.errors import StoreError, SynapseError from synapse.logging.context import make_deferred_yieldable @@ -27,7 +27,7 @@ from synapse.types import UserID from synapse.util import stringutils if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -241,7 +241,10 @@ class AccountValidityHandler: return True async def renew_account_for_user( - self, user_id: str, expiration_ts: int = None, email_sent: bool = False + self, + user_id: str, + expiration_ts: Optional[int] = None, + email_sent: bool = False, ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 132be238dd..2a25af6288 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py
@@ -24,7 +24,7 @@ from twisted.web.resource import Resource from synapse.app import check_bind_error if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index db68c94c50..c494de49a3 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py
@@ -25,7 +25,7 @@ from synapse.visibility import filter_events_for_client from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index deab8ff2d0..996f9e5deb 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py
@@ -38,7 +38,7 @@ from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, User from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fb5f8118f0..08e413bc98 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -70,7 +70,7 @@ from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -238,6 +238,7 @@ class AuthHandler(BaseHandler): # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, @@ -248,6 +249,7 @@ class AuthHandler(BaseHandler): # Ratelimitier for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, @@ -352,7 +354,7 @@ class AuthHandler(BaseHandler): requester_user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False) + await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False) # build a list of supported flows supported_ui_auth_types = await self._get_available_ui_auth_types( @@ -373,7 +375,9 @@ class AuthHandler(BaseHandler): ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id) + await self._failed_uia_attempts_ratelimiter.can_do_action( + requester, + ) raise # find the completed login type @@ -886,6 +890,19 @@ class AuthHandler(BaseHandler): ) return result + def can_change_password(self) -> bool: + """Get whether users on this server are allowed to change or set a password. + + Both `config.password_enabled` and `config.password_localdb_enabled` must be true. + + Note that any account (even SSO accounts) are allowed to add passwords if the above + is true. + + Returns: + Whether users on this server are allowed to change or set a password + """ + return self._password_enabled and self._password_localdb_enabled + def get_supported_login_types(self) -> Iterable[str]: """Get a the login types supported for the /login API @@ -969,8 +986,8 @@ class AuthHandler(BaseHandler): # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. if ratelimit: - self._failed_login_attempts_ratelimiter.ratelimit( - (medium, address), update=False + await self._failed_login_attempts_ratelimiter.ratelimit( + None, (medium, address), update=False ) # Check for login providers that support 3pid login types @@ -1003,8 +1020,8 @@ class AuthHandler(BaseHandler): # this code path, which is fine as then the per-user ratelimit # will kick in below. if ratelimit: - self._failed_login_attempts_ratelimiter.can_do_action( - (medium, address) + await self._failed_login_attempts_ratelimiter.can_do_action( + None, (medium, address) ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -1026,8 +1043,8 @@ class AuthHandler(BaseHandler): # Check if we've hit the failed ratelimit (but don't update it) if ratelimit: - self._failed_login_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), update=False + await self._failed_login_attempts_ratelimiter.ratelimit( + None, qualified_user_id.lower(), update=False ) try: @@ -1038,8 +1055,8 @@ class AuthHandler(BaseHandler): # exception and masking the LoginError. The actual ratelimiting # should have happened above. if ratelimit: - self._failed_login_attempts_ratelimiter.can_do_action( - qualified_user_id.lower() + await self._failed_login_attempts_ratelimiter.can_do_action( + None, qualified_user_id.lower() ) raise diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index cb67589f7d..5060936f94 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py
@@ -27,7 +27,7 @@ from synapse.http.site import SynapseRequest from synapse.types import UserID, map_username_to_mxid_localpart if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 3886d3124d..2bcd8f5435 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py
@@ -23,7 +23,7 @@ from synapse.types import Requester, UserID, create_requester from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index df3cdc8fba..7e76db3e2a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -45,7 +45,7 @@ from synapse.util.retryutils import NotRetryingDestination from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -166,7 +166,7 @@ class DeviceWorkerHandler(BaseHandler): # Fetch the current state at the time. try: - event_ids = await self.store.get_forward_extremeties_for_room( + event_ids = await self.store.get_forward_extremities_for_room_at_stream_ordering( room_id, stream_ordering=stream_ordering ) except errors.StoreError: @@ -631,7 +631,7 @@ class DeviceListUpdater: max_len=10000, expiry_ms=30 * 60 * 1000, iterable=True, - ) + ) # type: ExpiringCache[str, Set[str]] # Attempt to resync out of sync device lists every 30s. self._resync_retry_in_progress = False @@ -760,7 +760,7 @@ class DeviceListUpdater: """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ - seen_updates = self._seen_updates.get(user_id, set()) + seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str] extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) @@ -907,6 +907,7 @@ class DeviceListUpdater: master_key = result.get("master_key") self_signing_key = result.get("self_signing_key") + ignore_devices = False # If the remote server has more than ~1000 devices for this user # we assume that something is going horribly wrong (e.g. a bot # that logs in and creates a new device every time it tries to @@ -925,6 +926,12 @@ class DeviceListUpdater: len(devices), ) devices = [] + ignore_devices = True + else: + cached_devices = await self.store.get_cached_devices_for_user(user_id) + if cached_devices == {d["device_id"]: d for d in devices}: + devices = [] + ignore_devices = True for device in devices: logger.debug( @@ -934,7 +941,10 @@ class DeviceListUpdater: stream_id, ) - await self.store.update_remote_device_list_cache(user_id, devices, stream_id) + if not ignore_devices: + await self.store.update_remote_device_list_cache( + user_id, devices, stream_id + ) device_ids = [device["device_id"] for device in devices] # Handle cross-signing keys. @@ -945,7 +955,8 @@ class DeviceListUpdater: ) device_ids = device_ids + cross_signing_device_ids - await self.device_handler.notify_device_update(user_id, device_ids) + if device_ids: + await self.device_handler.notify_device_update(user_id, device_ids) # We clobber the seen updates since we've re-synced from a given # point. @@ -973,14 +984,17 @@ class DeviceListUpdater: """ device_ids = [] - if master_key: + current_keys_map = await self.store.get_e2e_cross_signing_keys_bulk([user_id]) + current_keys = current_keys_map.get(user_id) or {} + + if master_key and master_key != current_keys.get("master"): await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) _, verify_key = get_verify_key_from_cross_signing_key(master_key) # verify_key is a VerifyKey from signedjson, which uses # .version to denote the portion of the key ID after the # algorithm and colon, which is the device ID device_ids.append(verify_key.version) - if self_signing_key: + if self_signing_key and self_signing_key != current_keys.get("self_signing"): await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 7db4f48965..c971eeb4d2 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py
@@ -21,10 +21,10 @@ from synapse.api.errors import SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( + SynapseTags, get_active_span_text_map, log_kv, set_tag, - start_active_span, ) from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import JsonDict, Requester, UserID, get_domain_from_id @@ -32,7 +32,7 @@ from synapse.util import json_encoder from synapse.util.stringutils import random_string if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -81,6 +81,7 @@ class DeviceMessageHandler: ) self._ratelimiter = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.rc_key_requests.per_second, burst_count=hs.config.rc_key_requests.burst_count, @@ -182,7 +183,10 @@ class DeviceMessageHandler: ) -> None: sender_user_id = requester.user.to_string() - set_tag("number_of_messages", len(messages)) + message_id = random_string(16) + set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) + + log_kv({"number_of_to_device_messages": len(messages)}) set_tag("sender", sender_user_id) local_messages = {} remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] @@ -191,8 +195,8 @@ class DeviceMessageHandler: if ( message_type == EduTypes.RoomKeyRequest and user_id != sender_user_id - and self._ratelimiter.can_do_action( - (sender_user_id, requester.device_id) + and await self._ratelimiter.can_do_action( + requester, (sender_user_id, requester.device_id) ) ): continue @@ -204,32 +208,35 @@ class DeviceMessageHandler: "content": message_content, "type": message_type, "sender": sender_user_id, + "message_id": message_id, } for device_id, message_content in by_device.items() } if messages_by_device: local_messages[user_id] = messages_by_device + log_kv( + { + "user_id": user_id, + "device_id": list(messages_by_device), + } + ) else: destination = get_domain_from_id(user_id) remote_messages.setdefault(destination, {})[user_id] = by_device - message_id = random_string(16) - context = get_active_span_text_map() remote_edu_contents = {} for destination, messages in remote_messages.items(): - with start_active_span("to_device_for_user"): - set_tag("destination", destination) - remote_edu_contents[destination] = { - "messages": messages, - "sender": sender_user_id, - "type": message_type, - "message_id": message_id, - "org.matrix.opentracing_context": json_encoder.encode(context), - } + log_kv({"destination": destination}) + remote_edu_contents[destination] = { + "messages": messages, + "sender": sender_user_id, + "type": message_type, + "message_id": message_id, + "org.matrix.opentracing_context": json_encoder.encode(context), + } - log_kv({"local_messages": local_messages}) stream_id = await self.store.add_messages_to_device_inbox( local_messages, remote_edu_contents ) @@ -238,7 +245,6 @@ class DeviceMessageHandler: "to_device_key", stream_id, users=local_messages.keys() ) - log_kv({"remote_messages": remote_messages}) if self.federation_sender: for destination in remote_messages.keys(): # Enqueue a new federation transaction to send the new diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 9a946a3cfe..92b18378fc 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py
@@ -38,11 +38,10 @@ from synapse.types import ( ) from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer -from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -1008,7 +1007,7 @@ class E2eKeysHandler: return signature_list, failures async def _get_e2e_cross_signing_verify_key( - self, user_id: str, key_type: str, from_user_id: str = None + self, user_id: str, key_type: str, from_user_id: Optional[str] = None ) -> Tuple[JsonDict, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. @@ -1292,17 +1291,6 @@ class SigningKeyEduUpdater: # user_id -> list of updates waiting to be handled. self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] - # Recently seen stream ids. We don't bother keeping these in the DB, - # but they're useful to have them about to reduce the number of spurious - # resyncs. - self._seen_updates = ExpiringCache( - cache_name="signing_key_update_edu", - clock=self.clock, - max_len=10000, - expiry_ms=30 * 60 * 1000, - iterable=True, - ) - async def incoming_signing_key_update( self, origin: str, edu_content: JsonDict ) -> None: diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 622cae23be..a910d246d6 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py
@@ -29,7 +29,7 @@ from synapse.types import JsonDict from synapse.util.async_helpers import Linearizer if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 598a66f74c..5ea8a7b603 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -21,7 +21,17 @@ import itertools import logging from collections.abc import Container from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) import attr from signedjson.key import decode_verify_key_bytes @@ -171,15 +181,17 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: + async def on_receive_pdu( + self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False + ) -> None: """Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events Args: - origin (str): server which initiated the /send/ transaction. Will + origin: server which initiated the /send/ transaction. Will be used to fetch missing events or state. - pdu (FrozenEvent): received PDU - sent_to_us_directly (bool): True if this event was pushed to us; False if + pdu: received PDU + sent_to_us_directly: True if this event was pushed to us; False if we pulled it as the result of a missing prev_event. """ @@ -411,13 +423,15 @@ class FederationHandler(BaseHandler): await self._process_received_pdu(origin, pdu, state=state) - async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): + async def _get_missing_events_for_pdu( + self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int + ) -> None: """ Args: - origin (str): Origin of the pdu. Will be called to get the missing events + origin: Origin of the pdu. Will be called to get the missing events pdu: received pdu - prevs (set(str)): List of event ids which we are missing - min_depth (int): Minimum depth of events to return. + prevs: List of event ids which we are missing + min_depth: Minimum depth of events to return. """ room_id = pdu.room_id @@ -778,7 +792,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, state: Optional[Iterable[EventBase]], - ): + ) -> None: """Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. @@ -887,7 +901,9 @@ class FederationHandler(BaseHandler): logger.exception("Failed to resync device for %s", sender) @log_function - async def backfill(self, dest, room_id, limit, extremities): + async def backfill( + self, dest: str, room_id: str, limit: int, extremities: List[str] + ) -> List[EventBase]: """Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side @@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler): curr_state = await self.state_handler.get_current_state(room_id) - def get_domains_from_state(state): + def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: """Get joined domains from state Args: - state (dict[tuple, FrozenEvent]): State map from type/state - key to event. + state: State map from type/state key to event. Returns: - list[tuple[str, int]]: Returns a list of servers with the - lowest depth of their joins. Sorted by lowest depth first. + Returns a list of servers with the lowest depth of their joins. + Sorted by lowest depth first. """ joined_users = [ (state_key, int(event.depth)) @@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler): domain for domain, depth in curr_domains if domain != self.server_name ] - async def try_backfill(domains): + async def try_backfill(domains: List[str]) -> bool: # TODO: Should we try multiple of these at a time? for dom in domains: try: @@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler): } for e_id, _ in sorted_extremeties_tuple: - likely_domains = get_domains_from_state(states[e_id]) + likely_extremeties_domains = get_domains_from_state(states[e_id]) success = await try_backfill( - [dom for dom, _ in likely_domains if dom not in tried_domains] + [ + dom + for dom, _ in likely_extremeties_domains + if dom not in tried_domains + ] ) if success: return True - tried_domains.update(dom for dom, _ in likely_domains) + tried_domains.update(dom for dom, _ in likely_extremeties_domains) return False async def _get_events_and_persist( self, destination: str, room_id: str, events: Iterable[str] - ): + ) -> None: """Fetch the given events from a server, and persist them as outliers. This function *does not* recursively get missing auth events of the @@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler): event_infos, ) - def _sanity_check_event(self, ev): + def _sanity_check_event(self, ev: EventBase) -> None: """ Do some early sanity checks of a received event @@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler): or cascade of event fetches. Args: - ev (synapse.events.EventBase): event to be checked - - Returns: None + ev: event to be checked Raises: SynapseError if the event does not pass muster @@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler): ) raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") - async def send_invite(self, target_host, event): + async def send_invite(self, target_host: str, event: EventBase) -> EventBase: """Sends the invite to the remote server for signing. Invites must be signed by the invitee's server before distribution. @@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler): run_in_background(self._handle_queued_pdus, room_queue) - async def _handle_queued_pdus(self, room_queue): + async def _handle_queued_pdus( + self, room_queue: List[Tuple[EventBase, str]] + ) -> None: """Process PDUs which got queued up while we were busy send_joining. Args: - room_queue (list[FrozenEvent, str]): list of PDUs to be processed - and the servers that sent them + room_queue: list of PDUs to be processed and the servers that sent them """ for p, origin in room_queue: try: @@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler): return event - async def on_send_join_request(self, origin, pdu): + async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict: """We have received a join event for a room. Fully process it and respond with the current state and auth chains. """ @@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler): async def on_invite_request( self, origin: str, event: EventBase, room_version: RoomVersion - ): + ) -> EventBase: """We've got an invite event. Process and persist it. Sign it. Respond with the now signed event. @@ -1711,7 +1729,7 @@ class FederationHandler(BaseHandler): member_handler = self.hs.get_room_member_handler() # We don't rate limit based on room ID, as that should be done by # sending server. - member_handler.ratelimit_invite(None, event.state_key) + await member_handler.ratelimit_invite(None, None, event.state_key) # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a @@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler): return event - async def on_send_leave_request(self, origin, pdu): + async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None: """ We have received a leave event for a room. Fully process it.""" event = pdu @@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler): else: return None - async def get_min_depth_for_context(self, context): + async def get_min_depth_for_context(self, context: str) -> int: return await self.store.get_min_depth(context) async def _handle_new_event( - self, origin, event, state=None, auth_events=None, backfilled=False - ): + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]] = None, + auth_events: Optional[MutableStateMap[EventBase]] = None, + backfilled: bool = False, + ) -> EventContext: context = await self._prep_event( origin, event, state=state, auth_events=auth_events, backfilled=backfilled ) @@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler): logger.warning("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True - async def on_query_auth( - self, origin, event_id, room_id, remote_auth_chain, rejects, missing - ): - in_room = await self.auth.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") - - event = await self.store.get_event(event_id, check_room_id=room_id) - - # Just go through and process each event in `remote_auth_chain`. We - # don't want to fall into the trap of `missing` being wrong. - for e in remote_auth_chain: - try: - await self._handle_new_event(origin, e) - except AuthError: - pass - - # Now get the current auth_chain for the event. - local_auth_chain = await self.store.get_auth_chain( - room_id, list(event.auth_event_ids()), include_given=True - ) - - # TODO: Check if we would now reject event_id. If so we need to tell - # everyone. - - ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain) - - logger.debug("on_query_auth returning: %s", ret) - - return ret - async def on_get_missing_events( - self, origin, room_id, earliest_events, latest_events, limit - ): + self, + origin: str, + room_id: str, + earliest_events: List[str], + latest_events: List[str], + limit: int, + ) -> List[EventBase]: in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler): assumes that we have already processed all events in remote_auth Params: - local_auth (list) - remote_auth (list) + local_auth + remote_auth Returns: dict @@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler): @log_function async def exchange_third_party_invite( - self, sender_user_id, target_user_id, room_id, signed - ): + self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict + ) -> None: third_party_invite = {"signed": signed} event_dict = { @@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler): await member_handler.send_membership_event(None, event, context) async def add_display_name_to_third_party_invite( - self, room_version, event_dict, event, context - ): + self, + room_version: str, + event_dict: JsonDict, + event: EventBase, + context: EventContext, + ) -> Tuple[EventBase, EventContext]: key = ( EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"], @@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler): EventValidator().validate_new(event, self.config) return (event, context) - async def _check_signature(self, event, context): + async def _check_signature(self, event: EventBase, context: EventContext) -> None: """ Checks that the signature in the event is consistent with its invite. Args: - event (Event): The m.room.member event to check - context (EventContext): + event: The m.room.member event to check + context: Raises: AuthError: if signature didn't match any keys, or key has been @@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler): raise last_exception - async def _check_key_revocation(self, public_key, url): + async def _check_key_revocation(self, public_key: str, url: str) -> None: """ Checks whether public_key has been revoked. Args: - public_key (str): base-64 encoded public key. - url (str): Key revocation URL. + public_key: base-64 encoded public key. + url: Key revocation URL. Raises: AuthError: if they key has been revoked. diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index bfb95e3eee..a41ca5df9c 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py
@@ -21,7 +21,7 @@ from synapse.api.errors import HttpResponseException, RequestSendFailed, Synapse from synapse.types import GroupID, JsonDict, get_domain_from_id if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 5f346f6d6d..d89fa5fb30 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py
@@ -61,17 +61,19 @@ class IdentityHandler(BaseHandler): # Ratelimiters for `/requestToken` endpoints. self._3pid_validation_ratelimiter_ip = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) self._3pid_validation_ratelimiter_address = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) - def ratelimit_request_token_requests( + async def ratelimit_request_token_requests( self, request: SynapseRequest, medium: str, @@ -85,8 +87,12 @@ class IdentityHandler(BaseHandler): address: The actual threepid ID, e.g. the phone number or email address """ - self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) - self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + await self._3pid_validation_ratelimiter_ip.ratelimit( + None, (medium, request.getClientIP()) + ) + await self._3pid_validation_ratelimiter_address.ratelimit( + None, (medium, address) + ) async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 41ded62d21..577de9b550 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -385,7 +385,7 @@ class EventCreationHandler: self._events_shard_config = self.config.worker.events_shard_config self._instance_name = hs.get_instance_name() - self.room_invite_state_types = self.hs.config.room_invite_state_types + self.room_invite_state_types = self.hs.config.api.room_prejoin_state self.membership_types_to_include_profile_data_in = ( {Membership.JOIN, Membership.INVITE} diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 6d8551a6d6..6624212d6f 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py
@@ -149,6 +149,9 @@ class OidcHandler: Args: request: the incoming request from the browser. """ + # This will always be set by the time Twisted calls us. + assert request.args is not None + # The provider might redirect with an error. # In that case, just display it as-is. if b"error" in request.args: @@ -280,6 +283,7 @@ class OidcProvider: self._config = provider self._callback_url = hs.config.oidc_callback_url # type: str + self._oidc_attribute_requirements = provider.attribute_requirements self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method @@ -859,6 +863,18 @@ class OidcProvider: ) # otherwise, it's a login + logger.debug("Userinfo for OIDC login: %s", userinfo) + + # Ensure that the attributes of the logged in user meet the required + # attributes by checking the userinfo against attribute_requirements + # In order to deal with the fact that OIDC userinfo can contain many + # types of data, we wrap non-list values in lists. + if not self._sso_handler.check_required_attributes( + request, + {k: v if isinstance(v, list) else [v] for k, v in userinfo.items()}, + self._oidc_attribute_requirements, + ): + return # Call the mapper to register/login the user try: diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
index 6c635cc31b..92cefa11aa 100644 --- a/synapse/handlers/password_policy.py +++ b/synapse/handlers/password_policy.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING from synapse.api.errors import Codes, PasswordRefusedError if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 54631b4ee2..c817f2952d 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -25,7 +25,17 @@ The methods that define policy are: import abc import logging from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple +from typing import ( + TYPE_CHECKING, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) from prometheus_client import Counter from typing_extensions import ContextManager @@ -34,6 +44,7 @@ import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState +from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background from synapse.logging.utils import log_function from synapse.metrics import LaterGauge @@ -42,7 +53,7 @@ from synapse.state import StateHandler from synapse.storage.databases.main import DataStore from synapse.types import Collection, JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -104,6 +115,8 @@ class BasePresenceHandler(abc.ABC): self.clock = hs.get_clock() self.store = hs.get_datastore() + self._busy_presence_enabled = hs.config.experimental.msc3026_enabled + active_presence = self.store.take_presence_startup_info() self.user_to_current_state = {state.user_id: state for state in active_presence} @@ -207,6 +220,7 @@ class PresenceHandler(BasePresenceHandler): self.notifier = hs.get_notifier() self.federation = hs.get_federation_sender() self.state = hs.get_state_handler() + self.presence_router = hs.get_presence_router() self._presence_enabled = hs.config.use_presence federation_registry = hs.get_federation_registry() @@ -651,7 +665,7 @@ class PresenceHandler(BasePresenceHandler): """ stream_id, max_token = await self.store.update_presence(states) - parties = await get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, self.presence_router, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -730,8 +744,12 @@ class PresenceHandler(BasePresenceHandler): PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE, + PresenceState.BUSY, ) - if presence not in valid_presence: + + if presence not in valid_presence or ( + presence == PresenceState.BUSY and not self._busy_presence_enabled + ): raise SynapseError(400, "Invalid presence state") user_id = target_user.to_string() @@ -744,7 +762,9 @@ class PresenceHandler(BasePresenceHandler): msg = status_msg if presence != PresenceState.OFFLINE else None new_fields["status_msg"] = msg - if presence == PresenceState.ONLINE: + if presence == PresenceState.ONLINE or ( + presence == PresenceState.BUSY and self._busy_presence_enabled + ): new_fields["last_active_ts"] = self.clock.time_msec() await self._update_states([prev_state.copy_and_replace(**new_fields)]) @@ -1033,7 +1053,12 @@ class PresenceEventSource: # # Presence -> Notifier -> PresenceEventSource -> Presence # + # Same with get_module_api, get_presence_router + # + # AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler self.get_presence_handler = hs.get_presence_handler + self.get_module_api = hs.get_module_api + self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -1047,7 +1072,7 @@ class PresenceEventSource: include_offline=True, explicit_room_id=None, **kwargs - ): + ) -> Tuple[List[UserPresenceState], int]: # The process for getting presence events are: # 1. Get the rooms the user is in. # 2. Get the list of user in the rooms. @@ -1060,7 +1085,17 @@ class PresenceEventSource: # We don't try and limit the presence updates by the current token, as # sending down the rare duplicate is not a concern. + user_id = user.to_string() + stream_change_cache = self.store.presence_stream_cache + with Measure(self.clock, "presence.get_new_events"): + if user_id in self.get_module_api()._send_full_presence_to_local_users: + # This user has been specified by a module to receive all current, online + # user presence. Removing from_key and setting include_offline to false + # will do effectively this. + from_key = None + include_offline = False + if from_key is not None: from_key = int(from_key) @@ -1083,59 +1118,209 @@ class PresenceEventSource: # doesn't return. C.f. #5503. return [], max_token - presence = self.get_presence_handler() - stream_change_cache = self.store.presence_stream_cache - + # Figure out which other users this user should receive updates for users_interested_in = await self._get_interested_in(user, explicit_room_id) - user_ids_changed = set() # type: Collection[str] - changed = None - if from_key: - changed = stream_change_cache.get_all_entities_changed(from_key) + # We have a set of users that we're interested in the presence of. We want to + # cross-reference that with the users that have actually changed their presence. - if changed is not None and len(changed) < 500: - assert isinstance(user_ids_changed, set) + # Check whether this user should see all user updates - # For small deltas, its quicker to get all changes and then - # work out if we share a room or they're in our presence list - get_updates_counter.labels("stream").inc() - for other_user_id in changed: - if other_user_id in users_interested_in: - user_ids_changed.add(other_user_id) - else: - # Too many possible updates. Find all users we can see and check - # if any of them have changed. - get_updates_counter.labels("full").inc() + if users_interested_in == PresenceRouter.ALL_USERS: + # Provide presence state for all users + presence_updates = await self._filter_all_presence_updates_for_user( + user_id, include_offline, from_key + ) - if from_key: - user_ids_changed = stream_change_cache.get_entities_changed( - users_interested_in, from_key + # Remove the user from the list of users to receive all presence + if user_id in self.get_module_api()._send_full_presence_to_local_users: + self.get_module_api()._send_full_presence_to_local_users.remove( + user_id ) + + return presence_updates, max_token + + # Make mypy happy. users_interested_in should now be a set + assert not isinstance(users_interested_in, str) + + # The set of users that we're interested in and that have had a presence update. + # We'll actually pull the presence updates for these users at the end. + interested_and_updated_users = ( + set() + ) # type: Union[Set[str], FrozenSet[str]] + + if from_key: + # First get all users that have had a presence update + updated_users = stream_change_cache.get_all_entities_changed(from_key) + + # Cross-reference users we're interested in with those that have had updates. + # Use a slightly-optimised method for processing smaller sets of updates. + if updated_users is not None and len(updated_users) < 500: + # For small deltas, it's quicker to get all changes and then + # cross-reference with the users we're interested in + get_updates_counter.labels("stream").inc() + for other_user_id in updated_users: + if other_user_id in users_interested_in: + # mypy thinks this variable could be a FrozenSet as it's possibly set + # to one in the `get_entities_changed` call below, and `add()` is not + # method on a FrozenSet. That doesn't affect us here though, as + # `interested_and_updated_users` is clearly a set() above. + interested_and_updated_users.add(other_user_id) # type: ignore else: - user_ids_changed = users_interested_in + # Too many possible updates. Find all users we can see and check + # if any of them have changed. + get_updates_counter.labels("full").inc() + + interested_and_updated_users = ( + stream_change_cache.get_entities_changed( + users_interested_in, from_key + ) + ) + else: + # No from_key has been specified. Return the presence for all users + # this user is interested in + interested_and_updated_users = users_interested_in + + # Retrieve the current presence state for each user + users_to_state = await self.get_presence_handler().current_state_for_users( + interested_and_updated_users + ) + presence_updates = list(users_to_state.values()) - updates = await presence.current_state_for_users(user_ids_changed) + # Remove the user from the list of users to receive all presence + if user_id in self.get_module_api()._send_full_presence_to_local_users: + self.get_module_api()._send_full_presence_to_local_users.remove(user_id) - if include_offline: - return (list(updates.values()), max_token) + if not include_offline: + # Filter out offline presence states + presence_updates = self._filter_offline_presence_state(presence_updates) + + return presence_updates, max_token + + async def _filter_all_presence_updates_for_user( + self, + user_id: str, + include_offline: bool, + from_key: Optional[int] = None, + ) -> List[UserPresenceState]: + """ + Computes the presence updates a user should receive. + + First pulls presence updates from the database. Then consults PresenceRouter + for whether any updates should be excluded by user ID. + + Args: + user_id: The User ID of the user to compute presence updates for. + include_offline: Whether to include offline presence states from the results. + from_key: The minimum stream ID of updates to pull from the database + before filtering. + + Returns: + A list of presence states for the given user to receive. + """ + if from_key: + # Only return updates since the last sync + updated_users = self.store.presence_stream_cache.get_all_entities_changed( + from_key + ) + if not updated_users: + updated_users = [] + + # Get the actual presence update for each change + users_to_state = await self.get_presence_handler().current_state_for_users( + updated_users + ) + presence_updates = list(users_to_state.values()) + + if not include_offline: + # Filter out offline states + presence_updates = self._filter_offline_presence_state(presence_updates) else: - return ( - [s for s in updates.values() if s.state != PresenceState.OFFLINE], - max_token, + users_to_state = await self.store.get_presence_for_all_users( + include_offline=include_offline ) + presence_updates = list(users_to_state.values()) + + # TODO: This feels wildly inefficient, and it's unfortunate we need to ask the + # module for information on a number of users when we then only take the info + # for a single user + + # Filter through the presence router + users_to_state_set = await self.get_presence_router().get_users_for_states( + presence_updates + ) + + # We only want the mapping for the syncing user + presence_updates = list(users_to_state_set[user_id]) + + # Return presence information for all users + return presence_updates + + def _filter_offline_presence_state( + self, presence_updates: Iterable[UserPresenceState] + ) -> List[UserPresenceState]: + """Given an iterable containing user presence updates, return a list with any offline + presence states removed. + + Args: + presence_updates: Presence states to filter + + Returns: + A new list with any offline presence states removed. + """ + return [ + update + for update in presence_updates + if update.state != PresenceState.OFFLINE + ] + def get_current_key(self): return self.store.get_current_presence_token() @cached(num_args=2, cache_context=True) - async def _get_interested_in(self, user, explicit_room_id, cache_context): + async def _get_interested_in( + self, + user: UserID, + explicit_room_id: Optional[str] = None, + cache_context: Optional[_CacheContext] = None, + ) -> Union[Set[str], str]: """Returns the set of users that the given user should see presence - updates for + updates for. + + Args: + user: The user to retrieve presence updates for. + explicit_room_id: The users that are in the room will be returned. + + Returns: + A set of user IDs to return presence updates for, or "ALL" to return all + known updates. """ user_id = user.to_string() users_interested_in = set() users_interested_in.add(user_id) # So that we receive our own presence + # cache_context isn't likely to ever be None due to the @cached decorator, + # but we can't have a non-optional argument after the optional argument + # explicit_room_id either. Assert cache_context is not None so we can use it + # without mypy complaining. + assert cache_context + + # Check with the presence router whether we should poll additional users for + # their presence information + additional_users = await self.get_presence_router().get_interested_users( + user.to_string() + ) + if additional_users == PresenceRouter.ALL_USERS: + # If the module requested that this user see the presence updates of *all* + # users, then simply return that instead of calculating what rooms this + # user shares + return PresenceRouter.ALL_USERS + + # Add the additional users from the router + users_interested_in.update(additional_users) + + # Find the users who share a room with this user users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id, on_invalidate=cache_context.invalidate ) @@ -1306,14 +1491,15 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): async def get_interested_parties( - store: DataStore, states: List[UserPresenceState] + store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState] ) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]: """Given a list of states return which entities (rooms, users) are interested in the given states. Args: - store - states + store: The homeserver's data store. + presence_router: A module for augmenting the destinations for presence updates. + states: A list of incoming user presence updates. Returns: A 2-tuple of `(room_ids_to_states, users_to_states)`, @@ -1329,11 +1515,22 @@ async def get_interested_parties( # Always notify self users_to_states.setdefault(state.user_id, []).append(state) + # Ask a presence routing module for any additional parties if one + # is loaded. + router_users_to_states = await presence_router.get_users_for_states(states) + + # Update the dictionaries with additional destinations and state to send + for user_id, user_states in router_users_to_states.items(): + users_to_states.setdefault(user_id, []).extend(user_states) + return room_ids_to_states, users_to_states async def get_interested_remotes( - store: DataStore, states: List[UserPresenceState], state_handler: StateHandler + store: DataStore, + presence_router: PresenceRouter, + states: List[UserPresenceState], + state_handler: StateHandler, ) -> List[Tuple[Collection[str], List[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. @@ -1341,9 +1538,10 @@ async def get_interested_remotes( All the presence states should be for local users only. Args: - store - states - state_handler + store: The homeserver's data store. + presence_router: A module for augmenting the destinations for presence updates. + states: A list of incoming user presence updates. + state_handler: Returns: A list of 2-tuples of destinations and states, where for @@ -1355,7 +1553,9 @@ async def get_interested_remotes( # First we look up the rooms each user is in (as well as any explicit # subscriptions), then for each distinct room we look up the remote # hosts in those rooms. - room_ids_to_states, users_to_states = await get_interested_parties(store, states) + room_ids_to_states, users_to_states = await get_interested_parties( + store, presence_router, states + ) for room_id, states in room_ids_to_states.items(): hosts = await state_handler.get_current_hosts_in_room(room_id) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index dd59392bda..a755363c3f 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py
@@ -36,7 +36,7 @@ from synapse.types import ( from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index 6bb2fd936b..a54fe1968e 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py
@@ -21,7 +21,7 @@ from synapse.util.async_helpers import Linearizer from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 6a6c528849..dbfe9bfaca 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py
@@ -20,7 +20,7 @@ from synapse.handlers._base import BaseHandler from synapse.types import JsonDict, ReadReceipt, get_domain_from_id if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 1abc8875cb..9701b76d0f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -38,7 +38,7 @@ from synapse.types import RoomAlias, UserID, create_requester from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -204,7 +204,7 @@ class RegistrationHandler(BaseHandler): Raises: SynapseError if there was a problem registering. """ - self.check_registration_ratelimit(address) + await self.check_registration_ratelimit(address) result = await self.spam_checker.check_registration_for_spam( threepid, @@ -437,10 +437,10 @@ class RegistrationHandler(BaseHandler): if RoomAlias.is_valid(r): ( - room_id, + room, remote_room_hosts, ) = await room_member_handler.lookup_room_alias(room_alias) - room_id = room_id.to_string() + room_id = room.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (r,) @@ -583,7 +583,7 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - def check_registration_ratelimit(self, address: Optional[str]) -> None: + async def check_registration_ratelimit(self, address: Optional[str]) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address @@ -597,7 +597,7 @@ class RegistrationHandler(BaseHandler): if not address: return - self.ratelimiter.ratelimit(address) + await self.ratelimiter.ratelimit(None, address) async def register_with_store( self, diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 8c5b60e6be..ef82c3baa9 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py
@@ -29,7 +29,7 @@ from synapse.util.caches.response_cache import ResponseCache from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index c6a33251f2..03699b4019 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -76,22 +76,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.allow_per_room_profiles = self.config.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, ) self._join_rate_limiter_remote = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) self._invites_per_room_limiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, ) self._invites_per_user_limiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, @@ -156,15 +160,24 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """ raise NotImplementedError() - def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str): + @abc.abstractmethod + async def forget(self, user: UserID, room_id: str) -> None: + raise NotImplementedError() + + async def ratelimit_invite( + self, + requester: Optional[Requester], + room_id: Optional[str], + invitee_user_id: str, + ): """Ratelimit invites by room and by target user. If room ID is missing then we just rate limit by target user. """ if room_id: - self._invites_per_room_limiter.ratelimit(room_id) + await self._invites_per_room_limiter.ratelimit(requester, room_id) - self._invites_per_user_limiter.ratelimit(invitee_user_id) + await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) async def _local_membership_update( self, @@ -234,7 +247,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ( allowed, time_allowed, - ) = self._join_rate_limiter_local.can_requester_do_action(requester) + ) = await self._join_rate_limiter_local.can_do_action(requester) if not allowed: raise LimitExceededError( @@ -437,9 +450,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if effective_membership_state == Membership.INVITE: target_id = target.to_string() if ratelimit: - # Don't ratelimit application services. - if not requester.app_service or requester.app_service.is_rate_limited(): - self.ratelimit_invite(room_id, target_id) + await self.ratelimit_invite(requester, room_id, target_id) # block any attempts to invite the server notices mxid if target_id == self._server_notices_mxid: @@ -550,7 +561,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ( allowed, time_allowed, - ) = self._join_rate_limiter_remote.can_requester_do_action( + ) = await self._join_rate_limiter_remote.can_do_action( requester, ) diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 108730a7a1..3a90fc0c16 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler @@ -25,11 +25,14 @@ from synapse.replication.http.membership import ( ) from synapse.types import Requester, UserID +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class RoomMemberWorkerHandler(RoomMemberHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._remote_join_client = ReplRemoteJoin.make_client(hs) @@ -83,3 +86,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler): await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="left" ) + + async def forget(self, target: UserID, room_id: str) -> None: + raise RuntimeError("Cannot forget rooms on workers.") diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 94062e79cb..d742dfbd53 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py
@@ -30,7 +30,7 @@ from synapse.visibility import filter_events_for_client from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 84af2dde7e..f98a338ec5 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py
@@ -21,7 +21,7 @@ from synapse.types import Requester from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -41,7 +41,7 @@ class SetPasswordHandler(BaseHandler): logout_devices: bool, requester: Optional[Requester] = None, ) -> None: - if not self.hs.config.password_localdb_enabled: + if not self._auth_handler.can_change_password(): raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) try: diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py new file mode 100644
index 0000000000..5d9418969d --- /dev/null +++ b/synapse/handlers/space_summary.py
@@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging +from collections import deque +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple, cast + +import attr + +from synapse.api.constants import EventContentFields, EventTypes, HistoryVisibility +from synapse.api.errors import AuthError +from synapse.events import EventBase +from synapse.events.utils import format_event_for_client_v2 +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# number of rooms to return. We'll stop once we hit this limit. +# TODO: allow clients to reduce this with a request param. +MAX_ROOMS = 50 + +# max number of events to return per room. +MAX_ROOMS_PER_SPACE = 50 + +# max number of federation servers to hit per room +MAX_SERVERS_PER_SPACE = 3 + + +class SpaceSummaryHandler: + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() + self._auth = hs.get_auth() + self._room_list_handler = hs.get_room_list_handler() + self._state_handler = hs.get_state_handler() + self._store = hs.get_datastore() + self._event_serializer = hs.get_event_client_serializer() + self._server_name = hs.hostname + self._federation_client = hs.get_federation_client() + + async def get_space_summary( + self, + requester: str, + room_id: str, + suggested_only: bool = False, + max_rooms_per_space: Optional[int] = None, + ) -> JsonDict: + """ + Implementation of the space summary C-S API + + Args: + requester: user id of the user making this request + + room_id: room id to start the summary at + + suggested_only: whether we should only return children with the "suggested" + flag set. + + max_rooms_per_space: an optional limit on the number of child rooms we will + return. This does not apply to the root room (ie, room_id), and + is overridden by MAX_ROOMS_PER_SPACE. + + Returns: + summary dict to return + """ + # first of all, check that the user is in the room in question (or it's + # world-readable) + await self._auth.check_user_in_room_or_world_readable(room_id, requester) + + # the queue of rooms to process + room_queue = deque((_RoomQueueEntry(room_id, ()),)) + + # rooms we have already processed + processed_rooms = set() # type: Set[str] + + # events we have already processed. We don't necessarily have their event ids, + # so instead we key on (room id, state key) + processed_events = set() # type: Set[Tuple[str, str]] + + rooms_result = [] # type: List[JsonDict] + events_result = [] # type: List[JsonDict] + + while room_queue and len(rooms_result) < MAX_ROOMS: + queue_entry = room_queue.popleft() + room_id = queue_entry.room_id + if room_id in processed_rooms: + # already done this room + continue + + logger.debug("Processing room %s", room_id) + + is_in_room = await self._store.is_host_joined(room_id, self._server_name) + + # The client-specified max_rooms_per_space limit doesn't apply to the + # room_id specified in the request, so we ignore it if this is the + # first room we are processing. + max_children = max_rooms_per_space if processed_rooms else None + + if is_in_room: + rooms, events = await self._summarize_local_room( + requester, room_id, suggested_only, max_children + ) + else: + rooms, events = await self._summarize_remote_room( + queue_entry, + suggested_only, + max_children, + exclude_rooms=processed_rooms, + ) + + logger.debug( + "Query of %s returned rooms %s, events %s", + queue_entry.room_id, + [room.get("room_id") for room in rooms], + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], + ) + + rooms_result.extend(rooms) + + # any rooms returned don't need visiting again + processed_rooms.update(cast(str, room.get("room_id")) for room in rooms) + + # the room we queried may or may not have been returned, but don't process + # it again, anyway. + processed_rooms.add(room_id) + + # XXX: is it ok that we blindly iterate through any events returned by + # a remote server, whether or not they actually link to any rooms in our + # tree? + for ev in events: + # remote servers might return events we have already processed + # (eg, Dendrite returns inward pointers as well as outward ones), so + # we need to filter them out, to avoid returning duplicate links to the + # client. + ev_key = (ev["room_id"], ev["state_key"]) + if ev_key in processed_events: + continue + events_result.append(ev) + + # add the child to the queue. we have already validated + # that the vias are a list of server names. + room_queue.append( + _RoomQueueEntry(ev["state_key"], ev["content"]["via"]) + ) + processed_events.add(ev_key) + + return {"rooms": rooms_result, "events": events_result} + + async def federation_space_summary( + self, + room_id: str, + suggested_only: bool, + max_rooms_per_space: Optional[int], + exclude_rooms: Iterable[str], + ) -> JsonDict: + """ + Implementation of the space summary Federation API + + Args: + room_id: room id to start the summary at + + suggested_only: whether we should only return children with the "suggested" + flag set. + + max_rooms_per_space: an optional limit on the number of child rooms we will + return. Unlike the C-S API, this applies to the root room (room_id). + It is clipped to MAX_ROOMS_PER_SPACE. + + exclude_rooms: a list of rooms to skip over (presumably because the + calling server has already seen them). + + Returns: + summary dict to return + """ + # the queue of rooms to process + room_queue = deque((room_id,)) + + # the set of rooms that we should not walk further. Initialise it with the + # excluded-rooms list; we will add other rooms as we process them so that + # we do not loop. + processed_rooms = set(exclude_rooms) # type: Set[str] + + rooms_result = [] # type: List[JsonDict] + events_result = [] # type: List[JsonDict] + + while room_queue and len(rooms_result) < MAX_ROOMS: + room_id = room_queue.popleft() + if room_id in processed_rooms: + # already done this room + continue + + logger.debug("Processing room %s", room_id) + + rooms, events = await self._summarize_local_room( + None, room_id, suggested_only, max_rooms_per_space + ) + + processed_rooms.add(room_id) + + rooms_result.extend(rooms) + events_result.extend(events) + + # add any children to the queue + room_queue.extend(edge_event["state_key"] for edge_event in events) + + return {"rooms": rooms_result, "events": events_result} + + async def _summarize_local_room( + self, + requester: Optional[str], + room_id: str, + suggested_only: bool, + max_children: Optional[int], + ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]: + if not await self._is_room_accessible(room_id, requester): + return (), () + + room_entry = await self._build_room_entry(room_id) + + # look for child rooms/spaces. + child_events = await self._get_child_events(room_id) + + if suggested_only: + # we only care about suggested children + child_events = filter(_is_suggested_child_event, child_events) + + if max_children is None or max_children > MAX_ROOMS_PER_SPACE: + max_children = MAX_ROOMS_PER_SPACE + + now = self._clock.time_msec() + events_result = [] # type: List[JsonDict] + for edge_event in itertools.islice(child_events, max_children): + events_result.append( + await self._event_serializer.serialize_event( + edge_event, + time_now=now, + event_format=format_event_for_client_v2, + ) + ) + return (room_entry,), events_result + + async def _summarize_remote_room( + self, + room: "_RoomQueueEntry", + suggested_only: bool, + max_children: Optional[int], + exclude_rooms: Iterable[str], + ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]: + room_id = room.room_id + logger.info("Requesting summary for %s via %s", room_id, room.via) + + # we need to make the exclusion list json-serialisable + exclude_rooms = list(exclude_rooms) + + via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) + try: + res = await self._federation_client.get_space_summary( + via, + room_id, + suggested_only=suggested_only, + max_rooms_per_space=max_children, + exclude_rooms=exclude_rooms, + ) + except Exception as e: + logger.warning( + "Unable to get summary of %s via federation: %s", + room_id, + e, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return (), () + + return res.rooms, tuple( + ev.data + for ev in res.events + if ev.event_type == EventTypes.MSC1772_SPACE_CHILD + ) + + async def _is_room_accessible(self, room_id: str, requester: Optional[str]) -> bool: + # if we have an authenticated requesting user, first check if they are in the + # room + if requester: + try: + await self._auth.check_user_in_room(room_id, requester) + return True + except AuthError: + pass + + # otherwise, check if the room is peekable + hist_vis_ev = await self._state_handler.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) + if hist_vis_ev: + hist_vis = hist_vis_ev.content.get("history_visibility") + if hist_vis == HistoryVisibility.WORLD_READABLE: + return True + + logger.info( + "room %s is unpeekable and user %s is not a member, omitting from summary", + room_id, + requester, + ) + return False + + async def _build_room_entry(self, room_id: str) -> JsonDict: + """Generate en entry suitable for the 'rooms' list in the summary response""" + stats = await self._store.get_room_with_stats(room_id) + + # currently this should be impossible because we call + # check_user_in_room_or_world_readable on the room before we get here, so + # there should always be an entry + assert stats is not None, "unable to retrieve stats for %s" % (room_id,) + + current_state_ids = await self._store.get_current_state_ids(room_id) + create_event = await self._store.get_event( + current_state_ids[(EventTypes.Create, "")] + ) + + # TODO: update once MSC1772 lands + room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE) + + entry = { + "room_id": stats["room_id"], + "name": stats["name"], + "topic": stats["topic"], + "canonical_alias": stats["canonical_alias"], + "num_joined_members": stats["joined_members"], + "avatar_url": stats["avatar"], + "world_readable": ( + stats["history_visibility"] == HistoryVisibility.WORLD_READABLE + ), + "guest_can_join": stats["guest_access"] == "can_join", + "room_type": room_type, + } + + # Filter out Nones – rather omit the field altogether + room_entry = {k: v for k, v in entry.items() if v is not None} + + return room_entry + + async def _get_child_events(self, room_id: str) -> Iterable[EventBase]: + # look for child rooms/spaces. + current_state_ids = await self._store.get_current_state_ids(room_id) + + events = await self._store.get_events_as_list( + [ + event_id + for key, event_id in current_state_ids.items() + # TODO: update once MSC1772 lands + if key[0] == EventTypes.MSC1772_SPACE_CHILD + ] + ) + + # filter out any events without a "via" (which implies it has been redacted) + return (e for e in events if _has_valid_via(e)) + + +@attr.s(frozen=True, slots=True) +class _RoomQueueEntry: + room_id = attr.ib(type=str) + via = attr.ib(type=Sequence[str]) + + +def _has_valid_via(e: EventBase) -> bool: + via = e.content.get("via") + if not via or not isinstance(via, Sequence): + return False + for v in via: + if not isinstance(v, str): + logger.debug("Ignoring edge event %s with invalid via entry", e.event_id) + return False + return True + + +def _is_suggested_child_event(edge_event: EventBase) -> bool: + suggested = edge_event.content.get("suggested") + if isinstance(suggested, bool) and suggested: + return True + logger.debug("Ignorning not-suggested child %s", edge_event.state_key) + return False diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index b3f9875358..ee8f87e59a 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py
@@ -17,7 +17,7 @@ import logging from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 924281144c..8730f99d03 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py
@@ -24,7 +24,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index a65299bd22..ceed40f90d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -24,6 +24,7 @@ from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.filtering import FilterCollection from synapse.events import EventBase from synapse.logging.context import current_context +from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter @@ -81,7 +82,7 @@ class SyncConfig: filter_collection = attr.ib(type=FilterCollection) is_guest = attr.ib(type=bool) request_key = attr.ib(type=Tuple[Any, ...]) - device_id = attr.ib(type=str) + device_id = attr.ib(type=Optional[str]) @attr.s(slots=True, frozen=True) @@ -252,13 +253,13 @@ class SyncHandler: self.storage = hs.get_storage() self.state_store = self.storage.state - # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) + # ExpiringCache((User, Device)) -> LruCache(user_id => event_id) self.lazy_loaded_members_cache = ExpiringCache( "lazy_loaded_members_cache", self.clock, max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, - ) + ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]] async def wait_for_sync_for_user( self, @@ -341,7 +342,14 @@ class SyncHandler: full_state: bool = False, ) -> SyncResult: """Get the sync for client needed to match what the server has now.""" - return await self.generate_sync_result(sync_config, since_token, full_state) + with start_active_span("current_sync_for_user"): + log_kv({"since_token": since_token}) + sync_result = await self.generate_sync_result( + sync_config, since_token, full_state + ) + + set_tag(SynapseTags.SYNC_RESULT, bool(sync_result)) + return sync_result async def push_rules_for_user(self, user: UserID) -> JsonDict: user_id = user.to_string() @@ -724,8 +732,12 @@ class SyncHandler: return summary - def get_lazy_loaded_members_cache(self, cache_key: Tuple[str, str]) -> LruCache: - cache = self.lazy_loaded_members_cache.get(cache_key) + def get_lazy_loaded_members_cache( + self, cache_key: Tuple[str, Optional[str]] + ) -> LruCache[str, str]: + cache = self.lazy_loaded_members_cache.get( + cache_key + ) # type: Optional[LruCache[str, str]] if cache is None: logger.debug("creating LruCache for %r", cache_key) cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) @@ -963,6 +975,7 @@ class SyncHandler: # to query up to a given point. # Always use the `now_token` in `SyncResultBuilder` now_token = self.event_sources.get_current_token() + log_kv({"now_token": now_token}) logger.debug( "Calculating sync response for %r between %s and %s", @@ -1224,6 +1237,13 @@ class SyncHandler: user_id, device_id, since_stream_id, now_token.to_device_key ) + for message in messages: + # We pop here as we shouldn't be sending the message ID down + # `/sync` + message_id = message.pop("message_id", None) + if message_id: + set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) + logger.debug( "Returning %d to-device messages between %d and %d (current token: %d)", len(messages), @@ -1980,8 +2000,10 @@ class SyncHandler: logger.info("User joined room after current token: %s", room_id) - extrems = await self.store.get_forward_extremeties_for_room( - room_id, event_pos.stream + extrems = ( + await self.store.get_forward_extremities_for_room_at_stream_ordering( + room_id, event_pos.stream + ) ) users_in_room = await self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 096d199f4c..bb35af099d 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -19,7 +19,10 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.replication.tcp.streams import TypingStream from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -86,6 +89,7 @@ class FollowerTypingHandler: self._member_last_federation_poke = {} self.wheel_timer = WheelTimer(bucket_size=5000) + @wrap_as_background_process("typing._handle_timeouts") def _handle_timeouts(self) -> None: logger.debug("Checking for typing timeouts") diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 1a8340000a..b121286d95 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py
@@ -25,7 +25,7 @@ from synapse.types import JsonDict from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/http/client.py b/synapse/http/client.py
index 1e01e0a9f2..e691ba6d88 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -77,7 +77,7 @@ from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -590,7 +590,7 @@ class SimpleHttpClient: uri: str, json_body: Any, args: Optional[QueryParams] = None, - headers: RawHeaders = None, + headers: Optional[RawHeaders] = None, ) -> Any: """Puts some json to the given URI. diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 856e28454f..b797e3ce80 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py
@@ -19,9 +19,10 @@ from zope.interface import implementer from twisted.internet import defer, protocol from twisted.internet.error import ConnectError -from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.internet.protocol import connectionDone +from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint +from twisted.internet.protocol import ClientFactory, Protocol, connectionDone from twisted.web import http +from twisted.web.http_headers import Headers logger = logging.getLogger(__name__) @@ -43,23 +44,33 @@ class HTTPConnectProxyEndpoint: Args: reactor: the Twisted reactor to use for the connection - proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the - proxy - host (bytes): hostname that we want to CONNECT to - port (int): port that we want to connect to + proxy_endpoint: the endpoint to use to connect to the proxy + host: hostname that we want to CONNECT to + port: port that we want to connect to + headers: Extra HTTP headers to include in the CONNECT request """ - def __init__(self, reactor, proxy_endpoint, host, port): + def __init__( + self, + reactor: IReactorCore, + proxy_endpoint: IStreamClientEndpoint, + host: bytes, + port: int, + headers: Headers, + ): self._reactor = reactor self._proxy_endpoint = proxy_endpoint self._host = host self._port = port + self._headers = headers def __repr__(self): return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,) - def connect(self, protocolFactory): - f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory) + def connect(self, protocolFactory: ClientFactory): + f = HTTPProxiedClientFactory( + self._host, self._port, protocolFactory, self._headers + ) d = self._proxy_endpoint.connect(f) # once the tcp socket connects successfully, we need to wait for the # CONNECT to complete. @@ -74,15 +85,23 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): HTTP Protocol object and run the rest of the connection. Args: - dst_host (bytes): hostname that we want to CONNECT to - dst_port (int): port that we want to connect to - wrapped_factory (protocol.ClientFactory): The original Factory + dst_host: hostname that we want to CONNECT to + dst_port: port that we want to connect to + wrapped_factory: The original Factory + headers: Extra HTTP headers to include in the CONNECT request """ - def __init__(self, dst_host, dst_port, wrapped_factory): + def __init__( + self, + dst_host: bytes, + dst_port: int, + wrapped_factory: ClientFactory, + headers: Headers, + ): self.dst_host = dst_host self.dst_port = dst_port self.wrapped_factory = wrapped_factory + self.headers = headers self.on_connection = defer.Deferred() def startedConnecting(self, connector): @@ -92,7 +111,11 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): wrapped_protocol = self.wrapped_factory.buildProtocol(addr) return HTTPConnectProtocol( - self.dst_host, self.dst_port, wrapped_protocol, self.on_connection + self.dst_host, + self.dst_port, + wrapped_protocol, + self.on_connection, + self.headers, ) def clientConnectionFailed(self, connector, reason): @@ -112,24 +135,37 @@ class HTTPConnectProtocol(protocol.Protocol): """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect Args: - host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal + host: The original HTTP(s) hostname or IPv4 or IPv6 address literal to put in the CONNECT request - port (int): The original HTTP(s) port to put in the CONNECT request + port: The original HTTP(s) port to put in the CONNECT request - wrapped_protocol (interfaces.IProtocol): the original protocol (probably - HTTPChannel or TLSMemoryBIOProtocol, but could be anything really) + wrapped_protocol: the original protocol (probably HTTPChannel or + TLSMemoryBIOProtocol, but could be anything really) - connected_deferred (Deferred): a Deferred which will be callbacked with + connected_deferred: a Deferred which will be callbacked with wrapped_protocol when the CONNECT completes + + headers: Extra HTTP headers to include in the CONNECT request """ - def __init__(self, host, port, wrapped_protocol, connected_deferred): + def __init__( + self, + host: bytes, + port: int, + wrapped_protocol: Protocol, + connected_deferred: defer.Deferred, + headers: Headers, + ): self.host = host self.port = port self.wrapped_protocol = wrapped_protocol self.connected_deferred = connected_deferred - self.http_setup_client = HTTPConnectSetupClient(self.host, self.port) + self.headers = headers + + self.http_setup_client = HTTPConnectSetupClient( + self.host, self.port, self.headers + ) self.http_setup_client.on_connected.addCallback(self.proxyConnected) def connectionMade(self): @@ -154,7 +190,7 @@ class HTTPConnectProtocol(protocol.Protocol): if buf: self.wrapped_protocol.dataReceived(buf) - def dataReceived(self, data): + def dataReceived(self, data: bytes): # if we've set up the HTTP protocol, we can send the data there if self.wrapped_protocol.connected: return self.wrapped_protocol.dataReceived(data) @@ -168,21 +204,29 @@ class HTTPConnectSetupClient(http.HTTPClient): """HTTPClient protocol to send a CONNECT message for proxies and read the response. Args: - host (bytes): The hostname to send in the CONNECT message - port (int): The port to send in the CONNECT message + host: The hostname to send in the CONNECT message + port: The port to send in the CONNECT message + headers: Extra headers to send with the CONNECT message """ - def __init__(self, host, port): + def __init__(self, host: bytes, port: int, headers: Headers): self.host = host self.port = port + self.headers = headers self.on_connected = defer.Deferred() def connectionMade(self): logger.debug("Connected to proxy, sending CONNECT") self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) + + # Send any additional specified headers + for name, values in self.headers.getAllRawHeaders(): + for value in values: + self.sendHeader(name, value) + self.endHeaders() - def handleStatus(self, version, status, message): + def handleStatus(self, version: bytes, status: bytes, message: bytes): logger.debug("Got Status: %s %s %s", status, message, version) if status != b"200": raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index ecd63e6596..ce4079f15c 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py
@@ -71,8 +71,10 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3 logger = logging.getLogger(__name__) -_well_known_cache = TTLCache("well-known") -_had_valid_well_known_cache = TTLCache("had-valid-well-known") +_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]] +_had_valid_well_known_cache = TTLCache( + "had-valid-well-known" +) # type: TTLCache[bytes, bool] @attr.s(slots=True, frozen=True) @@ -88,8 +90,8 @@ class WellKnownResolver: reactor: IReactorTime, agent: IAgent, user_agent: bytes, - well_known_cache: Optional[TTLCache] = None, - had_well_known_cache: Optional[TTLCache] = None, + well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None, + had_well_known_cache: Optional[TTLCache[bytes, bool]] = None, ): self._reactor = reactor self._clock = Clock(reactor) diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 3d553ae236..16ec850064 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py
@@ -12,10 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging import re +from typing import Optional, Tuple from urllib.request import getproxies_environment, proxy_bypass_environment +import attr from zope.interface import implementer from twisted.internet import defer @@ -23,6 +26,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.python.failure import Failure from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase from twisted.web.error import SchemeNotSupported +from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint @@ -32,6 +36,22 @@ logger = logging.getLogger(__name__) _VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") +@attr.s +class ProxyCredentials: + username_password = attr.ib(type=bytes) + + def as_proxy_authorization_value(self) -> bytes: + """ + Return the value for a Proxy-Authorization header (i.e. 'Basic abdef=='). + + Returns: + A transformation of the authentication string the encoded value for + a Proxy-Authorization header. + """ + # Encode as base64 and prepend the authorization type + return b"Basic " + base64.encodebytes(self.username_password) + + @implementer(IAgent) class ProxyAgent(_AgentBase): """An Agent implementation which will use an HTTP proxy if one was requested @@ -96,6 +116,9 @@ class ProxyAgent(_AgentBase): https_proxy = proxies["https"].encode() if "https" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None + # Parse credentials from https proxy connection string if present + self.https_proxy_creds, https_proxy = parse_username_password(https_proxy) + self.http_proxy_endpoint = _http_proxy_endpoint( http_proxy, self.proxy_reactor, **self._endpoint_kwargs ) @@ -175,11 +198,22 @@ class ProxyAgent(_AgentBase): and self.https_proxy_endpoint and not should_skip_proxy ): + connect_headers = Headers() + + # Determine whether we need to set Proxy-Authorization headers + if self.https_proxy_creds: + # Set a Proxy-Authorization header + connect_headers.addRawHeader( + b"Proxy-Authorization", + self.https_proxy_creds.as_proxy_authorization_value(), + ) + endpoint = HTTPConnectProxyEndpoint( self.proxy_reactor, self.https_proxy_endpoint, parsed_uri.host, parsed_uri.port, + headers=connect_headers, ) else: # not using a proxy @@ -208,12 +242,16 @@ class ProxyAgent(_AgentBase): ) -def _http_proxy_endpoint(proxy, reactor, **kwargs): +def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs): """Parses an http proxy setting and returns an endpoint for the proxy Args: - proxy (bytes|None): the proxy setting + proxy: the proxy setting in the form: [<username>:<password>@]<host>[:<port>] + Note that compared to other apps, this function currently lacks support + for specifying a protocol schema (i.e. protocol://...). + reactor: reactor to be used to connect to the proxy + kwargs: other args to be passed to HostnameEndpoint Returns: @@ -223,16 +261,43 @@ def _http_proxy_endpoint(proxy, reactor, **kwargs): if proxy is None: return None - # currently we only support hostname:port. Some apps also support - # protocol://<host>[:port], which allows a way of requiring a TLS connection to the - # proxy. - + # Parse the connection string host, port = parse_host_port(proxy, default_port=1080) return HostnameEndpoint(reactor, host, port, **kwargs) -def parse_host_port(hostport, default_port=None): - # could have sworn we had one of these somewhere else... +def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]: + """ + Parses the username and password from a proxy declaration e.g + username:password@hostname:port. + + Args: + proxy: The proxy connection string. + + Returns + An instance of ProxyCredentials and the proxy connection string with any credentials + stripped, i.e u:p@host:port -> host:port. If no credentials were found, the + ProxyCredentials instance is replaced with None. + """ + if proxy and b"@" in proxy: + # We use rsplit here as the password could contain an @ character + credentials, proxy_without_credentials = proxy.rsplit(b"@", 1) + return ProxyCredentials(credentials), proxy_without_credentials + + return None, proxy + + +def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]: + """ + Parse the hostname and port from a proxy connection byte string. + + Args: + hostport: The proxy connection string. Must be in the form 'host[:port]'. + default_port: The default port to return if one is not found in `hostport`. + + Returns: + A tuple containing the hostname and port. Uses `default_port` if one was not found. + """ if b":" in hostport: host, port = hostport.rsplit(b":", 1) try: diff --git a/synapse/http/site.py b/synapse/http/site.py
index 47754aff43..c0c873ce32 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py
@@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Optional, Type, Union +from typing import Optional, Tuple, Type, Union import attr from zope.interface import implementer @@ -26,7 +26,11 @@ from twisted.web.server import Request, Site from synapse.config.server import ListenerConfig from synapse.http import get_request_user_agent, redact_uri from synapse.http.request_metrics import RequestMetrics, requests_counter -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import ( + ContextRequest, + LoggingContext, + PreserveLoggingContext, +) from synapse.types import Requester logger = logging.getLogger(__name__) @@ -63,7 +67,7 @@ class SynapseRequest(Request): # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. - self.requester = None # type: Optional[Union[Requester, str]] + self._requester = None # type: Optional[Union[Requester, str]] # we can't yet create the logcontext, as we don't know the method. self.logcontext = None # type: Optional[LoggingContext] @@ -93,6 +97,31 @@ class SynapseRequest(Request): self.site.site_tag, ) + @property + def requester(self) -> Optional[Union[Requester, str]]: + return self._requester + + @requester.setter + def requester(self, value: Union[Requester, str]) -> None: + # Store the requester, and update some properties based on it. + + # This should only be called once. + assert self._requester is None + + self._requester = value + + # A logging context should exist by now (and have a ContextRequest). + assert self.logcontext is not None + assert self.logcontext.request is not None + + ( + requester, + authenticated_entity, + ) = self.get_authenticated_entity() + self.logcontext.request.requester = requester + # If there's no authenticated entity, it was the requester. + self.logcontext.request.authenticated_entity = authenticated_entity or requester + def get_request_id(self): return "%s-%i" % (self.get_method(), self.request_seq) @@ -126,13 +155,60 @@ class SynapseRequest(Request): return self.method.decode("ascii") return method + def get_authenticated_entity(self) -> Tuple[Optional[str], Optional[str]]: + """ + Get the "authenticated" entity of the request, which might be the user + performing the action, or a user being puppeted by a server admin. + + Returns: + A tuple: + The first item is a string representing the user making the request. + + The second item is a string or None representing the user who + authenticated when making this request. See + Requester.authenticated_entity. + """ + # Convert the requester into a string that we can log + if isinstance(self._requester, str): + return self._requester, None + elif isinstance(self._requester, Requester): + requester = self._requester.user.to_string() + authenticated_entity = self._requester.authenticated_entity + + # If this is a request where the target user doesn't match the user who + # authenticated (e.g. and admin is puppetting a user) then we return both. + if self._requester.user.to_string() != authenticated_entity: + return requester, authenticated_entity + + return requester, None + elif self._requester is not None: + # This shouldn't happen, but we log it so we don't lose information + # and can see that we're doing something wrong. + return repr(self._requester), None # type: ignore[unreachable] + + return None, None + def render(self, resrc): # this is called once a Resource has been found to serve the request; in our # case the Resource in question will normally be a JsonResource. # create a LogContext for this request request_id = self.get_request_id() - self.logcontext = LoggingContext(request_id, request=request_id) + self.logcontext = LoggingContext( + request_id, + request=ContextRequest( + request_id=request_id, + ip_address=self.getClientIP(), + site_tag=self.site.site_tag, + # The requester is going to be unknown at this point. + requester=None, + authenticated_entity=None, + method=self.get_method(), + url=self.get_redacted_uri(), + protocol=self.clientproto.decode("ascii", errors="replace"), + user_agent=get_request_user_agent(self), + ), + ) # override the Server header which is set by twisted self.setHeader("Server", self.site.server_version_string) @@ -277,25 +353,6 @@ class SynapseRequest(Request): # to the client (nb may be negative) response_send_time = self.finish_time - self._processing_finished_time - # Convert the requester into a string that we can log - authenticated_entity = None - if isinstance(self.requester, str): - authenticated_entity = self.requester - elif isinstance(self.requester, Requester): - authenticated_entity = self.requester.authenticated_entity - - # If this is a request where the target user doesn't match the user who - # authenticated (e.g. and admin is puppetting a user) then we log both. - if self.requester.user.to_string() != authenticated_entity: - authenticated_entity = "{},{}".format( - authenticated_entity, - self.requester.user.to_string(), - ) - elif self.requester is not None: - # This shouldn't happen, but we log it so we don't lose information - # and can see that we're doing something wrong. - authenticated_entity = repr(self.requester) # type: ignore[unreachable] - user_agent = get_request_user_agent(self, "-") code = str(self.code) @@ -305,6 +362,13 @@ class SynapseRequest(Request): code += "!" log_level = logging.INFO if self._should_log_request() else logging.DEBUG + + # If this is a request where the target user doesn't match the user who + # authenticated (e.g. and admin is puppetting a user) then we log both. + requester, authenticated_entity = self.get_authenticated_entity() + if authenticated_entity: + requester = "{}.{}".format(authenticated_entity, requester) + self.site.access_logger.log( log_level, "%s - %s - {%s}" @@ -312,7 +376,7 @@ class SynapseRequest(Request): ' %sB %s "%s %s %s" "%s" [%d dbevts]', self.getClientIP(), self.site.site_tag, - authenticated_entity, + requester, processing_time, response_send_time, usage.ru_utime, diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 1a7ea4fa96..e78343f554 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py
@@ -22,7 +22,6 @@ them. See doc/log_contexts.rst for details on how this works. """ - import inspect import logging import threading @@ -30,6 +29,7 @@ import types import warnings from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union +import attr from typing_extensions import Literal from twisted.internet import defer, threads @@ -181,6 +181,29 @@ class ContextResourceUsage: return res +@attr.s(slots=True) +class ContextRequest: + """ + A bundle of attributes from the SynapseRequest object. + + This exists to: + + * Avoid a cycle between LoggingContext and SynapseRequest. + * Be a single variable that can be passed from parent LoggingContexts to + their children. + """ + + request_id = attr.ib(type=str) + ip_address = attr.ib(type=str) + site_tag = attr.ib(type=str) + requester = attr.ib(type=Optional[str]) + authenticated_entity = attr.ib(type=Optional[str]) + method = attr.ib(type=str) + url = attr.ib(type=str) + protocol = attr.ib(type=str) + user_agent = attr.ib(type=str) + + LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] @@ -256,7 +279,7 @@ class LoggingContext: self, name: Optional[str] = None, parent_context: "Optional[LoggingContext]" = None, - request: Optional[str] = None, + request: Optional[ContextRequest] = None, ) -> None: self.previous_context = current_context() self.name = name @@ -281,7 +304,11 @@ class LoggingContext: self.parent_context = parent_context if self.parent_context is not None: - self.parent_context.copy_to(self) + # we track the current request_id + self.request = self.parent_context.request + + # we also track the current scope: + self.scope = self.parent_context.scope if request is not None: # the request param overrides the request from the parent context @@ -289,7 +316,7 @@ class LoggingContext: def __str__(self) -> str: if self.request: - return str(self.request) + return self.request.request_id return "%s@%x" % (self.name, id(self)) @classmethod @@ -556,8 +583,23 @@ class LoggingContextFilter(logging.Filter): # we end up in a death spiral of infinite loops, so let's check, for # robustness' sake. if context is not None: - # Logging is interested in the request. - record.request = context.request # type: ignore + # Logging is interested in the request ID. Note that for backwards + # compatibility this is stored as the "request" on the record. + record.request = str(context) # type: ignore + + # Add some data from the HTTP request. + request = context.request + if request is None: + return True + + record.ip_address = request.ip_address # type: ignore + record.site_tag = request.site_tag # type: ignore + record.requester = request.requester # type: ignore + record.authenticated_entity = request.authenticated_entity # type: ignore + record.method = request.method # type: ignore + record.url = request.url # type: ignore + record.protocol = request.protocol # type: ignore + record.user_agent = request.user_agent # type: ignore return True @@ -630,8 +672,8 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe def nested_logging_context(suffix: str) -> LoggingContext: """Creates a new logging context as a child of another. - The nested logging context will have a 'request' made up of the parent context's - request, plus the given suffix. + The nested logging context will have a 'name' made up of the parent context's + name, plus the given suffix. CPU/db usage stats will be added to the parent context's on exit. @@ -641,7 +683,7 @@ def nested_logging_context(suffix: str) -> LoggingContext: # ... do stuff Args: - suffix: suffix to add to the parent context's 'request'. + suffix: suffix to add to the parent context's 'name'. Returns: LoggingContext: new logging context. @@ -653,11 +695,17 @@ def nested_logging_context(suffix: str) -> LoggingContext: ) parent_context = None prefix = "" + request = None else: assert isinstance(curr_context, LoggingContext) parent_context = curr_context - prefix = str(parent_context.request) - return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix) + prefix = str(parent_context.name) + request = parent_context.request + return LoggingContext( + prefix + "-" + suffix, + parent_context=parent_context, + request=request, + ) def preserve_fn(f): @@ -689,7 +737,7 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: current = current_context() try: res = f(*args, **kwargs) - except: # noqa: E722 + except Exception: # the assumption here is that the caller doesn't want to be disturbed # by synchronous exceptions, so let's turn them into Failures. return defer.fail() diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 10bd4a1461..b8081f197e 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py
@@ -169,7 +169,7 @@ import inspect import logging import re from functools import wraps -from typing import TYPE_CHECKING, Dict, Optional, Type +from typing import TYPE_CHECKING, Dict, Optional, Pattern, Type import attr @@ -259,10 +259,18 @@ except ImportError: logger = logging.getLogger(__name__) +class SynapseTags: + # The message ID of any to_device message processed + TO_DEVICE_MESSAGE_ID = "to_device.message_id" + + # Whether the sync response has new data to be returned to the client. + SYNC_RESULT = "sync.new_data" + + # Block everything by default # A regex which matches the server_names to expose traces for. # None means 'block everything'. -_homeserver_whitelist = None +_homeserver_whitelist = None # type: Optional[Pattern[str]] # Util methods diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 3b499efc07..13a5bc4558 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py
@@ -214,7 +214,12 @@ class GaugeBucketCollector: Prometheus, and optimise for that case. """ - __slots__ = ("_name", "_documentation", "_bucket_bounds", "_metric") + __slots__ = ( + "_name", + "_documentation", + "_bucket_bounds", + "_metric", + ) def __init__( self, @@ -242,11 +247,16 @@ class GaugeBucketCollector: if self._bucket_bounds[-1] != float("inf"): self._bucket_bounds.append(float("inf")) - self._metric = self._values_to_metric([]) + # We initially set this to None. We won't report metrics until + # this has been initialised after a successful data update + self._metric = None # type: Optional[GaugeHistogramMetricFamily] + registry.register(self) def collect(self): - yield self._metric + # Don't report metrics unless we've already collected some data + if self._metric is not None: + yield self._metric def update_data(self, values: Iterable[float]): """Update the data to be reported by the metric diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index b56986d8e7..e8a9096c03 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py
@@ -16,7 +16,7 @@ import logging import threading from functools import wraps -from typing import TYPE_CHECKING, Dict, Optional, Set +from typing import TYPE_CHECKING, Dict, Optional, Set, Union from prometheus_client.core import REGISTRY, Counter, Gauge @@ -199,11 +199,11 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar _background_process_start_count.labels(desc).inc() _background_process_in_flight_count.labels(desc).inc() - with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context: + with BackgroundProcessLoggingContext(desc, count) as context: try: ctx = noop_context_manager() if bg_start_span: - ctx = start_active_span(desc, tags={"request_id": context.request}) + ctx = start_active_span(desc, tags={"request_id": str(context)}) with ctx: return await maybe_awaitable(func(*args, **kwargs)) except Exception: @@ -242,13 +242,19 @@ class BackgroundProcessLoggingContext(LoggingContext): processes. """ - __slots__ = ["_proc"] + __slots__ = ["_id", "_proc"] - def __init__(self, name: str, request: Optional[str] = None): - super().__init__(name, request=request) + def __init__(self, name: str, id: Optional[Union[int, str]] = None): + super().__init__(name) + self._id = id self._proc = _BackgroundProcess(name, self) + def __str__(self) -> str: + if self._id is not None: + return "%s-%s" % (self.name, self._id) + return "%s@%x" % (self.name, id(self)) + def start(self, rusage: "Optional[resource._RUsage]"): """Log context has started running (again).""" diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 781e02fbbb..3ecd46c038 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -50,11 +50,20 @@ class ModuleApi: self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname + self._presence_stream = hs.get_event_sources().sources["presence"] # We expose these as properties below in order to attach a helpful docstring. self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient self._public_room_list_manager = PublicRoomListManager(hs) + # The next time these users sync, they will receive the current presence + # state of all local users. Users are added by send_local_online_presence_to, + # and removed after a successful sync. + # + # We make this a private variable to deter modules from accessing it directly, + # though other classes in Synapse will still do so. + self._send_full_presence_to_local_users = set() + @property def http_client(self): """Allows making outbound HTTP requests to remote resources. @@ -385,6 +394,47 @@ class ModuleApi: return event + async def send_local_online_presence_to(self, users: Iterable[str]) -> None: + """ + Forces the equivalent of a presence initial_sync for a set of local or remote + users. The users will receive presence for all currently online users that they + are considered interested in. + + Updates to remote users will be sent immediately, whereas local users will receive + them on their next sync attempt. + + Note that this method can only be run on the main or federation_sender worker + processes. + """ + if not self._hs.should_send_federation(): + raise Exception( + "send_local_online_presence_to can only be run " + "on processes that send federation", + ) + + for user in users: + if self._hs.is_mine_id(user): + # Modify SyncHandler._generate_sync_entry_for_presence to call + # presence_source.get_new_events with an empty `from_key` if + # that user's ID were in a list modified by ModuleApi somewhere. + # That user would then get all presence state on next incremental sync. + + # Force a presence initial_sync for this user next time + self._send_full_presence_to_local_users.add(user) + else: + # Retrieve presence state for currently online users that this user + # is considered interested in + presence_events, _ = await self._presence_stream.get_new_events( + UserID.from_string(user), from_key=None, include_offline=False + ) + + # Send to remote destinations + await make_deferred_yieldable( + # We pull the federation sender here as we can only do so on workers + # that support sending presence + self._hs.get_federation_sender().send_presence(presence_events) + ) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/notifier.py b/synapse/notifier.py
index 1374aae490..c178db57e3 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py
@@ -39,6 +39,7 @@ from synapse.api.errors import AuthError from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import PreserveLoggingContext +from synapse.logging.opentracing import log_kv, start_active_span from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.streams.config import PaginationConfig @@ -136,6 +137,15 @@ class _NotifierUserStream: self.last_notified_ms = time_now_ms noify_deferred = self.notify_deferred + log_kv( + { + "notify": self.user_id, + "stream": stream_key, + "stream_id": stream_id, + "listeners": self.count_listeners(), + } + ) + users_woken_by_stream_counter.labels(stream_key).inc() with PreserveLoggingContext(): @@ -404,6 +414,13 @@ class Notifier: with Measure(self.clock, "on_new_event"): user_streams = set() + log_kv( + { + "waking_up_explicit_users": len(users), + "waking_up_explicit_rooms": len(rooms), + } + ) + for user in users: user_stream = self.user_to_user_stream.get(str(user)) if user_stream is not None: @@ -476,12 +493,34 @@ class Notifier: (end_time - now) / 1000.0, self.hs.get_reactor(), ) - with PreserveLoggingContext(): - await listener.deferred + + with start_active_span("wait_for_events.deferred"): + log_kv( + { + "wait_for_events": "sleep", + "token": prev_token, + } + ) + + with PreserveLoggingContext(): + await listener.deferred + + log_kv( + { + "wait_for_events": "woken", + "token": user_stream.current_token, + } + ) current_token = user_stream.current_token result = await callback(prev_token, current_token) + log_kv( + { + "wait_for_events": "result", + "result": bool(result), + } + ) if result: break @@ -489,8 +528,10 @@ class Notifier: # has happened between the old prev_token and the current_token prev_token = current_token except defer.TimeoutError: + log_kv({"wait_for_events": "timeout"}) break except defer.CancelledError: + log_kv({"wait_for_events": "cancelled"}) break if result is None: @@ -507,7 +548,7 @@ class Notifier: pagination_config: PaginationConfig, timeout: int, is_guest: bool = False, - explicit_room_id: str = None, + explicit_room_id: Optional[str] = None, ) -> EventStreamResult: """For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index f4f7ec96f8..9fc3da49a2 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py
@@ -21,7 +21,7 @@ import attr from synapse.types import JsonDict, RoomStreamToken if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer @attr.s(slots=True) diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index aaed28650d..38a47a600f 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py
@@ -22,7 +22,7 @@ from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index c016a83909..1897f59153 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -33,7 +33,7 @@ from synapse.util.caches.lrucache import LruCache from .push_rule_evaluator import PushRuleEvaluatorForEvent if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 3dc06a79e8..c0968dc7a1 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py
@@ -24,7 +24,7 @@ from synapse.push import Pusher, PusherConfig, ThrottleParams from synapse.push.mailer import Mailer if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index d31043fec7..6eac57914e 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py
@@ -31,7 +31,7 @@ from synapse.push import Pusher, PusherConfig, PusherConfigException from . import push_rule_evaluator, push_tools if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -295,7 +295,7 @@ class HttpPusher(Pusher): if rejected is False: return False - if isinstance(rejected, list) or isinstance(rejected, tuple): + if isinstance(rejected, (list, tuple)): for pk in rejected: if pk != self.pushkey: # for sanity, we only remove the pushkey if it diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index d10201b6b3..2e5161de2c 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py
@@ -40,7 +40,7 @@ from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 2aa7918fb4..cb94127850 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py
@@ -22,7 +22,7 @@ from synapse.push.httppusher import HttpPusher from synapse.push.mailer import Mailer if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 321a333820..2a1c925ee8 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py
@@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging from typing import List, Set @@ -82,6 +83,9 @@ REQUIREMENTS = [ "Jinja2>=2.9", "bleach>=1.4.3", "typing-extensions>=3.7.4", + # We enforce that we have a `cryptography` version that bundles an `openssl` + # with the latest security patches. + "cryptography>=3.4.7;python_version>='3.6'", ] CONDITIONAL_REQUIREMENTS = { @@ -98,7 +102,7 @@ CONDITIONAL_REQUIREMENTS = { "txacme>=0.9.2", # txacme depends on eliot. Eliot 1.8.0 is incompatible with # python 3.5.2, as per https://github.com/itamarst/eliot/issues/418 - 'eliot<1.8.0;python_version<"3.5.3"', + "eliot<1.8.0;python_version<'3.5.3'", ], "saml2": [ # pysaml2 6.4.0 is incompatible with Python 3.5 (see https://github.com/IdentityPython/pysaml2/issues/749) @@ -128,6 +132,18 @@ for name, optional_deps in CONDITIONAL_REQUIREMENTS.items(): ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS +# ensure there are no double-quote characters in any of the deps (otherwise the +# 'pip install' incantation in DependencyException will break) +for dep in itertools.chain( + REQUIREMENTS, + *CONDITIONAL_REQUIREMENTS.values(), +): + if '"' in dep: + raise Exception( + "Dependency `%s` contains double-quote; use single-quotes instead" % (dep,) + ) + + def list_requirements(): return list(set(REQUIREMENTS) | ALL_OPTIONAL_REQUIREMENTS) @@ -147,7 +163,7 @@ class DependencyException(Exception): @property def dependencies(self): for i in self.args[0]: - yield "'" + i + "'" + yield '"' + i + '"' def check_requirements(for_feature=None): diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 8af53b4f28..82ea3b895f 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py
@@ -40,6 +40,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): // containing the event "event_format_version": .., // 1,2,3 etc: the event format version "internal_metadata": { .. serialized internal_metadata .. }, + "outlier": true|false, "rejected_reason": .., // The event.rejected_reason field "context": { .. serialized event context .. }, }], @@ -84,6 +85,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): "room_version": event.room_version.identifier, "event_format_version": event.format_version, "internal_metadata": event.internal_metadata.get_dict(), + "outlier": event.internal_metadata.is_outlier(), "rejected_reason": event.rejected_reason, "context": serialized_context, } @@ -116,6 +118,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): event = make_event_from_dict( event_dict, room_ver, internal_metadata, rejected_reason ) + event.internal_metadata.outlier = event_payload["outlier"] context = EventContext.deserialize( self.storage, event_payload["context"] diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index d005f38767..73d7477854 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py
@@ -77,7 +77,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - self.registration_handler.check_registration_ratelimit(content["address"]) + await self.registration_handler.check_registration_ratelimit(content["address"]) await self.registration_handler.register_with_store( user_id=user_id, diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 8fa104c8d3..a4c5b44292 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py
@@ -40,6 +40,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): // containing the event "event_format_version": .., // 1,2,3 etc: the event format version "internal_metadata": { .. serialized internal_metadata .. }, + "outlier": true|false, "rejected_reason": .., // The event.rejected_reason field "context": { .. serialized event context .. }, "requester": { .. serialized requester .. }, @@ -79,7 +80,6 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): ratelimit (bool) extra_users (list(UserID)): Any extra users to notify about event """ - serialized_context = await context.serialize(event, store) payload = { @@ -87,6 +87,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "room_version": event.room_version.identifier, "event_format_version": event.format_version, "internal_metadata": event.internal_metadata.get_dict(), + "outlier": event.internal_metadata.is_outlier(), "rejected_reason": event.rejected_reason, "context": serialized_context, "requester": requester.serialize(), @@ -108,6 +109,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event = make_event_from_dict( event_dict, room_ver, internal_metadata, rejected_reason ) + event.internal_metadata.outlier = content["outlier"] requester = Requester.deserialize(self.store, content["requester"]) context = EventContext.deserialize(self.storage, content["context"]) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 045bd014da..93161c3dfb 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py
@@ -24,7 +24,7 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index bb447f75b4..8abed1f52d 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py
@@ -312,16 +312,16 @@ class FederationAckCommand(Command): NAME = "FEDERATION_ACK" - def __init__(self, instance_name, token): + def __init__(self, instance_name: str, token: int): self.instance_name = instance_name self.token = token @classmethod - def from_line(cls, line): + def from_line(cls, line: str) -> "FederationAckCommand": instance_name, token = line.split(" ") return cls(instance_name, int(token)) - def to_line(self): + def to_line(self) -> str: return "%s %s" % (self.instance_name, self.token) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 825900f64c..d10d574246 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -104,7 +104,7 @@ tcp_outbound_commands_counter = Counter( # A list of all connected protocols. This allows us to send metrics about the # connections. -connected_connections = [] +connected_connections = [] # type: List[BaseReplicationStreamProtocol] logger = logging.getLogger(__name__) @@ -184,8 +184,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. - ctx_name = "replication-conn-%s" % self.conn_id - self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name) + self._logging_context = BackgroundProcessLoggingContext( + "replication-conn", self.conn_id + ) def connectionMade(self): logger.info("[%s] Connection established", self.id()) diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 2f4d407f94..98bdeb0ec6 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py
@@ -60,7 +60,7 @@ class ConstantProperty(Generic[T, V]): constant = attr.ib() # type: V - def __get__(self, obj: Optional[T], objtype: Type[T] = None) -> V: + def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V: return self.constant def __set__(self, obj: Optional[T], value: V): diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f45e7a8c89..3dfee76743 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -33,7 +33,7 @@ import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates if TYPE_CHECKING: - import synapse.server + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -299,20 +299,23 @@ class TypingStream(Stream): NAME = "typing" ROW_TYPE = TypingStreamRow - def __init__(self, hs): - typing_handler = hs.get_typing_handler() - + def __init__(self, hs: "HomeServer"): writer_instance = hs.config.worker.writers.typing if writer_instance == hs.get_instance_name(): # On the writer, query the typing handler - update_function = typing_handler.get_all_typing_updates + typing_writer_handler = hs.get_typing_writer_handler() + update_function = ( + typing_writer_handler.get_all_typing_updates + ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] + current_token_function = typing_writer_handler.get_current_token else: # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) + current_token_function = hs.get_typing_handler().get_current_token super().__init__( hs.get_instance_name(), - current_token_without_instance(typing_handler.get_current_token), + current_token_without_instance(current_token_function), update_function, ) @@ -509,7 +512,7 @@ class AccountDataStream(Stream): NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 9bcd13b009..9bb8e9e177 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py
@@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple from synapse.replication.tcp.streams._base import ( Stream, @@ -21,6 +22,9 @@ from synapse.replication.tcp.streams._base import ( make_http_update_function, ) +if TYPE_CHECKING: + from synapse.server import HomeServer + class FederationStream(Stream): """Data to be sent over federation. Only available when master has federation @@ -38,7 +42,7 @@ class FederationStream(Stream): NAME = "federation" ROW_TYPE = FederationStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): if hs.config.worker_app is None: # master process: get updates from the FederationRemoteSendQueue. # (if the master is configured to send federation itself, federation_sender @@ -48,7 +52,9 @@ class FederationStream(Stream): current_token = current_token_without_instance( federation_sender.get_current_token ) - update_function = federation_sender.get_replication_rows + update_function = ( + federation_sender.get_replication_rows + ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] elif hs.should_send_federation(): # federation sender: Query master process @@ -69,5 +75,7 @@ class FederationStream(Stream): return 0 @staticmethod - async def _stub_update_function(instance_name, from_token, upto_token, limit): + async def _stub_update_function( + instance_name: str, from_token: int, upto_token: int, limit: int + ) -> Tuple[list, int, bool]: return [], upto_token, False diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 7fcc48a9d7..40646ef241 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py
@@ -28,7 +28,7 @@ from synapse.rest.admin._base import ( from synapse.types import JsonDict if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 263d8ec076..cfe1bebb91 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -390,6 +390,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): async def on_POST( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 2c89b62e25..fa7804583a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -36,6 +36,7 @@ from synapse.rest.admin._base import ( ) from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.databases.main.media_repository import MediaSortOrder +from synapse.storage.databases.main.stats import UserSortOrder from synapse.types import JsonDict, UserID if TYPE_CHECKING: @@ -117,8 +118,26 @@ class UsersRestServletV2(RestServlet): guests = parse_boolean(request, "guests", default=True) deactivated = parse_boolean(request, "deactivated", default=False) + order_by = parse_string( + request, + "order_by", + default=UserSortOrder.NAME.value, + allowed_values=( + UserSortOrder.NAME.value, + UserSortOrder.DISPLAYNAME.value, + UserSortOrder.GUEST.value, + UserSortOrder.ADMIN.value, + UserSortOrder.DEACTIVATED.value, + UserSortOrder.USER_TYPE.value, + UserSortOrder.AVATAR_URL.value, + UserSortOrder.SHADOW_BANNED.value, + ), + ) + + direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + users, total = await self.store.get_users_paginate( - start, limit, user_id, name, guests, deactivated + start, limit, user_id, name, guests, deactivated, order_by, direction ) ret = {"users": users, "total": total} if (start + limit) < total: @@ -271,7 +290,7 @@ class UserRestServletV2(RestServlet): elif not deactivate and user["deactivated"]: if ( "password" not in body - and self.hs.config.password_localdb_enabled + and self.auth_handler.can_change_password() ): raise SynapseError( 400, "Must provide a password to re-activate an account." @@ -833,6 +852,9 @@ class UserMediaRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index e4c352f572..3151e72d4f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -74,11 +74,13 @@ class LoginRestServlet(RestServlet): self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( + store=hs.get_datastore(), clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( + store=hs.get_datastore(), clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, @@ -141,20 +143,22 @@ class LoginRestServlet(RestServlet): appservice = self.auth.get_appservice_by_req(request) if appservice.is_rate_limited(): - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit( + None, request.getClientIP() + ) result = await self._do_appservice_login(login_submission, appservice) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_token_login(login_submission) else: - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -258,7 +262,7 @@ class LoginRestServlet(RestServlet): # too often. This happens here rather than before as we don't # necessarily know the user before now. if ratelimit: - self._account_ratelimiter.ratelimit(user_id.lower()) + await self._account_ratelimiter.ratelimit(None, user_id.lower()) if create_non_existent_users: canonical_uid = await self.auth_handler.check_user_exists(user_id) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 5884daea6d..525efdf221 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py
@@ -18,7 +18,7 @@ import logging import re -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple from urllib import parse as urlparse from synapse.api.constants import EventTypes, Membership @@ -35,21 +35,30 @@ from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_boolean, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.logging.opentracing import set_tag from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig -from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID +from synapse.types import ( + JsonDict, + RoomAlias, + RoomID, + StreamToken, + ThirdPartyInstanceID, + UserID, +) from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: - import synapse.server + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -846,10 +855,10 @@ class RoomTypingRestServlet(RestServlet): "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() + self.hs = hs self.presence_handler = hs.get_presence_handler() - self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() # If we're not on the typing writer instance we should scream if we get @@ -874,16 +883,19 @@ class RoomTypingRestServlet(RestServlet): # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) + # Defer getting the typing handler since it will raise on workers. + typing_handler = self.hs.get_typing_writer_handler() + try: if content["typing"]: - await self.typing_handler.started_typing( + await typing_handler.started_typing( target_user=target_user, requester=requester, room_id=room_id, timeout=timeout, ) else: - await self.typing_handler.stopped_typing( + await typing_handler.stopped_typing( target_user=target_user, requester=requester, room_id=room_id ) except ShadowBanError: @@ -901,7 +913,7 @@ class RoomAliasListServlet(RestServlet): ), ] - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.directory_handler = hs.get_directory_handler() @@ -984,7 +996,58 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): ) -def register_servlets(hs, http_server, is_worker=False): +class RoomSpaceSummaryRestServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2946" + "/rooms/(?P<room_id>[^/]*)/spaces$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._space_summary_handler = hs.get_space_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + + return 200, await self._space_summary_handler.get_space_summary( + requester.user.to_string(), + room_id, + suggested_only=parse_boolean(request, "suggested_only", default=False), + max_rooms_per_space=parse_integer(request, "max_rooms_per_space"), + ) + + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + content = parse_json_object_from_request(request) + + suggested_only = content.get("suggested_only", False) + if not isinstance(suggested_only, bool): + raise SynapseError( + 400, "'suggested_only' must be a boolean", Codes.BAD_JSON + ) + + max_rooms_per_space = content.get("max_rooms_per_space") + if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON + ) + + return 200, await self._space_summary_handler.get_space_summary( + requester.user.to_string(), + room_id, + suggested_only=suggested_only, + max_rooms_per_space=max_rooms_per_space, + ) + + +def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) @@ -998,6 +1061,9 @@ def register_servlets(hs, http_server, is_worker=False): RoomTypingRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) + if hs.config.experimental.spaces_enabled: + RoomSpaceSummaryRestServlet(hs).register(http_server) + # Some servlets only get registered for the main process. if not is_worker: RoomCreateRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index adf1d39728..411fb57c47 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py
@@ -45,7 +45,7 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed from ._base import client_patterns, interactive_auth_handler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -103,7 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to @@ -387,7 +389,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) if next_link: # Raise if the provided next_link value isn't valid @@ -468,7 +472,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests( + await self.identity_handler.ratelimit_request_token_requests( request, "msisdn", msisdn ) diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index 76879ac559..44ccf10ed4 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -13,12 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -27,21 +33,16 @@ class CapabilitiesRestServlet(RestServlet): PATTERNS = client_patterns("/capabilities$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.config = hs.config self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user = await self.store.get_user_by_id(requester.user.to_string()) - change_password = bool(user["password_hash"]) + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.auth.get_user_by_req(request, allow_guest=True) + change_password = self.auth_handler.can_change_password() response = { "capabilities": { @@ -58,5 +59,5 @@ class CapabilitiesRestServlet(RestServlet): return 200, response -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): CapabilitiesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 5901432fad..08fb6b2b06 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py
@@ -38,7 +38,7 @@ from synapse.types import GroupID, JsonDict from ._base import client_patterns if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 8f68d8dfc8..c212da0cb2 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py
@@ -126,7 +126,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email @@ -208,7 +210,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests( + await self.identity_handler.ratelimit_request_token_requests( request, "msisdn", msisdn ) @@ -406,7 +408,7 @@ class RegisterRestServlet(RestServlet): client_addr = request.getClientIP() - self.ratelimiter.ratelimit(client_addr, update=False) + await self.ratelimiter.ratelimit(None, client_addr, update=False) kind = b"user" if b"kind" in request.args: diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 8e52e4cca4..3481770c83 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py
@@ -15,6 +15,7 @@ import itertools import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import PresenceState from synapse.api.errors import Codes, StoreError, SynapseError @@ -26,11 +27,15 @@ from synapse.events.utils import ( from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sync import SyncConfig from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string -from synapse.types import StreamToken +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder from ._base import client_patterns, set_timeline_upper_limit +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -73,7 +78,7 @@ class SyncRestServlet(RestServlet): PATTERNS = client_patterns("/sync$") ALLOWED_PRESENCE = {"online", "offline", "unavailable"} - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -85,7 +90,10 @@ class SyncRestServlet(RestServlet): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + if b"from" in request.args: # /events used to use 'from', but /sync uses 'since'. # Lets be helpful and whine if we see a 'from'. diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index d24a199318..3e3d8839f4 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py
@@ -81,6 +81,8 @@ class VersionsRestServlet(RestServlet): "io.element.e2ee_forced.public": self.e2ee_forced_public, "io.element.e2ee_forced.private": self.e2ee_forced_private, "io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private, + # Supports the busy presence state described in MSC3026. + "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, }, }, ) diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 1eff98ef14..c41a7ab412 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py
@@ -23,7 +23,7 @@ from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.site import SynapseRequest if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer class MediaConfigResource(DirectServeJsonResource): diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 8a43581f1f..5dadaeaf57 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py
@@ -24,8 +24,8 @@ from synapse.http.servlet import parse_boolean from ._base import parse_media_id, respond_404 if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 8b4841ed5d..0c041b542d 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py
@@ -58,7 +58,7 @@ from .thumbnailer import Thumbnailer, ThumbnailError from .upload_resource import UploadResource if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index b8895aeaa9..814145a04a 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -54,8 +54,8 @@ from ._base import FileInfo if TYPE_CHECKING: from lxml import etree - from synapse.app.homeserver import HomeServer from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -175,7 +175,7 @@ class PreviewUrlResource(DirectServeJsonResource): clock=self.clock, # don't spider URLs more often than once an hour expiry_ms=ONE_HOUR, - ) + ) # type: ExpiringCache[str, ObservableDeferred] if self._worker_run_media_background_jobs: self._cleaner_loop = self.clock.looping_call( @@ -187,6 +187,8 @@ class PreviewUrlResource(DirectServeJsonResource): respond_with_json(request, 200, {}, send_cors=True) async def _async_render_GET(self, request: SynapseRequest) -> None: + # This will always be set by the time Twisted calls us. + assert request.args is not None # XXX: if get_user_by_req fails, what should we do in an async render? requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index e92006faa9..031947557d 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py
@@ -29,7 +29,7 @@ from .media_storage import FileResponder logger = logging.getLogger(__name__) if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer class StorageProvider(metaclass=abc.ABCMeta): diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index fbcd50f1e2..af802bc0b1 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -34,8 +34,8 @@ from ._base import ( ) if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index ae5aef2f7f..0138b2e2d1 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py
@@ -26,8 +26,8 @@ from synapse.http.site import SynapseRequest from synapse.rest.media.v1.media_storage import SpamMediaException if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index 51acaa9a92..d9ffe84489 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py
@@ -104,6 +104,9 @@ class AccountDetailsResource(DirectServeHtmlResource): respond_with_html(request, 200, html) async def _async_render_POST(self, request: SynapseRequest): + # This will always be set by the time Twisted calls us. + assert request.args is not None + try: session_id = get_username_mapping_session_cookie_from_request(request) except SynapseError as e: diff --git a/synapse/secrets.py b/synapse/secrets.py
index fb6d90a3b7..7939db75e7 100644 --- a/synapse/secrets.py +++ b/synapse/secrets.py
@@ -26,10 +26,10 @@ if sys.version_info[0:2] >= (3, 6): import secrets class Secrets: - def token_bytes(self, nbytes=32): + def token_bytes(self, nbytes: int = 32) -> bytes: return secrets.token_bytes(nbytes) - def token_hex(self, nbytes=32): + def token_hex(self, nbytes: int = 32) -> str: return secrets.token_hex(nbytes) @@ -38,8 +38,8 @@ else: import os class Secrets: - def token_bytes(self, nbytes=32): + def token_bytes(self, nbytes: int = 32) -> bytes: return os.urandom(nbytes) - def token_hex(self, nbytes=32): + def token_hex(self, nbytes: int = 32) -> str: return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii") diff --git a/synapse/server.py b/synapse/server.py
index 48ac87a124..cfb55c230d 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -51,6 +51,7 @@ from synapse.crypto import context_factory from synapse.crypto.context_factory import RegularPolicyForHTTPS from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory +from synapse.events.presence_router import PresenceRouter from synapse.events.spamcheck import SpamChecker from synapse.events.third_party_rules import ThirdPartyEventRules from synapse.events.utils import EventClientSerializer @@ -60,7 +61,7 @@ from synapse.federation.federation_server import ( FederationServer, ) from synapse.federation.send_queue import FederationRemoteSendQueue -from synapse.federation.sender import FederationSender +from synapse.federation.sender import AbstractFederationSender, FederationSender from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler @@ -96,10 +97,11 @@ from synapse.handlers.room import ( RoomShutdownHandler, ) from synapse.handlers.room_list import RoomListHandler -from synapse.handlers.room_member import RoomMemberMasterHandler +from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.search import SearchHandler from synapse.handlers.set_password import SetPasswordHandler +from synapse.handlers.space_summary import SpaceSummaryHandler from synapse.handlers.sso import SsoHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler @@ -328,6 +330,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( + store=self.get_datastore(), clock=self.get_clock(), rate_hz=self.config.rc_registration.per_second, burst_count=self.config.rc_registration.burst_count, @@ -417,10 +420,23 @@ class HomeServer(metaclass=abc.ABCMeta): return PresenceHandler(self) @cache_in_self - def get_typing_handler(self): + def get_typing_writer_handler(self) -> TypingWriterHandler: if self.config.worker.writers.typing == self.get_instance_name(): return TypingWriterHandler(self) else: + raise Exception("Workers cannot write typing") + + @cache_in_self + def get_presence_router(self) -> PresenceRouter: + return PresenceRouter(self) + + @cache_in_self + def get_typing_handler(self) -> FollowerTypingHandler: + if self.config.worker.writers.typing == self.get_instance_name(): + # Use get_typing_writer_handler to ensure that we use the same + # cached version. + return self.get_typing_writer_handler() + else: return FollowerTypingHandler(self) @cache_in_self @@ -561,7 +577,7 @@ class HomeServer(metaclass=abc.ABCMeta): return TransportLayerClient(self) @cache_in_self - def get_federation_sender(self): + def get_federation_sender(self) -> AbstractFederationSender: if self.should_send_federation(): return FederationSender(self) elif not self.config.worker_app: @@ -630,7 +646,7 @@ class HomeServer(metaclass=abc.ABCMeta): return ThirdPartyEventRules(self) @cache_in_self - def get_room_member_handler(self): + def get_room_member_handler(self) -> RoomMemberHandler: if self.config.worker_app: return RoomMemberWorkerHandler(self) return RoomMemberMasterHandler(self) @@ -640,13 +656,13 @@ class HomeServer(metaclass=abc.ABCMeta): return FederationHandlerRegistry(self) @cache_in_self - def get_server_notices_manager(self): + def get_server_notices_manager(self) -> ServerNoticesManager: if self.config.worker_app: raise Exception("Workers cannot send server notices") return ServerNoticesManager(self) @cache_in_self - def get_server_notices_sender(self): + def get_server_notices_sender(self) -> WorkerServerNoticesSender: if self.config.worker_app: return WorkerServerNoticesSender(self) return ServerNoticesSender(self) @@ -724,6 +740,10 @@ class HomeServer(metaclass=abc.ABCMeta): return AccountDataHandler(self) @cache_in_self + def get_space_summary_handler(self) -> SpaceSummaryHandler: + return SpaceSummaryHandler(self) + + @cache_in_self def get_external_cache(self) -> ExternalCache: return ExternalCache(self) diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 9137c4edb1..a9349bf9a1 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py
@@ -13,13 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any +from typing import TYPE_CHECKING, Any, Set from synapse.api.errors import SynapseError from synapse.api.urls import ConsentURIBuilder from synapse.config import ConfigError from synapse.types import get_localpart_from_id +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -28,16 +31,11 @@ class ConsentServerNotices: privacy policy consent, and sends one if we do. """ - def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() self._store = hs.get_datastore() - self._users_in_progress = set() + self._users_in_progress = set() # type: Set[str] self._current_consent_version = hs.config.user_consent_version self._server_notice_content = hs.config.user_consent_server_notice_content @@ -73,6 +71,10 @@ class ConsentServerNotices: try: u = await self._store.get_user_by_id(user_id) + # The user doesn't exist. + if u is None: + return + if u["is_guest"] and not self._send_to_guests: # don't send to guests return diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 6652451346..a18a2e76c9 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Tuple +from typing import TYPE_CHECKING, List, Tuple from synapse.api.constants import ( EventTypes, @@ -24,6 +24,9 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError, ResourceLimitError, SynapseError from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -32,11 +35,7 @@ class ResourceLimitsServerNotices: ensures that the client is kept up to date. """ - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() self._store = hs.get_datastore() self._auth = hs.get_auth() diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index c46b2f047d..144e1da78e 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py
@@ -58,7 +58,7 @@ class ServerNoticesManager: user_id: str, event_content: dict, type: str = EventTypes.Message, - state_key: Optional[bool] = None, + state_key: Optional[str] = None, ) -> EventBase: """Send a notice to the given user diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py
index 6870b67ca0..965c645889 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py
@@ -12,25 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Union +from typing import TYPE_CHECKING, Iterable, Union from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) +from synapse.server_notices.worker_server_notices_sender import ( + WorkerServerNoticesSender, +) + +if TYPE_CHECKING: + from synapse.server import HomeServer -class ServerNoticesSender: +class ServerNoticesSender(WorkerServerNoticesSender): """A centralised place which sends server notices automatically when Certain Events take place """ - def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): + super().__init__(hs) self._server_notices = ( ConsentServerNotices(hs), ResourceLimitsServerNotices(hs), diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py
index 9273e61895..c76bd57460 100644 --- a/synapse/server_notices/worker_server_notices_sender.py +++ b/synapse/server_notices/worker_server_notices_sender.py
@@ -12,16 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from synapse.server import HomeServer class WorkerServerNoticesSender: """Stub impl of ServerNoticesSender which does nothing""" - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): + pass async def on_user_syncing(self, user_id: str) -> None: """Called when the user performs a sync operation. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index c3d6e80c49..c0f79ffdc8 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py
@@ -22,6 +22,7 @@ from typing import ( Callable, DefaultDict, Dict, + FrozenSet, Iterable, List, Optional, @@ -515,7 +516,7 @@ class StateResolutionHandler: expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, iterable=True, reset_expiry_on_get=True, - ) + ) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry] # # stuff for tracking time spent on state-res by room @@ -536,7 +537,7 @@ class StateResolutionHandler: state_groups_ids: Dict[int, StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", - ): + ) -> _StateCacheEntry: """Resolves conflicts between a set of state groups Always generates a new state group (unless we hit the cache), so should diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index a3c52695e9..0b9007e51f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py
@@ -36,7 +36,7 @@ from synapse.storage.purge_events import PurgeEventsStorage from synapse.storage.state import StateGroupStorage if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer __all__ = ["Databases", "DataStore"] diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a25c4093bc..240905329f 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -27,7 +27,7 @@ from synapse.types import Collection, StreamToken, get_domain_from_id from synapse.util import json_decoder if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 329660cf0f..ccb06aab39 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py
@@ -23,7 +23,7 @@ from synapse.util import json_encoder from . import engines if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer from synapse.storage.database import DatabasePool, LoggingTransaction logger = logging.getLogger(__name__) diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index f1ba529a2d..94590e7b45 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -670,7 +670,7 @@ class DatabasePool: for after_callback, after_args, after_kwargs in after_callbacks: after_callback(*after_args, **after_kwargs) - except: # noqa: E722, as we reraise the exception this is fine. + except Exception: for after_callback, after_args, after_kwargs in exception_callbacks: after_callback(*after_args, **after_kwargs) raise @@ -1906,6 +1906,7 @@ class DatabasePool: retcols: Iterable[str], filters: Optional[Dict[str, Any]] = None, keyvalues: Optional[Dict[str, Any]] = None, + exclude_keyvalues: Optional[Dict[str, Any]] = None, order_direction: str = "ASC", ) -> List[Dict[str, Any]]: """ @@ -1929,7 +1930,10 @@ class DatabasePool: apply a WHERE ? LIKE ? clause. keyvalues: column names and values to select the rows with, or None to not - apply a WHERE clause. + apply a WHERE key = value clause. + exclude_keyvalues: + column names and values to exclude rows with, or None to not + apply a WHERE key != value clause. order_direction: Whether the results should be ordered "ASC" or "DESC". Returns: @@ -1938,7 +1942,7 @@ class DatabasePool: if order_direction not in ["ASC", "DESC"]: raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") - where_clause = "WHERE " if filters or keyvalues else "" + where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else "" arg_list = [] # type: List[Any] if filters: where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) @@ -1947,6 +1951,9 @@ class DatabasePool: if keyvalues: where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues) arg_list += list(keyvalues.values()) + if exclude_keyvalues: + where_clause += " AND ".join("%s != ?" % (k,) for k in exclude_keyvalues) + arg_list += list(exclude_keyvalues.values()) sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( ", ".join(retcols), diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 1d44c3aa2c..b3d16ca7ac 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -21,6 +21,7 @@ from typing import List, Optional, Tuple from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( IdGenerator, @@ -292,6 +293,8 @@ class DataStore( name: Optional[str] = None, guests: bool = True, deactivated: bool = False, + order_by: UserSortOrder = UserSortOrder.USER_ID.value, + direction: str = "f", ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the @@ -304,6 +307,8 @@ class DataStore( name: search for local part of user_id or display name guests: whether to in include guest users deactivated: whether to include deactivated users + order_by: the sort order of the returned list + direction: sort ascending or descending Returns: A tuple of a list of mappings from user to information and a count of total users. """ @@ -312,6 +317,14 @@ class DataStore( filters = [] args = [self.hs.config.server_name] + # Set ordering + order_by_column = UserSortOrder(order_by).value + + if direction == "b": + order = "DESC" + else: + order = "ASC" + # `name` is in database already in lower case if name: filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)") @@ -339,10 +352,15 @@ class DataStore( txn.execute(sql, args) count = txn.fetchone()[0] - sql = ( - "SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url " - + sql_base - + " ORDER BY u.name LIMIT ? OFFSET ?" + sql = """ + SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url + {sql_base} + ORDER BY {order_by_column} {order}, u.name ASC + LIMIT ? OFFSET ? + """.format( + sql_base=sql_base, + order_by_column=order_by_column, + order=order, ) args += [limit, start] txn.execute(sql, args) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 03a38422a1..85bb853d33 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -32,7 +32,7 @@ from synapse.types import JsonDict from synapse.util import json_encoder if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 45ca6620a8..691080ce74 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import List, Tuple +from typing import List, Optional, Tuple from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.replication.tcp.streams import ToDeviceStream @@ -115,7 +115,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): async def get_new_messages_for_device( self, user_id: str, - device_id: str, + device_id: Optional[str], last_stream_id: int, current_stream_id: int, limit: int = 100, @@ -163,7 +163,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): @trace async def delete_messages_for_device( - self, user_id: str, device_id: str, up_to_stream_id: int + self, user_id: str, device_id: Optional[str], up_to_stream_id: int ) -> int: """ Args: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 332193ad1c..a956be491a 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -793,7 +793,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return int(min_depth) if min_depth is not None else None - async def get_forward_extremeties_for_room( + async def get_forward_extremities_for_room_at_stream_ordering( self, room_id: str, stream_ordering: int ) -> List[str]: """For a given room_id and stream_ordering, return the forward diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index cd1ceac50e..98dac19a95 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -1270,8 +1270,10 @@ class PersistEventsStore: logger.exception("") raise + # update the stored internal_metadata to update the "outlier" flag. + # TODO: This is unused as of Synapse 1.31. Remove it once we are happy + # to drop backwards-compatibility with 1.30. metadata_json = json_encoder.encode(event.internal_metadata.get_dict()) - sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" txn.execute(sql, (metadata_json, event.event_id)) @@ -1319,6 +1321,19 @@ class PersistEventsStore: d.pop("redacted_because", None) return d + def get_internal_metadata(event): + im = event.internal_metadata.get_dict() + + # temporary hack for database compatibility with Synapse 1.30 and earlier: + # store the `outlier` flag inside the internal_metadata json as well as in + # the `events` table, so that if anyone rolls back to an older Synapse, + # things keep working. This can be removed once we are happy to drop support + # for that + if event.internal_metadata.is_outlier(): + im["outlier"] = True + + return im + self.db_pool.simple_insert_many_txn( txn, table="event_json", @@ -1327,7 +1342,7 @@ class PersistEventsStore: "event_id": event.event_id, "room_id": event.room_id, "internal_metadata": json_encoder.encode( - event.internal_metadata.get_dict() + get_internal_metadata(event) ), "json": json_encoder.encode(event_dict(event)), "format_version": event.format_version, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c04e162ccc..c00780969f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -16,7 +16,7 @@ import logging import threading from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Tuple, overload +from typing import Container, Dict, Iterable, List, Optional, Tuple, overload from constantly import NamedConstant, Names from typing_extensions import Literal @@ -544,7 +544,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_stripped_room_state_from_event_context( self, context: EventContext, - state_types_to_include: List[EventTypes], + state_types_to_include: Container[str], membership_user_id: Optional[str] = None, ) -> List[JsonDict]: """ @@ -799,6 +799,7 @@ class EventsWorkerStore(SQLBaseStore): rejected_reason=rejected_reason, ) original_ev.internal_metadata.stream_ordering = row["stream_ordering"] + original_ev.internal_metadata.outlier = row["outlier"] event_map[event_id] = original_ev @@ -905,7 +906,8 @@ class EventsWorkerStore(SQLBaseStore): ej.json, ej.format_version, r.room_version, - rej.reason + rej.reason, + e.outlier FROM events AS e JOIN event_json AS ej USING (event_id) LEFT JOIN rooms r ON r.room_id = e.room_id @@ -929,6 +931,7 @@ class EventsWorkerStore(SQLBaseStore): "room_version_id": row[5], "rejected_reason": row[6], "redactions": [], + "outlier": row[7], } # check for redactions diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ac07e0197b..8f462dfc31 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py
@@ -1027,8 +1027,8 @@ class GroupServerStore(GroupServerWorkerStore): user_id: str, is_admin: bool = False, is_public: bool = True, - local_attestation: dict = None, - remote_attestation: dict = None, + local_attestation: Optional[dict] = None, + remote_attestation: Optional[dict] = None, ) -> None: """Add a user to the group server. diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 4f3d192562..b7820ac7ff 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -22,6 +22,9 @@ from synapse.storage.database import DatabasePool BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( "media_repository_drop_index_wo_method" ) +BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( + "media_repository_drop_index_wo_method_2" +) class MediaSortOrder(Enum): @@ -85,23 +88,35 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): unique=True, ) + # the original impl of _drop_media_index_without_method was broken (see + # https://github.com/matrix-org/synapse/issues/8649), so we replace the original + # impl with a no-op and run the fixed migration as + # media_repository_drop_index_wo_method_2. + self.db_pool.updates.register_noop_background_update( + BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD + ) self.db_pool.updates.register_background_update_handler( - BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD, + BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2, self._drop_media_index_without_method, ) async def _drop_media_index_without_method(self, progress, batch_size): + """background update handler which removes the old constraints. + + Note that this is only run on postgres. + """ + def f(txn): txn.execute( "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key" ) txn.execute( - "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key" + "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_cache_thumbnails_media_origin_media_id_thumbna_key" ) await self.db_pool.runInteraction("drop_media_indices_without_method", f) await self.db_pool.updates._end_background_update( - BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD + BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 ) return 1 diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index d788dc0fc6..757da3d55d 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, List +from typing import Dict, List, Optional from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore @@ -109,7 +109,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): return users @cached(num_args=1) - async def user_last_seen_monthly_active(self, user_id: str) -> int: + async def user_last_seen_monthly_active(self, user_id: str) -> Optional[int]: """ Checks if a given user is part of the monthly active user group diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 29edab34d4..0ff693a310 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import Dict, List, Tuple from synapse.api.presence import UserPresenceState from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause @@ -157,5 +157,63 @@ class PresenceStore(SQLBaseStore): return {row["user_id"]: UserPresenceState(**row) for row in rows} + async def get_presence_for_all_users( + self, + include_offline: bool = True, + ) -> Dict[str, UserPresenceState]: + """Retrieve the current presence state for all users. + + Note that the presence_stream table is culled frequently, so it should only + contain the latest presence state for each user. + + Args: + include_offline: Whether to include offline presence states + + Returns: + A dict of user IDs to their current UserPresenceState. + """ + users_to_state = {} + + exclude_keyvalues = None + if not include_offline: + # Exclude offline presence state + exclude_keyvalues = {"state": "offline"} + + # This may be a very heavy database query. + # We paginate in order to not block a database connection. + limit = 100 + offset = 0 + while True: + rows = await self.db_pool.runInteraction( + "get_presence_for_all_users", + self.db_pool.simple_select_list_paginate_txn, + "presence_stream", + orderby="stream_id", + start=offset, + limit=limit, + exclude_keyvalues=exclude_keyvalues, + retcols=( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), + order_direction="ASC", + ) + + for row in rows: + users_to_state[row["user_id"]] = UserPresenceState(**row) + + # We've run out of updates to query + if len(rows) < limit: + break + + offset += limit + + return users_to_state + def get_current_presence_token(self): return self._presence_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 85f1ebac98..c65558c280 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -27,7 +27,7 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index eba66ff352..90a8f664ef 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -1210,6 +1210,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream( txn, self.get_user_deactivated_status, (user_id,) ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) @cached() diff --git a/synapse/storage/databases/main/schema/delta/59/11drop_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/59/11drop_thumbnail_constraint.sql.postgres new file mode 100644
index 0000000000..54c1bca3b1 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/11drop_thumbnail_constraint.sql.postgres
@@ -0,0 +1,22 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- drop old constraints on remote_media_cache_thumbnails +-- +-- This was originally part of 57.07, but it was done wrong, per +-- https://github.com/matrix-org/synapse/issues/8649, so we do it again. +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (5911, 'media_repository_drop_index_wo_method_2', '{}', 'remote_media_repository_thumbnails_method_idx'); + diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 1c99393c65..bce8946c21 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py
@@ -66,18 +66,37 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class UserSortOrder(Enum): """ Enum to define the sorting method used when returning users - with get_users_media_usage_paginate + with get_users_paginate in __init__.py + and get_users_media_usage_paginate in stats.py - MEDIA_LENGTH = ordered by size of uploaded media. Smallest to largest. - MEDIA_COUNT = ordered by number of uploaded media. Smallest to largest. + When moves this to __init__.py gets `builtins.ImportError` with + `most likely due to a circular import` + + MEDIA_LENGTH = ordered by size of uploaded media. + MEDIA_COUNT = ordered by number of uploaded media. USER_ID = ordered alphabetically by `user_id`. + NAME = ordered alphabetically by `user_id`. This is for compatibility reasons, + as the user_id is returned in the name field in the response in list users admin API. DISPLAYNAME = ordered alphabetically by `displayname` + GUEST = ordered by `is_guest` + ADMIN = ordered by `admin` + DEACTIVATED = ordered by `deactivated` + USER_TYPE = ordered alphabetically by `user_type` + AVATAR_URL = ordered alphabetically by `avatar_url` + SHADOW_BANNED = ordered by `shadow_banned` """ MEDIA_LENGTH = "media_length" MEDIA_COUNT = "media_count" USER_ID = "user_id" + NAME = "name" DISPLAYNAME = "displayname" + GUEST = "is_guest" + ADMIN = "admin" + DEACTIVATED = "deactivated" + USER_TYPE = "user_type" + AVATAR_URL = "avatar_url" + SHADOW_BANNED = "shadow_banned" class StatsStore(StateDeltasStore): diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 0309661841..b7072f1f5e 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -22,7 +22,6 @@ from canonicaljson import encode_canonical_json from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction -from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import JsonDict from synapse.util.caches.expiringcache import ExpiringCache @@ -312,49 +311,23 @@ class TransactionStore(TransactionWorkerStore): stream_ordering: the stream_ordering of the event """ - return await self.db_pool.runInteraction( - "store_destination_rooms_entries", - self._store_destination_rooms_entries_txn, - destinations, - room_id, - stream_ordering, + await self.db_pool.simple_upsert_many( + table="destinations", + key_names=("destination",), + key_values=[(d,) for d in destinations], + value_names=[], + value_values=[], + desc="store_destination_rooms_entries_dests", ) - def _store_destination_rooms_entries_txn( - self, - txn: LoggingTransaction, - destinations: Iterable[str], - room_id: str, - stream_ordering: int, - ) -> None: - - # ensure we have a `destinations` row for this destination, as there is - # a foreign key constraint. - if isinstance(self.database_engine, PostgresEngine): - q = """ - INSERT INTO destinations (destination) - VALUES (?) - ON CONFLICT DO NOTHING; - """ - elif isinstance(self.database_engine, Sqlite3Engine): - q = """ - INSERT OR IGNORE INTO destinations (destination) - VALUES (?); - """ - else: - raise RuntimeError("Unknown database engine") - - txn.execute_batch(q, ((destination,) for destination in destinations)) - rows = [(destination, room_id) for destination in destinations] - - self.db_pool.simple_upsert_many_txn( - txn, + await self.db_pool.simple_upsert_many( table="destination_rooms", key_names=("destination", "room_id"), key_values=rows, value_names=["stream_ordering"], value_values=[(stream_ordering,)] * len(rows), + desc="store_destination_rooms_entries_rooms", ) async def get_destination_last_successful_stream_ordering( diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index e2240703a7..97ec65f757 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py
@@ -183,12 +183,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): requests state from the cache, if False we need to query the DB for the missing state. """ - is_all, known_absent, state_dict_ids = cache.get(group) + cache_entry = cache.get(group) + state_dict_ids = cache_entry.value - if is_all or state_filter.is_full(): + if cache_entry.full or state_filter.is_full(): # Either we have everything or want everything, either way # `is_all` tells us whether we've gotten everything. - return state_filter.filter_state(state_dict_ids), is_all + return state_filter.filter_state(state_dict_ids), cache_entry.full # tracks whether any of our requested types are missing from the cache missing_types = False @@ -202,7 +203,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # There aren't any wild cards, so `concrete_types()` returns the # complete list of event types we're wanting. for key in state_filter.concrete_types(): - if key not in state_dict_ids and key not in known_absent: + if key not in state_dict_ids and key not in cache_entry.known_absent: missing_types = True break diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6c3c2da520..c7f0b8ccb5 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py
@@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import imp +import importlib.util import logging import os import re @@ -454,8 +454,13 @@ def _upgrade_existing_database( ) module_name = "synapse.storage.v%d_%s" % (v, root_name) - with open(absolute_path) as python_file: - module = imp.load_source(module_name, absolute_path, python_file) # type: ignore + + spec = importlib.util.spec_from_file_location( + module_name, absolute_path + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) # type: ignore + logger.info("Running script %s", relative_path) module.run_create(cur, database_engine) # type: ignore if not is_empty: diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index 4dcd848c59..ad954990a7 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Set from synapse.storage.databases import Databases if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index d179a41884..2e277a21c4 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py
@@ -32,7 +32,7 @@ from synapse.events import EventBase from synapse.types import MutableStateMap, StateMap if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer from synapse.storage.databases import Databases logger = logging.getLogger(__name__) @@ -449,7 +449,7 @@ class StateGroupStorage: return self.stores.state._get_state_groups_from_groups(groups, state_filter) async def get_state_for_events( - self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -485,7 +485,7 @@ class StateGroupStorage: return {event: event_to_state[event] for event in event_ids} async def get_state_ids_for_events( - self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f33c115844..c3b2d981ea 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -496,7 +496,7 @@ def timeout_deferred( try: deferred.cancel() - except: # noqa: E722, if we throw any exception it'll break time outs + except Exception: # if we throw any exception it'll break time outs logger.exception("Canceller failed during timeout") # the cancel() call should have set off a chain of errbacks which diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index e676c2cac4..48f64eeb38 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py
@@ -25,8 +25,8 @@ from synapse.config.cache import add_resizable_cache logger = logging.getLogger(__name__) -caches_by_name = {} -collectors_by_name = {} # type: Dict +caches_by_name = {} # type: Dict[str, Sized] +collectors_by_name = {} # type: Dict[str, CacheMetric] cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) @@ -116,7 +116,7 @@ def register_cache( """ if resizable: if not resize_callback: - resize_callback = getattr(cache, "set_cache_factor") + resize_callback = cache.set_cache_factor # type: ignore add_resizable_cache(cache_name, resize_callback) metric = CacheMetric(cache, cache_type, cache_name, collect_callback) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1adc92eb90..dd392cf694 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py
@@ -283,7 +283,9 @@ class DeferredCache(Generic[KT, VT]): # we return a new Deferred which will be called before any subsequent observers. return observable.observe() - def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): + def prefill( + self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None + ): callbacks = [callback] if callback else [] self.cache.set(key, value, callbacks=callbacks) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 588d2d49f2..b3b413b02c 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py
@@ -15,26 +15,38 @@ import enum import logging import threading -from collections import namedtuple -from typing import Any +from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar + +import attr from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) -class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))): +# The type of the cache keys. +KT = TypeVar("KT") +# The type of the dictionary keys. +DKT = TypeVar("DKT") + + +@attr.s(slots=True) +class DictionaryEntry: """Returned when getting an entry from the cache Attributes: - full (bool): Whether the cache has the full or dict or just some keys. + full: Whether the cache has the full or dict or just some keys. If not full then not all requested keys will necessarily be present in `value` - known_absent (set): Keys that were looked up in the dict and were not + known_absent: Keys that were looked up in the dict and were not there. - value (dict): The full or partial dict value + value: The full or partial dict value """ + full = attr.ib(type=bool) + known_absent = attr.ib() + value = attr.ib() + def __len__(self): return len(self.value) @@ -45,21 +57,21 @@ class _Sentinel(enum.Enum): sentinel = object() -class DictionaryCache: +class DictionaryCache(Generic[KT, DKT]): """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ - def __init__(self, name, max_entries=1000): + def __init__(self, name: str, max_entries: int = 1000): self.cache = LruCache( max_size=max_entries, cache_name=name, size_callback=len - ) # type: LruCache[Any, DictionaryEntry] + ) # type: LruCache[KT, DictionaryEntry] self.name = name self.sequence = 0 - self.thread = None + self.thread = None # type: Optional[threading.Thread] - def check_thread(self): + def check_thread(self) -> None: expected_thread = self.thread if expected_thread is None: self.thread = threading.current_thread() @@ -69,12 +81,14 @@ class DictionaryCache: "Cache objects can only be accessed from the main thread" ) - def get(self, key, dict_keys=None): + def get( + self, key: KT, dict_keys: Optional[Iterable[DKT]] = None + ) -> DictionaryEntry: """Fetch an entry out of the cache Args: key - dict_key(list): If given a set of keys then return only those keys + dict_key: If given a set of keys then return only those keys that exist in the cache. Returns: @@ -95,7 +109,7 @@ class DictionaryCache: return DictionaryEntry(False, set(), {}) - def invalidate(self, key): + def invalidate(self, key: KT) -> None: self.check_thread() # Increment the sequence number so that any SELECT statements that @@ -103,19 +117,25 @@ class DictionaryCache: self.sequence += 1 self.cache.pop(key, None) - def invalidate_all(self): + def invalidate_all(self) -> None: self.check_thread() self.sequence += 1 self.cache.clear() - def update(self, sequence, key, value, fetched_keys=None): + def update( + self, + sequence: int, + key: KT, + value: Dict[DKT, Any], + fetched_keys: Optional[Set[DKT]] = None, + ) -> None: """Updates the entry in the cache Args: sequence - key (K) - value (dict[X,Y]): The value to update the cache with. - fetched_keys (None|set[X]): All of the dictionary keys which were + key + value: The value to update the cache with. + fetched_keys: All of the dictionary keys which were fetched from the database. If None, this is the complete value for key K. Otherwise, it @@ -131,7 +151,9 @@ class DictionaryCache: else: self._update_or_insert(key, value, fetched_keys) - def _update_or_insert(self, key, value, known_absent): + def _update_or_insert( + self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT] + ) -> None: # We pop and reinsert as we need to tell the cache the size may have # changed @@ -140,5 +162,5 @@ class DictionaryCache: entry.known_absent.update(known_absent) self.cache[key] = entry - def _insert(self, key, value, known_absent): + def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None: self.cache[key] = DictionaryEntry(True, known_absent, value) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index e15f7ee698..4dc3477e89 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py
@@ -15,40 +15,50 @@ import logging from collections import OrderedDict +from typing import Any, Generic, Optional, TypeVar, Union, overload + +import attr +from typing_extensions import Literal from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import Clock from synapse.util.caches import register_cache logger = logging.getLogger(__name__) -SENTINEL = object() +SENTINEL = object() # type: Any + +T = TypeVar("T") +KT = TypeVar("KT") +VT = TypeVar("VT") -class ExpiringCache: + +class ExpiringCache(Generic[KT, VT]): def __init__( self, - cache_name, - clock, - max_len=0, - expiry_ms=0, - reset_expiry_on_get=False, - iterable=False, + cache_name: str, + clock: Clock, + max_len: int = 0, + expiry_ms: int = 0, + reset_expiry_on_get: bool = False, + iterable: bool = False, ): """ Args: - cache_name (str): Name of this cache, used for logging. - clock (Clock) - max_len (int): Max size of dict. If the dict grows larger than this + cache_name: Name of this cache, used for logging. + clock + max_len: Max size of dict. If the dict grows larger than this then the oldest items get automatically evicted. Default is 0, which indicates there is no max limit. - expiry_ms (int): How long before an item is evicted from the cache + expiry_ms: How long before an item is evicted from the cache in milliseconds. Default is 0, indicating items never get evicted based on time. - reset_expiry_on_get (bool): If true, will reset the expiry time for + reset_expiry_on_get: If true, will reset the expiry time for an item on access. Defaults to False. - iterable (bool): If true, the size is calculated by summing the + iterable: If true, the size is calculated by summing the sizes of all entries, rather than the number of entries. """ self._cache_name = cache_name @@ -62,7 +72,7 @@ class ExpiringCache: self._expiry_ms = expiry_ms self._reset_expiry_on_get = reset_expiry_on_get - self._cache = OrderedDict() + self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry] self.iterable = iterable @@ -79,12 +89,12 @@ class ExpiringCache: self._clock.looping_call(f, self._expiry_ms / 2) - def __setitem__(self, key, value): + def __setitem__(self, key: KT, value: VT) -> None: now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) self.evict() - def evict(self): + def evict(self) -> None: # Evict if there are now too many items while self._max_size and len(self) > self._max_size: _key, value = self._cache.popitem(last=False) @@ -93,7 +103,7 @@ class ExpiringCache: else: self.metrics.inc_evictions() - def __getitem__(self, key): + def __getitem__(self, key: KT) -> VT: try: entry = self._cache[key] self.metrics.inc_hits() @@ -106,7 +116,7 @@ class ExpiringCache: return entry.value - def pop(self, key, default=SENTINEL): + def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: """Removes and returns the value with the given key from the cache. If the key isn't in the cache then `default` will be returned if @@ -115,29 +125,40 @@ class ExpiringCache: Identical functionality to `dict.pop(..)`. """ - value = self._cache.pop(key, default) + value = self._cache.pop(key, SENTINEL) + # The key was not found. if value is SENTINEL: - raise KeyError(key) + if default is SENTINEL: + raise KeyError(key) + return default - return value + return value.value - def __contains__(self, key): + def __contains__(self, key: KT) -> bool: return key in self._cache - def get(self, key, default=None): + @overload + def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]: + ... + + @overload + def get(self, key: KT, default: T) -> Union[VT, T]: + ... + + def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]: try: return self[key] except KeyError: return default - def setdefault(self, key, value): + def setdefault(self, key: KT, value: VT) -> VT: try: return self[key] except KeyError: self[key] = value return value - def _prune_cache(self): + def _prune_cache(self) -> None: if not self._expiry_ms: # zero expiry time means don't expire. This should never get called # since we have this check in start too. @@ -166,7 +187,7 @@ class ExpiringCache: len(self), ) - def __len__(self): + def __len__(self) -> int: if self.iterable: return sum(len(entry.value) for entry in self._cache.values()) else: @@ -190,9 +211,7 @@ class ExpiringCache: return False +@attr.s(slots=True) class _CacheEntry: - __slots__ = ["time", "value"] - - def __init__(self, time, value): - self.time = time - self.value = value + time = attr.ib(type=int) + value = attr.ib() diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 6ce2a3d12b..96a8274940 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py
@@ -15,6 +15,7 @@ import logging import time +from typing import Any, Callable, Dict, Generic, Tuple, TypeVar, Union import attr from sortedcontainers import SortedList @@ -23,15 +24,19 @@ from synapse.util.caches import register_cache logger = logging.getLogger(__name__) -SENTINEL = object() +SENTINEL = object() # type: Any +T = TypeVar("T") +KT = TypeVar("KT") +VT = TypeVar("VT") -class TTLCache: + +class TTLCache(Generic[KT, VT]): """A key/value cache implementation where each entry has its own TTL""" - def __init__(self, cache_name, timer=time.time): + def __init__(self, cache_name: str, timer: Callable[[], float] = time.time): # map from key to _CacheEntry - self._data = {} + self._data = {} # type: Dict[KT, _CacheEntry] # the _CacheEntries, sorted by expiry time self._expiry_list = SortedList() # type: SortedList[_CacheEntry] @@ -40,26 +45,27 @@ class TTLCache: self._metrics = register_cache("ttl", cache_name, self, resizable=False) - def set(self, key, value, ttl): + def set(self, key: KT, value: VT, ttl: float) -> None: """Add/update an entry in the cache Args: key: key for this entry value: value for this entry - ttl (float): TTL for this entry, in seconds + ttl: TTL for this entry, in seconds """ expiry = self._timer() + ttl self.expire() e = self._data.pop(key, SENTINEL) - if e != SENTINEL: + if e is not SENTINEL: + assert isinstance(e, _CacheEntry) self._expiry_list.remove(e) entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value) self._data[key] = entry self._expiry_list.add(entry) - def get(self, key, default=SENTINEL): + def get(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: """Get a value from the cache Args: @@ -72,23 +78,23 @@ class TTLCache: """ self.expire() e = self._data.get(key, SENTINEL) - if e == SENTINEL: + if e is SENTINEL: self._metrics.inc_misses() - if default == SENTINEL: + if default is SENTINEL: raise KeyError(key) return default + assert isinstance(e, _CacheEntry) self._metrics.inc_hits() return e.value - def get_with_expiry(self, key): + def get_with_expiry(self, key: KT) -> Tuple[VT, float, float]: """Get a value, and its expiry time, from the cache Args: key: key to look up Returns: - Tuple[Any, float, float]: the value from the cache, the expiry time - and the TTL + A tuple of the value from the cache, the expiry time and the TTL Raises: KeyError if the entry is not found @@ -102,7 +108,7 @@ class TTLCache: self._metrics.inc_hits() return e.value, e.expiry_time, e.ttl - def pop(self, key, default=SENTINEL): + def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore """Remove a value from the cache If key is in the cache, remove it and return its value, else return default. @@ -118,29 +124,30 @@ class TTLCache: """ self.expire() e = self._data.pop(key, SENTINEL) - if e == SENTINEL: + if e is SENTINEL: self._metrics.inc_misses() - if default == SENTINEL: + if default is SENTINEL: raise KeyError(key) return default + assert isinstance(e, _CacheEntry) self._expiry_list.remove(e) self._metrics.inc_hits() return e.value - def __getitem__(self, key): + def __getitem__(self, key: KT) -> VT: return self.get(key) - def __delitem__(self, key): + def __delitem__(self, key: KT) -> None: self.pop(key) - def __contains__(self, key): + def __contains__(self, key: KT) -> bool: return key in self._data - def __len__(self): + def __len__(self) -> int: self.expire() return len(self._data) - def expire(self): + def expire(self) -> None: """Run the expiry on the cache. Any entries whose expiry times are due will be removed """ @@ -158,7 +165,7 @@ class _CacheEntry: """TTLCache entry""" # expiry_time is the first attribute, so that entries are sorted by expiry. - expiry_time = attr.ib() - ttl = attr.ib() + expiry_time = attr.ib(type=float) + ttl = attr.ib(type=float) key = attr.ib() value = attr.ib() diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 5f7a6dd1d3..5ca2e71e60 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py
@@ -36,7 +36,7 @@ def freeze(o): def unfreeze(o): if isinstance(o, (dict, frozendict)): - return dict({k: unfreeze(v) for k, v in o.items()}) + return {k: unfreeze(v) for k, v in o.items()} if isinstance(o, (bytes, str)): return o diff --git a/synapse/visibility.py b/synapse/visibility.py
index e39d02602a..ff53a49b3a 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import operator +from typing import Dict, FrozenSet, List, Optional from synapse.api.constants import ( AccountDataTypes, @@ -21,10 +21,11 @@ from synapse.api.constants import ( HistoryVisibility, Membership, ) +from synapse.events import EventBase from synapse.events.utils import prune_event from synapse.storage import Storage from synapse.storage.state import StateFilter -from synapse.types import get_domain_from_id +from synapse.types import StateMap, get_domain_from_id logger = logging.getLogger(__name__) @@ -48,32 +49,32 @@ MEMBERSHIP_PRIORITY = ( async def filter_events_for_client( storage: Storage, - user_id, - events, - is_peeking=False, - always_include_ids=frozenset(), - filter_send_to_client=True, -): + user_id: str, + events: List[EventBase], + is_peeking: bool = False, + always_include_ids: FrozenSet[str] = frozenset(), + filter_send_to_client: bool = True, +) -> List[EventBase]: """ Check which events a user is allowed to see. If the user can see the event but its sender asked for their data to be erased, prune the content of the event. Args: storage - user_id(str): user id to be checked - events(list[synapse.events.EventBase]): sequence of events to be checked - is_peeking(bool): should be True if: + user_id: user id to be checked + events: sequence of events to be checked + is_peeking: should be True if: * the user is not currently a member of the room, and: * the user has not been a member of the room since the given events - always_include_ids (set(event_id)): set of event ids to specifically + always_include_ids: set of event ids to specifically include (unless sender is ignored) - filter_send_to_client (bool): Whether we're checking an event that's going to be + filter_send_to_client: Whether we're checking an event that's going to be sent to a client. This might not always be the case since this function can also be called to check whether a user can see the state at a given point. Returns: - list[synapse.events.EventBase] + The filtered events. """ # Filter out events that have been soft failed so that we don't relay them # to clients. @@ -90,7 +91,7 @@ async def filter_events_for_client( AccountDataTypes.IGNORED_USER_LIST, user_id ) - ignore_list = frozenset() + ignore_list = frozenset() # type: FrozenSet[str] if ignore_dict_content: ignored_users_dict = ignore_dict_content.get("ignored_users", {}) if isinstance(ignored_users_dict, dict): @@ -107,19 +108,18 @@ async def filter_events_for_client( room_id ] = await storage.main.get_retention_policy_for_room(room_id) - def allowed(event): + def allowed(event: EventBase) -> Optional[EventBase]: """ Args: - event (synapse.events.EventBase): event to check + event: event to check Returns: - None|EventBase: - None if the user cannot see this event at all + None if the user cannot see this event at all - a redacted copy of the event if they can only see a redacted - version + a redacted copy of the event if they can only see a redacted + version - the original event if they can see it as normal. + the original event if they can see it as normal. """ # Only run some checks if these events aren't about to be sent to clients. This is # because, if this is not the case, we're probably only checking if the users can @@ -252,48 +252,46 @@ async def filter_events_for_client( return event - # check each event: gives an iterable[None|EventBase] + # Check each event: gives an iterable of None or (a potentially modified) + # EventBase. filtered_events = map(allowed, events) - # remove the None entries - filtered_events = filter(operator.truth, filtered_events) - - # we turn it into a list before returning it. - return list(filtered_events) + # Turn it into a list and remove None entries before returning. + return [ev for ev in filtered_events if ev] async def filter_events_for_server( storage: Storage, - server_name, - events, - redact=True, - check_history_visibility_only=False, -): + server_name: str, + events: List[EventBase], + redact: bool = True, + check_history_visibility_only: bool = False, +) -> List[EventBase]: """Filter a list of events based on whether given server is allowed to see them. Args: storage - server_name (str) - events (iterable[FrozenEvent]) - redact (bool): Whether to return a redacted version of the event, or + server_name + events + redact: Whether to return a redacted version of the event, or to filter them out entirely. - check_history_visibility_only (bool): Whether to only check the + check_history_visibility_only: Whether to only check the history visibility, rather than things like if the sender has been erased. This is used e.g. during pagination to decide whether to backfill or not. Returns - list[FrozenEvent] + The filtered events. """ - def is_sender_erased(event, erased_senders): + def is_sender_erased(event: EventBase, erased_senders: Dict[str, bool]) -> bool: if erased_senders and erased_senders[event.sender]: logger.info("Sender of %s has been erased, redacting", event.event_id) return True return False - def check_event_is_visible(event, state): + def check_event_is_visible(event: EventBase, state: StateMap[EventBase]) -> bool: history = state.get((EventTypes.RoomHistoryVisibility, ""), None) if history: visibility = history.content.get( diff --git a/test_postgresql.sh b/test_postgresql.sh
index 1ffcaabd31..c10828fbbc 100755 --- a/test_postgresql.sh +++ b/test_postgresql.sh
@@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # This script builds the Docker image to run the PostgreSQL tests, and then runs # the tests. diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 483418192c..fa96ba07a5 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py
@@ -5,38 +5,25 @@ from synapse.types import create_requester from tests import unittest -class TestRatelimiter(unittest.TestCase): +class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) - self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) - self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) - self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) - - def test_allowed_user_via_can_requester_do_action(self): - user_requester = create_requester("@user:example.com") - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) def test_allowed_via_ratelimit(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=0) + self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0)) # Should raise with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key="test_id", _time_now_s=5) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=5) + ) self.assertEqual(context.exception.retry_after_ms, 5000) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=10) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=10) + ) def test_allowed_via_can_do_action_and_overriding_parameters(self): """Test that we can override options of can_do_action that would otherwise fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=0, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=0, + ) ) self.assertTrue(allowed) self.assertEqual(10.0, time_allowed) # Second attempt, 1s later, will fail - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=1, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=1, + ) ) self.assertFalse(allowed) self.assertEqual(10.0, time_allowed) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, rate_hz=10.0 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0) ) self.assertTrue(allowed) self.assertEqual(1.1, time_allowed) # Similarly if we allow a burst of 10 actions - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, burst_count=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10) ) self.assertTrue(allowed) self.assertEqual(1.0, time_allowed) @@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase): fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed - limiter.ratelimit(key=("test_id",), _time_now_s=0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=0) + ) # Second attempt, 1s later, will fail with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key=("test_id",), _time_now_s=1) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1) + ) self.assertEqual(context.exception.retry_after_ms, 9000) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0) + ) # Similarly if we allow a burst of 10 actions - limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10) + ) def test_pruning(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - limiter.can_do_action(key="test_id_1", _time_now_s=0) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_1", _time_now_s=0) + ) self.assertIn("test_id_1", limiter.actions) - limiter.can_do_action(key="test_id_2", _time_now_s=10) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_2", _time_now_s=10) + ) self.assertNotIn("test_id_1", limiter.actions) + + def test_db_user_override(self): + """Test that users that have ratelimiting disabled in the DB aren't + ratelimited. + """ + store = self.hs.get_datastore() + + user_id = "@user:test" + requester = create_requester(user_id) + + self.get_success( + store.db_pool.simple_insert( + table="ratelimit_override", + values={ + "user_id": user_id, + "messages_per_second": None, + "burst_count": None, + }, + desc="test_db_user_override", + ) + ) + + limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1) + + # Shouldn't raise + for _ in range(20): + self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0)) diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 734a9983e8..c109425671 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py
@@ -20,6 +20,7 @@ from io import StringIO import yaml +from synapse.config import ConfigError from synapse.config.homeserver import HomeServerConfig from tests import unittest @@ -35,9 +36,9 @@ class ConfigLoadingTestCase(unittest.TestCase): def test_load_fails_if_server_name_missing(self): self.generate_config_and_remove_lines_containing("server_name") - with self.assertRaises(Exception): + with self.assertRaises(ConfigError): HomeServerConfig.load_config("", ["-c", self.file]) - with self.assertRaises(Exception): + with self.assertRaises(ConfigError): HomeServerConfig.load_or_generate_config("", ["-c", self.file]) def test_generates_and_loads_macaroon_secret_key(self): diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 30fcc4c1bf..946482b7e7 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py
@@ -16,6 +16,7 @@ import time from mock import Mock +import attr import canonicaljson import signedjson.key import signedjson.sign @@ -68,6 +69,11 @@ class MockPerspectiveServer: signedjson.sign.sign_json(res, self.server_name, self.key) +@attr.s(slots=True) +class FakeRequest: + id = attr.ib() + + @logcontext_clean class KeyringTestCase(unittest.HomeserverTestCase): def check_context(self, val, expected): @@ -89,7 +95,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): first_lookup_deferred = Deferred() async def first_lookup_fetch(keys_to_fetch): - self.assertEquals(current_context().request, "context_11") + self.assertEquals(current_context().request.id, "context_11") self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}}) await make_deferred_yieldable(first_lookup_deferred) @@ -102,9 +108,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): mock_fetcher.get_keys.side_effect = first_lookup_fetch async def first_lookup(): - with LoggingContext("context_11") as context_11: - context_11.request = "context_11" - + with LoggingContext("context_11", request=FakeRequest("context_11")): res_deferreds = kr.verify_json_objects_for_server( [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] ) @@ -130,7 +134,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # should block rather than start a second call async def second_lookup_fetch(keys_to_fetch): - self.assertEquals(current_context().request, "context_12") + self.assertEquals(current_context().request.id, "context_12") return { "server10": { get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) @@ -142,9 +146,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): second_lookup_state = [0] async def second_lookup(): - with LoggingContext("context_12") as context_12: - context_12.request = "context_12" - + with LoggingContext("context_12", request=FakeRequest("context_12")): res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1, 0, "test")] ) @@ -589,10 +591,7 @@ def get_key_id(key): @defer.inlineCallbacks def run_in_context(f, *args, **kwargs): - with LoggingContext("testctx") as ctx: - # we set the "request" prop to make it easier to follow what's going on in the - # logs. - ctx.request = "testctx" + with LoggingContext("testctx"): rv = yield f(*args, **kwargs) return rv diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py new file mode 100644
index 0000000000..c6e547f11c --- /dev/null +++ b/tests/events/test_presence_router.py
@@ -0,0 +1,386 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union + +from mock import Mock + +import attr + +from synapse.api.constants import EduTypes +from synapse.events.presence_router import PresenceRouter +from synapse.federation.units import Transaction +from synapse.handlers.presence import UserPresenceState +from synapse.module_api import ModuleApi +from synapse.rest import admin +from synapse.rest.client.v1 import login, presence, room +from synapse.types import JsonDict, StreamToken, create_requester + +from tests.handlers.test_sync import generate_sync_config +from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config + + +@attr.s +class PresenceRouterTestConfig: + users_who_should_receive_all_presence = attr.ib(type=List[str], default=[]) + + +class PresenceRouterTestModule: + def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi): + self._config = config + self._module_api = module_api + + async def get_users_for_states( + self, state_updates: Iterable[UserPresenceState] + ) -> Dict[str, Set[UserPresenceState]]: + users_to_state = { + user_id: set(state_updates) + for user_id in self._config.users_who_should_receive_all_presence + } + return users_to_state + + async def get_interested_users( + self, user_id: str + ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + if user_id in self._config.users_who_should_receive_all_presence: + return PresenceRouter.ALL_USERS + + return set() + + @staticmethod + def parse_config(config_dict: dict) -> PresenceRouterTestConfig: + """Parse a configuration dictionary from the homeserver config, do + some validation and return a typed PresenceRouterConfig. + + Args: + config_dict: The configuration dictionary. + + Returns: + A validated config object. + """ + # Initialise a typed config object + config = PresenceRouterTestConfig() + + config.users_who_should_receive_all_presence = config_dict.get( + "users_who_should_receive_all_presence" + ) + + return config + + +class PresenceRouterTestCase(FederatingHomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + presence.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def prepare(self, reactor, clock, homeserver): + self.sync_handler = self.hs.get_sync_handler() + self.module_api = homeserver.get_module_api() + + @override_config( + { + "presence": { + "presence_router": { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler:test", + ] + }, + } + }, + "send_federation": True, + } + ) + def test_receiving_all_presence(self): + """Test that a user that does not share a room with another other can receive + presence for them, due to presence routing. + """ + # Create a user who should receive all presence of others + self.presence_receiving_user_id = self.register_user( + "presence_gobbler", "monkey" + ) + self.presence_receiving_user_tok = self.login("presence_gobbler", "monkey") + + # And two users who should not have any special routing + self.other_user_one_id = self.register_user("other_user_one", "monkey") + self.other_user_one_tok = self.login("other_user_one", "monkey") + self.other_user_two_id = self.register_user("other_user_two", "monkey") + self.other_user_two_tok = self.login("other_user_two", "monkey") + + # Put the other two users in a room with each other + room_id = self.helper.create_room_as( + self.other_user_one_id, tok=self.other_user_one_tok + ) + + self.helper.invite( + room_id, + self.other_user_one_id, + self.other_user_two_id, + tok=self.other_user_one_tok, + ) + self.helper.join(room_id, self.other_user_two_id, tok=self.other_user_two_tok) + # User one sends some presence + send_presence_update( + self, + self.other_user_one_id, + self.other_user_one_tok, + "online", + "boop", + ) + + # Check that the presence receiving user gets user one's presence when syncing + presence_updates, sync_token = sync_presence( + self, self.presence_receiving_user_id + ) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.other_user_one_id) + self.assertEqual(presence_update.state, "online") + self.assertEqual(presence_update.status_msg, "boop") + + # Have all three users send presence + send_presence_update( + self, + self.other_user_one_id, + self.other_user_one_tok, + "online", + "user_one", + ) + send_presence_update( + self, + self.other_user_two_id, + self.other_user_two_tok, + "online", + "user_two", + ) + send_presence_update( + self, + self.presence_receiving_user_id, + self.presence_receiving_user_tok, + "online", + "presence_gobbler", + ) + + # Check that the presence receiving user gets everyone's presence + presence_updates, _ = sync_presence( + self, self.presence_receiving_user_id, sync_token + ) + self.assertEqual(len(presence_updates), 3) + + # But that User One only get itself and User Two's presence + presence_updates, _ = sync_presence(self, self.other_user_one_id) + self.assertEqual(len(presence_updates), 2) + + found = False + for update in presence_updates: + if update.user_id == self.other_user_two_id: + self.assertEqual(update.state, "online") + self.assertEqual(update.status_msg, "user_two") + found = True + + self.assertTrue(found) + + @override_config( + { + "presence": { + "presence_router": { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler1:test", + "@presence_gobbler2:test", + "@far_away_person:island", + ] + }, + } + }, + "send_federation": True, + } + ) + def test_send_local_online_presence_to_with_module(self): + """Tests that send_local_presence_to_users sends local online presence to a set + of specified local and remote users, with a custom PresenceRouter module enabled. + """ + # Create a user who will send presence updates + self.other_user_id = self.register_user("other_user", "monkey") + self.other_user_tok = self.login("other_user", "monkey") + + # And another two users that will also send out presence updates, as well as receive + # theirs and everyone else's + self.presence_receiving_user_one_id = self.register_user( + "presence_gobbler1", "monkey" + ) + self.presence_receiving_user_one_tok = self.login("presence_gobbler1", "monkey") + self.presence_receiving_user_two_id = self.register_user( + "presence_gobbler2", "monkey" + ) + self.presence_receiving_user_two_tok = self.login("presence_gobbler2", "monkey") + + # Have all three users send some presence updates + send_presence_update( + self, + self.other_user_id, + self.other_user_tok, + "online", + "I'm online!", + ) + send_presence_update( + self, + self.presence_receiving_user_one_id, + self.presence_receiving_user_one_tok, + "online", + "I'm also online!", + ) + send_presence_update( + self, + self.presence_receiving_user_two_id, + self.presence_receiving_user_two_tok, + "unavailable", + "I'm in a meeting!", + ) + + # Mark each presence-receiving user for receiving all user presence + self.get_success( + self.module_api.send_local_online_presence_to( + [ + self.presence_receiving_user_one_id, + self.presence_receiving_user_two_id, + ] + ) + ) + + # Perform a sync for each user + + # The other user should only receive their own presence + presence_updates, _ = sync_presence(self, self.other_user_id) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.other_user_id) + self.assertEqual(presence_update.state, "online") + self.assertEqual(presence_update.status_msg, "I'm online!") + + # Whereas both presence receiving users should receive everyone's presence updates + presence_updates, _ = sync_presence(self, self.presence_receiving_user_one_id) + self.assertEqual(len(presence_updates), 3) + presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id) + self.assertEqual(len(presence_updates), 3) + + # Test that sending to a remote user works + remote_user_id = "@far_away_person:island" + + # Note that due to the remote user being in our module's + # users_who_should_receive_all_presence config, they would have + # received user presence updates already. + # + # Thus we reset the mock, and try sending all online local user + # presence again + self.hs.get_federation_transport_client().send_transaction.reset_mock() + + # Broadcast local user online presence + self.get_success( + self.module_api.send_local_online_presence_to([remote_user_id]) + ) + + # Check that the expected presence updates were sent + expected_users = [ + self.other_user_id, + self.presence_receiving_user_one_id, + self.presence_receiving_user_two_id, + ] + + calls = ( + self.hs.get_federation_transport_client().send_transaction.call_args_list + ) + for call in calls: + federation_transaction = call.args[0] # type: Transaction + + # Get the sent EDUs in this transaction + edus = federation_transaction.get_dict()["edus"] + + for edu in edus: + # Make sure we're only checking presence-type EDUs + if edu["edu_type"] != EduTypes.Presence: + continue + + # EDUs can contain multiple presence updates + for presence_update in edu["content"]["push"]: + # Check for presence updates that contain the user IDs we're after + expected_users.remove(presence_update["user_id"]) + + # Ensure that no offline states are being sent out + self.assertNotEqual(presence_update["presence"], "offline") + + self.assertEqual(len(expected_users), 0) + + +def send_presence_update( + testcase: TestCase, + user_id: str, + access_token: str, + presence_state: str, + status_message: Optional[str] = None, +) -> JsonDict: + # Build the presence body + body = {"presence": presence_state} + if status_message: + body["status_msg"] = status_message + + # Update the user's presence state + channel = testcase.make_request( + "PUT", "/presence/%s/status" % (user_id,), body, access_token=access_token + ) + testcase.assertEqual(channel.code, 200) + + return channel.json_body + + +def sync_presence( + testcase: TestCase, + user_id: str, + since_token: Optional[StreamToken] = None, +) -> Tuple[List[UserPresenceState], StreamToken]: + """Perform a sync request for the given user and return the user presence updates + they've received, as well as the next_batch token. + + This method assumes testcase.sync_handler points to the homeserver's sync handler. + + Args: + testcase: The testcase that is currently being run. + user_id: The ID of the user to generate a sync response for. + since_token: An optional token to indicate from at what point to sync from. + + Returns: + A tuple containing a list of presence updates, and the sync response's + next_batch token. + """ + requester = create_requester(user_id) + sync_config = generate_sync_config(requester.user.to_string()) + sync_result = testcase.get_success( + testcase.sync_handler.wait_for_sync_for_user( + requester, sync_config, since_token + ) + ) + + return sync_result.presence, sync_result.next_batch diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 6f96cd7940..95eac6a5a3 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py
@@ -2,6 +2,7 @@ from typing import List, Tuple from mock import Mock +from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu @@ -421,3 +422,51 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.assertNotIn("zzzerver", woken) # - all destinations are woken exactly once; they appear once in woken. self.assertCountEqual(woken, server_names[:-1]) + + @override_config({"send_federation": True}) + def test_not_latest_event(self): + """Test that we send the latest event in the room even if its not ours.""" + + per_dest_queue, sent_pdus = self.make_fake_destination_queue() + + # Make a room with a local user, and two servers. One will go offline + # and one will send some events. + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room_1 = self.helper.create_room_as("u1", tok=u1_token) + + self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join") + ) + event_1 = self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join") + ) + + # First we send something from the local server, so that we notice the + # remote is down and go into catchup mode. + self.helper.send(room_1, "you hear me!!", tok=u1_token) + + # Now simulate us receiving an event from the still online remote. + event_2 = self.get_success( + event_injection.inject_event( + self.hs, + type=EventTypes.Message, + sender="@user:host3", + room_id=room_1, + content={"msgtype": "m.text", "body": "Hello"}, + ) + ) + + self.get_success( + self.hs.get_datastore().set_destination_last_successful_stream_ordering( + "host2", event_1.internal_metadata.stream_ordering + ) + ) + + self.get_success(per_dest_queue._catch_up_transmission_loop()) + + # We expect only the last message from the remote, event_2, to have been + # sent, rather than the last *local* event that was sent. + self.assertEqual(len(sent_pdus), 1) + self.assertEqual(sent_pdus[0].event_id, event_2.event_id) + self.assertFalse(per_dest_queue._catching_up) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 5e9c9c2e88..c7796fb837 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -989,6 +989,138 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) self.assertRenderedError("mapping_error", "localpart is invalid: ") + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "attribute_requirements": [{"attribute": "test", "value": "foobar"}], + } + } + ) + def test_attribute_requirements(self): + """The required attributes must be met from the OIDC userinfo response.""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # userinfo lacking "test": "foobar" attribute should fail. + userinfo = { + "sub": "tester", + "username": "tester", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": "foobar" attribute should succeed. + userinfo = { + "sub": "tester", + "username": "tester", + "test": "foobar", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@tester:test", "oidc", ANY, ANY, None, new_user=True + ) + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "attribute_requirements": [{"attribute": "test", "value": "foobar"}], + } + } + ) + def test_attribute_requirements_contains(self): + """Test that auth succeeds if userinfo attribute CONTAINS required value""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed. + userinfo = { + "sub": "tester", + "username": "tester", + "test": ["foobar", "foo", "bar"], + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@tester:test", "oidc", ANY, ANY, None, new_user=True + ) + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "attribute_requirements": [{"attribute": "test", "value": "foobar"}], + } + } + ) + def test_attribute_requirements_mismatch(self): + """ + Test that auth fails if attributes exist but don't match, + or are non-string values. + """ + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # userinfo with "test": "not_foobar" attribute should fail + userinfo = { + "sub": "tester", + "username": "tester", + "test": "not_foobar", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": ["foo", "bar"] attribute should fail + userinfo = { + "sub": "tester", + "username": "tester", + "test": ["foo", "bar"], + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": False attribute should fail + # this is largely just to ensure we don't crash here + userinfo = { + "sub": "tester", + "username": "tester", + "test": False, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": None attribute should fail + # a value of None breaks the OIDC spec, but it's important to not crash here + userinfo = { + "sub": "tester", + "username": "tester", + "test": None, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": 1 attribute should fail + # this is largely just to ensure we don't crash here + userinfo = { + "sub": "tester", + "username": "tester", + "test": 1, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": 3.14 attribute should fail + # this is largely just to ensure we don't crash here + userinfo = { + "sub": "tester", + "username": "tester", + "test": 3.14, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + def _generate_oidc_session_token( self, state: str, diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 996c614198..77330f59a9 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py
@@ -310,6 +310,26 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) + def test_busy_no_idle(self): + """ + Tests that a user setting their presence to busy but idling doesn't turn their + presence state into unavailable. + """ + user_id = "@foo:bar" + now = 5000000 + + state = UserPresenceState.default(user_id) + state = state.copy_and_replace( + state=PresenceState.BUSY, + last_active_ts=now - IDLE_TIMER - 1, + last_user_sync_ts=now, + ) + + new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) + + self.assertIsNotNone(new_state) + self.assertEquals(new_state.state, PresenceState.BUSY) + def test_sync_timeout(self): user_id = "@foo:bar" now = 5000000 diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e62586142e..8e950f25c5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py
@@ -37,7 +37,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:test" user_id2 = "@user2:test" - sync_config = self._generate_sync_config(user_id1) + sync_config = generate_sync_config(user_id1) requester = create_requester(user_id1) self.reactor.advance(100) # So we get not 0 time @@ -60,7 +60,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = False - sync_config = self._generate_sync_config(user_id2) + sync_config = generate_sync_config(user_id2) requester = create_requester(user_id2) e = self.get_failure( @@ -69,11 +69,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - def _generate_sync_config(self, user_id): - return SyncConfig( - user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), - filter_collection=DEFAULT_FILTER_COLLECTION, - is_guest=False, - request_key="request_key", - device_id="device_id", - ) + +def generate_sync_config(user_id: str) -> SyncConfig: + return SyncConfig( + user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), + filter_collection=DEFAULT_FILTER_COLLECTION, + is_guest=False, + request_key="request_key", + device_id="device_id", + ) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 505ffcd300..3ea8b5bec7 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py
@@ -12,8 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging import os +from typing import Optional from unittest.mock import patch import treq @@ -242,6 +244,21 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) def test_https_request_via_proxy(self): + """Tests that TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, + ) + def test_https_request_via_proxy_with_auth(self): + """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(auth_credentials="bob:pinkponies") + + def _do_https_request_via_proxy( + self, + auth_credentials: Optional[str] = None, + ): agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), @@ -278,6 +295,22 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b"CONNECT") self.assertEqual(request.path, b"test.com:443") + # Check whether auth credentials have been supplied to the proxy + proxy_auth_header_values = request.requestHeaders.getRawHeaders( + b"Proxy-Authorization" + ) + + if auth_credentials is not None: + # Compute the correct header value for Proxy-Authorization + encoded_credentials = base64.b64encode(b"bob:pinkponies") + expected_header_value = b"Basic " + encoded_credentials + + # Validate the header's value + self.assertIn(expected_header_value, proxy_auth_header_values) + else: + # Check that the Proxy-Authorization header has not been supplied to the proxy + self.assertIsNone(proxy_auth_header_values) + # tell the proxy server not to close the connection proxy_server.persistent = True @@ -312,6 +345,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/abc") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + + # Check that the destination server DID NOT receive proxy credentials + proxy_auth_header_values = request.requestHeaders.getRawHeaders( + b"Proxy-Authorization" + ) + self.assertIsNone(proxy_auth_header_values) + request.write(b"result") request.finish() diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 48a74e2eee..bfe0d11c93 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py
@@ -12,15 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json import logging -from io import StringIO +from io import BytesIO, StringIO + +from mock import Mock, patch + +from twisted.web.server import Request +from synapse.http.site import SynapseRequest from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter from synapse.logging.context import LoggingContext, LoggingContextFilter from tests.logging import LoggerCleanupMixin +from tests.server import FakeChannel from tests.unittest import TestCase @@ -120,7 +125,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): handler.addFilter(LoggingContextFilter()) logger = self.get_logger(handler) - with LoggingContext(request="test"): + with LoggingContext("name"): logger.info("Hello there, %s!", "wally") log = self.get_log_line() @@ -134,4 +139,61 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): ] self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") - self.assertEqual(log["request"], "test") + self.assertTrue(log["request"].startswith("name@")) + + def test_with_request_context(self): + """ + Information from the logging context request should be added to the JSON response. + """ + handler = logging.StreamHandler(self.output) + handler.setFormatter(JsonFormatter()) + handler.addFilter(LoggingContextFilter()) + logger = self.get_logger(handler) + + # A full request isn't needed here. + site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"]) + site.site_tag = "test-site" + site.server_version_string = "Server v1" + request = SynapseRequest(FakeChannel(site, None)) + # Call requestReceived to finish instantiating the object. + request.content = BytesIO() + # Partially skip some of the internal processing of SynapseRequest. + request._started_processing = Mock() + request.request_metrics = Mock(spec=["name"]) + with patch.object(Request, "render"): + request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1") + + # Also set the requester to ensure the processing works. + request.requester = "@foo:test" + + with LoggingContext(parent_context=request.logcontext): + logger.info("Hello there, %s!", "wally") + + log = self.get_log_line() + + # The terse logger includes additional request information, if possible. + expected_log_keys = [ + "log", + "level", + "namespace", + "request", + "ip_address", + "site_tag", + "requester", + "authenticated_entity", + "method", + "url", + "protocol", + "user_agent", + ] + self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") + self.assertTrue(log["request"].startswith("POST-")) + self.assertEqual(log["ip_address"], "127.0.0.1") + self.assertEqual(log["site_tag"], "test-site") + self.assertEqual(log["requester"], "@foo:test") + self.assertEqual(log["authenticated_entity"], "@foo:test") + self.assertEqual(log["method"], "POST") + self.assertEqual(log["url"], "/_matrix/client/versions") + self.assertEqual(log["protocol"], "1.1") + self.assertEqual(log["user_agent"], "") diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index edacd1b566..1d1fceeecf 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py
@@ -14,25 +14,37 @@ # limitations under the License. from mock import Mock +from synapse.api.constants import EduTypes from synapse.events import EventBase +from synapse.federation.units import Transaction +from synapse.handlers.presence import UserPresenceState from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client.v1 import login, presence, room from synapse.types import create_requester -from tests.unittest import HomeserverTestCase +from tests.events.test_presence_router import send_presence_update, sync_presence +from tests.test_utils.event_injection import inject_member_event +from tests.unittest import FederatingHomeserverTestCase, override_config -class ModuleApiTestCase(HomeserverTestCase): +class ModuleApiTestCase(FederatingHomeserverTestCase): servlets = [ admin.register_servlets, login.register_servlets, room.register_servlets, + presence.register_servlets, ] def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() self.module_api = homeserver.get_module_api() self.event_creation_handler = homeserver.get_event_creation_handler() + self.sync_handler = homeserver.get_sync_handler() + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + federation_transport_client=Mock(spec=["send_transaction"]), + ) def test_can_register_user(self): """Tests that an external module can register a user""" @@ -205,3 +217,160 @@ class ModuleApiTestCase(HomeserverTestCase): ) ) self.assertFalse(is_in_public_rooms) + + # The ability to send federation is required by send_local_online_presence_to. + @override_config({"send_federation": True}) + def test_send_local_online_presence_to(self): + """Tests that send_local_presence_to_users sends local online presence to local users.""" + # Create a user who will send presence updates + self.presence_receiver_id = self.register_user("presence_receiver", "monkey") + self.presence_receiver_tok = self.login("presence_receiver", "monkey") + + # And another user that will send presence updates out + self.presence_sender_id = self.register_user("presence_sender", "monkey") + self.presence_sender_tok = self.login("presence_sender", "monkey") + + # Put them in a room together so they will receive each other's presence updates + room_id = self.helper.create_room_as( + self.presence_receiver_id, + tok=self.presence_receiver_tok, + ) + self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok) + + # Presence sender comes online + send_presence_update( + self, + self.presence_sender_id, + self.presence_sender_tok, + "online", + "I'm online!", + ) + + # Presence receiver should have received it + presence_updates, sync_token = sync_presence(self, self.presence_receiver_id) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.presence_sender_id) + self.assertEqual(presence_update.state, "online") + + # Syncing again should result in no presence updates + presence_updates, sync_token = sync_presence( + self, self.presence_receiver_id, sync_token + ) + self.assertEqual(len(presence_updates), 0) + + # Trigger sending local online presence + self.get_success( + self.module_api.send_local_online_presence_to( + [ + self.presence_receiver_id, + ] + ) + ) + + # Presence receiver should have received online presence again + presence_updates, sync_token = sync_presence( + self, self.presence_receiver_id, sync_token + ) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.presence_sender_id) + self.assertEqual(presence_update.state, "online") + + # Presence sender goes offline + send_presence_update( + self, + self.presence_sender_id, + self.presence_sender_tok, + "offline", + "I slink back into the darkness.", + ) + + # Trigger sending local online presence + self.get_success( + self.module_api.send_local_online_presence_to( + [ + self.presence_receiver_id, + ] + ) + ) + + # Presence receiver should *not* have received offline state + presence_updates, sync_token = sync_presence( + self, self.presence_receiver_id, sync_token + ) + self.assertEqual(len(presence_updates), 0) + + @override_config({"send_federation": True}) + def test_send_local_online_presence_to_federation(self): + """Tests that send_local_presence_to_users sends local online presence to remote users.""" + # Create a user who will send presence updates + self.presence_sender_id = self.register_user("presence_sender", "monkey") + self.presence_sender_tok = self.login("presence_sender", "monkey") + + # And a room they're a part of + room_id = self.helper.create_room_as( + self.presence_sender_id, + tok=self.presence_sender_tok, + ) + + # Mark them as online + send_presence_update( + self, + self.presence_sender_id, + self.presence_sender_tok, + "online", + "I'm online!", + ) + + # Make up a remote user to send presence to + remote_user_id = "@far_away_person:island" + + # Create a join membership event for the remote user into the room. + # This allows presence information to flow from one user to the other. + self.get_success( + inject_member_event( + self.hs, + room_id, + sender=remote_user_id, + target=remote_user_id, + membership="join", + ) + ) + + # The remote user would have received the existing room members' presence + # when they joined the room. + # + # Thus we reset the mock, and try sending online local user + # presence again + self.hs.get_federation_transport_client().send_transaction.reset_mock() + + # Broadcast local user online presence + self.get_success( + self.module_api.send_local_online_presence_to([remote_user_id]) + ) + + # Check that a presence update was sent as part of a federation transaction + found_update = False + calls = ( + self.hs.get_federation_transport_client().send_transaction.call_args_list + ) + for call in calls: + federation_transaction = call.args[0] # type: Transaction + + # Get the sent EDUs in this transaction + edus = federation_transaction.get_dict()["edus"] + + for edu in edus: + # Make sure we're only checking presence-type EDUs + if edu["edu_type"] != EduTypes.Presence: + continue + + # EDUs can contain multiple presence updates + for presence_update in edu["content"]["push"]: + if presence_update["user_id"] == self.presence_sender_id: + found_update = True + + self.assertTrue(found_update) diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 67b7913666..1d4a592862 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -44,7 +44,7 @@ from tests.server import FakeTransport try: import hiredis except ImportError: - hiredis = None + hiredis = None # type: ignore logger = logging.getLogger(__name__) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 5acfb3e53e..ca49d4dd3a 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py
@@ -69,6 +69,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assert_request_is_get_repl_stream_updates(request, "typing") # The from token should be the token from the last RDATA we got. + assert request.args is not None self.assertEqual(int(request.args[b"from_token"][0]), token) self.test_handler.on_rdata.assert_called_once() diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 7ff11cde10..b0800f9840 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py
@@ -15,7 +15,7 @@ import logging import os from binascii import unhexlify -from typing import Tuple +from typing import Optional, Tuple from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory @@ -32,7 +32,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request logger = logging.getLogger(__name__) -test_server_connection_factory = None +test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory] class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e58d5cf0db..0c9ec133c2 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -28,7 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from tests import unittest from tests.server import FakeSite, make_request @@ -467,6 +467,8 @@ class UsersListTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users" def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -634,6 +636,26 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + # unkown order_by + channel = self.make_request( + "GET", + self.url + "?order_by=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid search order + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + def test_limit(self): """ Testing list of users with limit @@ -759,6 +781,103 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) + def test_order_by(self): + """ + Testing order list with parameter `order_by` + """ + + user1 = self.register_user("user1", "pass1", admin=False, displayname="Name Z") + user2 = self.register_user("user2", "pass2", admin=False, displayname="Name Y") + + # Modify user + self.get_success(self.store.set_user_deactivated_status(user1, True)) + self.get_success(self.store.set_shadow_banned(UserID.from_string(user1), True)) + + # Set avatar URL to all users, that no user has a NULL value to avoid + # different sort order between SQlite and PostreSQL + self.get_success(self.store.set_profile_avatar_url("user1", "mxc://url3")) + self.get_success(self.store.set_profile_avatar_url("user2", "mxc://url2")) + self.get_success(self.store.set_profile_avatar_url("admin", "mxc://url1")) + + # order by default (name) + self._order_test([self.admin_user, user1, user2], None) + self._order_test([self.admin_user, user1, user2], None, "f") + self._order_test([user2, user1, self.admin_user], None, "b") + + # order by name + self._order_test([self.admin_user, user1, user2], "name") + self._order_test([self.admin_user, user1, user2], "name", "f") + self._order_test([user2, user1, self.admin_user], "name", "b") + + # order by displayname + self._order_test([user2, user1, self.admin_user], "displayname") + self._order_test([user2, user1, self.admin_user], "displayname", "f") + self._order_test([self.admin_user, user1, user2], "displayname", "b") + + # order by is_guest + # like sort by ascending name, as no guest user here + self._order_test([self.admin_user, user1, user2], "is_guest") + self._order_test([self.admin_user, user1, user2], "is_guest", "f") + self._order_test([self.admin_user, user1, user2], "is_guest", "b") + + # order by admin + self._order_test([user1, user2, self.admin_user], "admin") + self._order_test([user1, user2, self.admin_user], "admin", "f") + self._order_test([self.admin_user, user1, user2], "admin", "b") + + # order by deactivated + self._order_test([self.admin_user, user2, user1], "deactivated") + self._order_test([self.admin_user, user2, user1], "deactivated", "f") + self._order_test([user1, self.admin_user, user2], "deactivated", "b") + + # order by user_type + # like sort by ascending name, as no special user type here + self._order_test([self.admin_user, user1, user2], "user_type") + self._order_test([self.admin_user, user1, user2], "user_type", "f") + self._order_test([self.admin_user, user1, user2], "is_guest", "b") + + # order by shadow_banned + self._order_test([self.admin_user, user2, user1], "shadow_banned") + self._order_test([self.admin_user, user2, user1], "shadow_banned", "f") + self._order_test([user1, self.admin_user, user2], "shadow_banned", "b") + + # order by avatar_url + self._order_test([self.admin_user, user2, user1], "avatar_url") + self._order_test([self.admin_user, user2, user1], "avatar_url", "f") + self._order_test([user1, user2, self.admin_user], "avatar_url", "b") + + def _order_test( + self, + expected_user_list: List[str], + order_by: Optional[str], + dir: Optional[str] = None, + ): + """Request the list of users in a certain order. Assert that order is what + we expect + Args: + expected_user_list: The list of user_id in the order we expect to get + back from the server + order_by: The type of ordering to give the server + dir: The direction of ordering to give the server + """ + + url = self.url + "?deactivated=true&" + if order_by is not None: + url += "order_by=%s&" % (order_by,) + if dir is not None and dir in ("b", "f"): + url += "dir=%s" % (dir,) + channel = self.make_request( + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], len(expected_user_list)) + + returned_order = [row["name"] for row in channel.json_body["users"]] + self.assertEqual(expected_user_list, returned_order) + self._check_fields(channel.json_body["users"]) + def _check_fields(self, content: JsonDict): """Checks that the expected user attributes are present in content Args: @@ -1003,12 +1122,23 @@ class UserRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() + # create users and get access tokens + # regardless of whether password login or SSO is allowed self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") + self.admin_user_tok = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.admin_user, device_id=None, valid_until_ms=None + ) + ) self.other_user = self.register_user("user", "pass", displayname="User") - self.other_user_token = self.login("user", "pass") + self.other_user_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.other_user, device_id=None, valid_until_ms=None + ) + ) self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( self.other_user ) @@ -1081,7 +1211,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(True, channel.json_body["admin"]) + self.assertTrue(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user @@ -1096,9 +1226,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(True, channel.json_body["admin"]) - self.assertEqual(False, channel.json_body["is_guest"]) - self.assertEqual(False, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["admin"]) + self.assertFalse(channel.json_body["is_guest"]) + self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) def test_create_user(self): @@ -1130,7 +1260,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(False, channel.json_body["admin"]) + self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user @@ -1145,10 +1275,10 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(False, channel.json_body["admin"]) - self.assertEqual(False, channel.json_body["is_guest"]) - self.assertEqual(False, channel.json_body["deactivated"]) - self.assertEqual(False, channel.json_body["shadow_banned"]) + self.assertFalse(channel.json_body["admin"]) + self.assertFalse(channel.json_body["is_guest"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertFalse(channel.json_body["shadow_banned"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) @override_config( @@ -1197,7 +1327,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["admin"]) + self.assertFalse(channel.json_body["admin"]) @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} @@ -1237,7 +1367,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Admin user is not blocked by mau anymore self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["admin"]) + self.assertFalse(channel.json_body["admin"]) @override_config( { @@ -1429,24 +1559,23 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["deactivated"]) + self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) # Deactivate user - body = json.dumps({"deactivated": True}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"deactivated": True}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) @@ -1461,7 +1590,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) @@ -1478,41 +1608,37 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertTrue(profile["display_name"] == "User") # Deactivate user - body = json.dumps({"deactivated": True}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"deactivated": True}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) # is not in user directory profile = self.get_success(self.store.get_user_in_directory(self.other_user)) - self.assertTrue(profile is None) + self.assertIsNone(profile) # Set new displayname user - body = json.dumps({"displayname": "Foobar"}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"displayname": "Foobar"}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) # is not in user directory profile = self.get_success(self.store.get_user_in_directory(self.other_user)) - self.assertTrue(profile is None) + self.assertIsNone(profile) def test_reactivate_user(self): """ @@ -1520,48 +1646,92 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Deactivate the user. + self._deactivate_user("@user:test") + + # Attempt to reactivate the user (without a password). channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=json.dumps({"deactivated": True}).encode(encoding="utf_8"), + content={"deactivated": False}, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + # Reactivate the user. + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": False, "password": "foo"}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertIsNotNone(channel.json_body["password_hash"]) self._is_erased("@user:test", False) - d = self.store.mark_user_erased("@user:test") - self.assertIsNone(self.get_success(d)) - self._is_erased("@user:test", True) - # Attempt to reactivate the user (without a password). + @override_config({"password_config": {"localdb_enabled": False}}) + def test_reactivate_user_localdb_disabled(self): + """ + Test reactivating another user when using SSO. + """ + + # Deactivate the user. + self._deactivate_user("@user:test") + + # Reactivate the user with a password channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=json.dumps({"deactivated": False}).encode(encoding="utf_8"), + content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - # Reactivate the user. + # Reactivate the user without a password. channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=json.dumps({"deactivated": False, "password": "foo"}).encode( - encoding="utf_8" - ), + content={"deactivated": False}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) + self._is_erased("@user:test", False) - # Get user + @override_config({"password_config": {"enabled": False}}) + def test_reactivate_user_password_disabled(self): + """ + Test reactivating another user when using SSO. + """ + + # Deactivate the user. + self._deactivate_user("@user:test") + + # Reactivate the user with a password channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"deactivated": False, "password": "foo"}, ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + # Reactivate the user without a password. + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": False}, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["deactivated"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) self._is_erased("@user:test", False) def test_set_user_as_admin(self): @@ -1570,18 +1740,16 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Set a user as an admin - body = json.dumps({"admin": True}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"admin": True}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["admin"]) + self.assertTrue(channel.json_body["admin"]) # Get user channel = self.make_request( @@ -1592,7 +1760,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["admin"]) + self.assertTrue(channel.json_body["admin"]) def test_accidental_deactivation_prevention(self): """ @@ -1602,13 +1770,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps({"password": "abc123"}) - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123"}, ) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) @@ -1628,13 +1794,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["deactivated"]) # Change password (and use a str for deactivate instead of a bool) - body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops! - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123", "deactivated": "false"}, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -1653,7 +1817,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Ensure they're still alive self.assertEqual(0, channel.json_body["deactivated"]) - def _is_erased(self, user_id, expect): + def _is_erased(self, user_id: str, expect: bool) -> None: """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) if expect: @@ -1661,6 +1825,24 @@ class UserRestTestCase(unittest.HomeserverTestCase): else: self.assertFalse(self.get_success(d)) + def _deactivate_user(self, user_id: str) -> None: + """Deactivate user and set as erased""" + + # Deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id), + access_token=self.admin_user_tok, + content={"deactivated": True}, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertTrue(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) + self._is_erased(user_id, False) + d = self.store.mark_user_erased(user_id) + self.assertIsNone(self.get_success(d)) + self._is_erased(user_id, True) + class UserMembershipRestTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 227fffab58..bf39014277 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py
@@ -161,6 +161,68 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): ev = channel.json_body self.assertEqual(ev["content"]["x"], "y") + def test_message_edit(self): + """Ensure that the module doesn't cause issues with edited messages.""" + # first patch the event checker so that it will modify the event + async def check(ev: EventBase, state): + d = ev.get_dict() + d["content"] = { + "msgtype": "m.text", + "body": d["content"]["body"].upper(), + } + return d + + current_rules_module().check_event_allowed = check + + # Send an event, then edit it. + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id, + { + "msgtype": "m.text", + "body": "Original body", + }, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + orig_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/m.room.message/2" % self.room_id, + { + "m.new_content": {"msgtype": "m.text", "body": "Edited body"}, + "m.relates_to": { + "rel_type": "m.replace", + "event_id": orig_event_id, + }, + "msgtype": "m.text", + "body": "Edited body", + }, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + edited_event_id = channel.json_body["event_id"] + + # ... and check that they both got modified + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id), + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["body"], "ORIGINAL BODY") + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id), + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["body"], "EDITED BODY") + def test_send_event(self): """Tests that the module can send an event into a room via the module api""" content = { diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 9734a2159a..ed433d9333 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Optional, Union from twisted.internet.defer import succeed @@ -74,7 +74,10 @@ class FallbackAuthTests(unittest.HomeserverTestCase): return channel def recaptcha( - self, session: str, expected_post_response: int, post_session: str = None + self, + session: str, + expected_post_response: int, + post_session: Optional[str] = None, ) -> None: """Get and respond to a fallback recaptcha. Returns the second request.""" if post_session is None: diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index e808339fb3..287a1a485c 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -18,6 +18,7 @@ from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import capabilities from tests import unittest +from tests.unittest import override_config class CapabilitiesTestCase(unittest.HomeserverTestCase): @@ -33,6 +34,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver() self.store = hs.get_datastore() self.config = hs.config + self.auth_handler = hs.get_auth_handler() return hs def test_check_auth_required(self): @@ -56,7 +58,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): capabilities["m.room_versions"]["default"], ) - def test_get_change_password_capabilities(self): + def test_get_change_password_capabilities_password_login(self): localpart = "user" password = "pass" user = self.register_user(localpart, password) @@ -66,10 +68,36 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) - - # Test case where password is handled outside of Synapse self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.get_success(self.store.user_set_password_hash(user, None)) + + @override_config({"password_config": {"localdb_enabled": False}}) + def test_get_change_password_capabilities_localdb_disabled(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["m.change_password"]["enabled"]) + + @override_config({"password_config": {"enabled": False}}) + def test_get_change_password_capabilities_password_disabled(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + user, device_id=None, valid_until_ms=None + ) + ) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 7c457754f1..e7bb5583fc 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -39,6 +39,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We need to enable msc1849 support for aggregations config = self.default_config() config["experimental_msc1849_support_enabled"] = True + + # We enable frozen dicts as relations/edits change event contents, so we + # want to test that we don't modify the events in the caches. + config["use_frozen_dicts"] = True + return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, hs): @@ -518,6 +523,63 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) + def test_edit_reply(self): + """Test that editing a reply works.""" + + # Create a reply to edit. + channel = self._send_relation( + RelationTypes.REFERENCE, + "m.room.message", + content={"msgtype": "m.text", "body": "A reply!"}, + ) + self.assertEquals(200, channel.code, channel.json_body) + reply = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=reply, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, reply), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to see the new body in the dict, as well as the reference + # metadata sill intact. + self.assertDictContainsSubset(new_body, channel.json_body["content"]) + self.assertDictContainsSubset( + { + "m.relates_to": { + "event_id": self.parent_id, + "key": None, + "rel_type": "m.reference", + } + }, + channel.json_body["content"], + ) + + # We expect that the edit relation appears in the unsigned relations + # section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + def test_relations_redaction_redacts_edits(self): """Test that edits of an event are redacted when the original event is redacted. diff --git a/tests/server.py b/tests/server.py
index 2287d20076..b535a5d886 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -2,7 +2,7 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union import attr from typing_extensions import Deque @@ -13,8 +13,11 @@ from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.defer import Deferred, fail, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( + IHostnameResolver, + IProtocol, + IPullProducer, + IPushProducer, IReactorPluggableNameResolver, - IReactorTCP, IResolverSimple, ITransport, ) @@ -45,11 +48,11 @@ class FakeChannel: wire). """ - site = attr.ib(type=Site) + site = attr.ib(type=Union[Site, "FakeSite"]) _reactor = attr.ib() result = attr.ib(type=dict, default=attr.Factory(dict)) _ip = attr.ib(type=str, default="127.0.0.1") - _producer = None + _producer = None # type: Optional[Union[IPullProducer, IPushProducer]] @property def json_body(self): @@ -159,7 +162,11 @@ class FakeChannel: Any cookines found are added to the given dict """ - for h in self.headers.getRawHeaders("Set-Cookie"): + headers = self.headers.getRawHeaders("Set-Cookie") + if not headers: + return + + for h in headers: parts = h.split(";") k, v = parts[0].split("=", maxsplit=1) cookies[k] = v @@ -311,8 +318,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self._tcp_callbacks = {} self._udp = [] - lookups = self.lookups = {} - self._thread_callbacks = deque() # type: Deque[Callable[[], None]]() + lookups = self.lookups = {} # type: Dict[str, str] + self._thread_callbacks = deque() # type: Deque[Callable[[], None]] @implementer(IResolverSimple) class FakeResolver: @@ -324,6 +331,9 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self.nameResolver = SimpleResolverComplexifier(FakeResolver()) super().__init__() + def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver: + raise NotImplementedError() + def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): p = udp.Port(port, protocol, interface, maxPacketSize, self) p.startListening() @@ -593,7 +603,7 @@ class FakeTransport: if self.disconnected: return - if getattr(self.other, "transport") is None: + if not hasattr(self.other, "transport"): # the other has no transport yet; reschedule if self.autoflush: self._reactor.callLater(0.0, self.flush) @@ -621,7 +631,9 @@ class FakeTransport: self.disconnected = True -def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol: +def connect_client( + reactor: ThreadedMemoryReactorClock, client_id: int +) -> Tuple[IProtocol, AccumulatingProtocol]: """ Connect a client to a fake TCP transport. diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index dabc1c5f09..ef4cf8d0f1 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd +# Copyright 2016-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,32 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import synapse.api.errors -import tests.unittest -import tests.utils - - -class DeviceStoreTestCase(tests.unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore +from tests.unittest import HomeserverTestCase - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class DeviceStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks def test_store_new_device(self): - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id", "device_id", "display_name") ) - res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) + res = self.get_success(self.store.get_device("user_id", "device_id")) self.assertDictContainsSubset( { "user_id": "user_id", @@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase): res, ) - @defer.inlineCallbacks def test_get_devices_by_user(self): - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id", "device1", "display_name 1") ) - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id", "device2", "display_name 2") ) - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id2", "device3", "display_name 3") ) - res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id")) + res = self.get_success(self.store.get_devices_by_user("user_id")) self.assertEqual(2, len(res.keys())) self.assertDictContainsSubset( { @@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase): res["device2"], ) - @defer.inlineCallbacks def test_count_devices_by_users(self): - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id", "device1", "display_name 1") ) - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id", "device2", "display_name 2") ) - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id2", "device3", "display_name 3") ) - res = yield defer.ensureDeferred(self.store.count_devices_by_users()) + res = self.get_success(self.store.count_devices_by_users()) self.assertEqual(0, res) - res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"])) + res = self.get_success(self.store.count_devices_by_users(["unknown"])) self.assertEqual(0, res) - res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"])) + res = self.get_success(self.store.count_devices_by_users(["user_id"])) self.assertEqual(2, res) - res = yield defer.ensureDeferred( + res = self.get_success( self.store.count_devices_by_users(["user_id", "user_id2"]) ) self.assertEqual(3, res) - @defer.inlineCallbacks def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id - yield defer.ensureDeferred( + self.get_success( self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield defer.ensureDeferred( + now_stream_id, device_updates = self.get_success( self.store.get_device_updates_by_remote("somehost", -1, limit=100) ) @@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase): } self.assertEqual(received_device_ids, set(expected_device_ids)) - @defer.inlineCallbacks def test_update_device(self): - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id", "device_id", "display_name 1") ) - res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) + res = self.get_success(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do a no-op first - yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) - res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) + self.get_success(self.store.update_device("user_id", "device_id")) + res = self.get_success(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do the update - yield defer.ensureDeferred( + self.get_success( self.store.update_device( "user_id", "device_id", new_display_name="display_name 2" ) ) # check it worked - res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) + res = self.get_success(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 2", res["display_name"]) - @defer.inlineCallbacks def test_update_unknown_device(self): - with self.assertRaises(synapse.api.errors.StoreError) as cm: - yield defer.ensureDeferred( - self.store.update_device( - "user_id", "unknown_device_id", new_display_name="display_name 2" - ) - ) - self.assertEqual(404, cm.exception.code) + exc = self.get_failure( + self.store.update_device( + "user_id", "unknown_device_id", new_display_name="display_name 2" + ), + synapse.api.errors.StoreError, + ) + self.assertEqual(404, exc.value.code) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index da93ca3980..0db233fd68 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,28 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - from synapse.types import RoomAlias, RoomID -from tests import unittest -from tests.utils import setup_test_homeserver +from tests.unittest import HomeserverTestCase -class DirectoryStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) - +class DirectoryStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") - @defer.inlineCallbacks def test_room_to_alias(self): - yield defer.ensureDeferred( + self.get_success( self.store.create_room_alias_association( room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] ) @@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase): self.assertEquals( ["#my-room:test"], - ( - yield defer.ensureDeferred( - self.store.get_aliases_for_room(self.room.to_string()) - ) - ), + (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))), ) - @defer.inlineCallbacks def test_alias_to_room(self): - yield defer.ensureDeferred( + self.get_success( self.store.create_room_alias_association( room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] ) @@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase): self.assertObjectHasAttributes( {"room_id": self.room.to_string(), "servers": ["test"]}, - ( - yield defer.ensureDeferred( - self.store.get_association_from_room_alias(self.alias) - ) - ), + (self.get_success(self.store.get_association_from_room_alias(self.alias))), ) - @defer.inlineCallbacks def test_delete_alias(self): - yield defer.ensureDeferred( + self.get_success( self.store.create_room_alias_association( room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] ) ) - room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias)) + room_id = self.get_success(self.store.delete_room_alias(self.alias)) self.assertEqual(self.room.to_string(), room_id) self.assertIsNone( - ( - yield defer.ensureDeferred( - self.store.get_association_from_room_alias(self.alias) - ) - ) + (self.get_success(self.store.get_association_from_room_alias(self.alias))) ) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 3fc4bb13b6..1e54b940fd 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd +# Copyright 2016-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,30 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from tests.unittest import HomeserverTestCase -import tests.unittest -import tests.utils - -class EndToEndKeyStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class EndToEndKeyStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks def test_key_without_device_name(self): now = 1470174257070 json = {"key": "value"} - yield defer.ensureDeferred(self.store.store_device("user", "device", None)) + self.get_success(self.store.store_device("user", "device", None)) - yield defer.ensureDeferred( - self.store.set_e2e_device_keys("user", "device", now, json) - ) + self.get_success(self.store.set_e2e_device_keys("user", "device", now, json)) - res = yield defer.ensureDeferred( + res = self.get_success( self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) ) self.assertIn("user", res) @@ -44,38 +36,32 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): dev = res["user"]["device"] self.assertDictContainsSubset(json, dev) - @defer.inlineCallbacks def test_reupload_key(self): now = 1470174257070 json = {"key": "value"} - yield defer.ensureDeferred(self.store.store_device("user", "device", None)) + self.get_success(self.store.store_device("user", "device", None)) - changed = yield defer.ensureDeferred( + changed = self.get_success( self.store.set_e2e_device_keys("user", "device", now, json) ) self.assertTrue(changed) # If we try to upload the same key then we should be told nothing # changed - changed = yield defer.ensureDeferred( + changed = self.get_success( self.store.set_e2e_device_keys("user", "device", now, json) ) self.assertFalse(changed) - @defer.inlineCallbacks def test_get_key_with_device_name(self): now = 1470174257070 json = {"key": "value"} - yield defer.ensureDeferred( - self.store.set_e2e_device_keys("user", "device", now, json) - ) - yield defer.ensureDeferred( - self.store.store_device("user", "device", "display_name") - ) + self.get_success(self.store.set_e2e_device_keys("user", "device", now, json)) + self.get_success(self.store.store_device("user", "device", "display_name")) - res = yield defer.ensureDeferred( + res = self.get_success( self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) ) self.assertIn("user", res) @@ -85,29 +71,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev ) - @defer.inlineCallbacks def test_multiple_devices(self): now = 1470174257070 - yield defer.ensureDeferred(self.store.store_device("user1", "device1", None)) - yield defer.ensureDeferred(self.store.store_device("user1", "device2", None)) - yield defer.ensureDeferred(self.store.store_device("user2", "device1", None)) - yield defer.ensureDeferred(self.store.store_device("user2", "device2", None)) + self.get_success(self.store.store_device("user1", "device1", None)) + self.get_success(self.store.store_device("user1", "device2", None)) + self.get_success(self.store.store_device("user2", "device1", None)) + self.get_success(self.store.store_device("user2", "device2", None)) - yield defer.ensureDeferred( + self.get_success( self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) ) - yield defer.ensureDeferred( + self.get_success( self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) ) - yield defer.ensureDeferred( + self.get_success( self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) ) - yield defer.ensureDeferred( + self.get_success( self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) ) - res = yield defer.ensureDeferred( + res = self.get_success( self.store.get_e2e_device_keys_for_cs_api( (("user1", "device1"), ("user2", "device2")) ) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 485f1ee033..239f7c9faf 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd +# Copyright 2016-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +15,7 @@ from mock import Mock -from twisted.internet import defer - -import tests.unittest -import tests.utils +from tests.unittest import HomeserverTestCase USER_ID = "@user:example.com" @@ -30,37 +27,31 @@ HIGHLIGHT = [ ] -class EventPushActionsStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class EventPushActionsStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.persist_events_store = hs.get_datastores().persist_events - @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): - yield defer.ensureDeferred( + self.get_success( self.store.get_unread_push_actions_for_user_in_range_for_http( USER_ID, 0, 1000, 20 ) ) - @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_email(self): - yield defer.ensureDeferred( + self.get_success( self.store.get_unread_push_actions_for_user_in_range_for_email( USER_ID, 0, 1000, 20 ) ) - @defer.inlineCallbacks def test_count_aggregation(self): room_id = "!foo:example.com" user_id = "@user1235:example.com" - @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield defer.ensureDeferred( + counts = self.get_success( self.store.db_pool.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) @@ -74,7 +65,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): }, ) - @defer.inlineCallbacks def _inject_actions(stream, action): event = Mock() event.room_id = room_id @@ -82,14 +72,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.internal_metadata.stream_ordering = stream event.depth = stream - yield defer.ensureDeferred( + self.get_success( self.store.add_push_actions_to_staging( event.event_id, {user_id: action}, False, ) ) - yield defer.ensureDeferred( + self.get_success( self.store.db_pool.runInteraction( "", self.persist_events_store._set_push_actions_for_event_and_users_txn, @@ -99,14 +89,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) def _rotate(stream): - return defer.ensureDeferred( + self.get_success( self.store.db_pool.runInteraction( "", self.store._rotate_notifs_before_txn, stream ) ) def _mark_read(stream, depth): - return defer.ensureDeferred( + self.get_success( self.store.db_pool.runInteraction( "", self.store._remove_old_push_actions_before_txn, @@ -116,49 +106,48 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) ) - yield _assert_counts(0, 0) - yield _inject_actions(1, PlAIN_NOTIF) - yield _assert_counts(1, 0) - yield _rotate(2) - yield _assert_counts(1, 0) + _assert_counts(0, 0) + _inject_actions(1, PlAIN_NOTIF) + _assert_counts(1, 0) + _rotate(2) + _assert_counts(1, 0) - yield _inject_actions(3, PlAIN_NOTIF) - yield _assert_counts(2, 0) - yield _rotate(4) - yield _assert_counts(2, 0) + _inject_actions(3, PlAIN_NOTIF) + _assert_counts(2, 0) + _rotate(4) + _assert_counts(2, 0) - yield _inject_actions(5, PlAIN_NOTIF) - yield _mark_read(3, 3) - yield _assert_counts(1, 0) + _inject_actions(5, PlAIN_NOTIF) + _mark_read(3, 3) + _assert_counts(1, 0) - yield _mark_read(5, 5) - yield _assert_counts(0, 0) + _mark_read(5, 5) + _assert_counts(0, 0) - yield _inject_actions(6, PlAIN_NOTIF) - yield _rotate(7) + _inject_actions(6, PlAIN_NOTIF) + _rotate(7) - yield defer.ensureDeferred( + self.get_success( self.store.db_pool.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) ) - yield _assert_counts(1, 0) + _assert_counts(1, 0) - yield _mark_read(7, 7) - yield _assert_counts(0, 0) + _mark_read(7, 7) + _assert_counts(0, 0) - yield _inject_actions(8, HIGHLIGHT) - yield _assert_counts(1, 1) - yield _rotate(9) - yield _assert_counts(1, 1) - yield _rotate(10) - yield _assert_counts(1, 1) + _inject_actions(8, HIGHLIGHT) + _assert_counts(1, 1) + _rotate(9) + _assert_counts(1, 1) + _rotate(10) + _assert_counts(1, 1) - @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return defer.ensureDeferred( + self.get_success( self.store.db_pool.simple_insert( "events", { @@ -177,24 +166,16 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) # start with the base case where there are no events in the table - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(11) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(11)) self.assertEqual(r, 0) # now with one event - yield add_event(2, 10) - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(9) - ) + add_event(2, 10) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(9)) self.assertEqual(r, 2) - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(10) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(10)) self.assertEqual(r, 2) - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(11) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(11)) self.assertEqual(r, 3) # add a bunch of dummy events to the events table @@ -205,39 +186,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): (10, 130), (20, 140), ): - yield add_event(stream_ordering, ts) + add_event(stream_ordering, ts) - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(110) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(110)) self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r) # 4 and 5 are both after 120: we want 4 rather than 5 - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(120) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(120)) self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r) - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(129) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(129)) self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r) # check we can get the last event - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(140) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(140)) self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r) # off the end - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(160) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(160)) self.assertEqual(r, 21) # check we can find an event at ordering zero - yield add_event(0, 5) - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(1) - ) + add_event(0, 5) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(1)) self.assertEqual(r, 0) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index ea63bd56b4..d18ceb41a9 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,59 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - from synapse.types import UserID from tests import unittest -from tests.utils import setup_test_homeserver - -class ProfileStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) +class ProfileStoreTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") - @defer.inlineCallbacks def test_displayname(self): - yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) + self.get_success(self.store.create_profile(self.u_frank.localpart)) - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_displayname(self.u_frank.localpart, "Frank") ) self.assertEquals( "Frank", ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_displayname(self.u_frank.localpart) ) ), ) # test set to None - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_displayname(self.u_frank.localpart, None) ) self.assertIsNone( ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_displayname(self.u_frank.localpart) ) ) ) - @defer.inlineCallbacks def test_avatar_url(self): - yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) + self.get_success(self.store.create_profile(self.u_frank.localpart)) - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_avatar_url( self.u_frank.localpart, "http://my.site/here" ) @@ -74,20 +65,20 @@ class ProfileStoreTestCase(unittest.TestCase): self.assertEquals( "http://my.site/here", ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_avatar_url(self.u_frank.localpart) ) ), ) # test set to None - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_avatar_url(self.u_frank.localpart, None) ) self.assertIsNone( ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_avatar_url(self.u_frank.localpart) ) ) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index b2a0e60856..2622207639 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py
@@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,8 +15,6 @@ from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.types import RoomID, UserID @@ -230,10 +227,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): self._base_builder = base_builder self._event_id = event_id - @defer.inlineCallbacks - def build(self, prev_event_ids, auth_event_ids): - built_event = yield defer.ensureDeferred( - self._base_builder.build(prev_event_ids, auth_event_ids) + async def build(self, prev_event_ids, auth_event_ids): + built_event = await self._base_builder.build( + prev_event_ids, auth_event_ids ) built_event._event_id = self._event_id diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4eb41c46e8..c82cf15bc2 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,21 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - from synapse.api.constants import UserTypes from synapse.api.errors import ThreepidValidationError -from tests import unittest -from tests.utils import setup_test_homeserver - +from tests.unittest import HomeserverTestCase -class RegistrationStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) +class RegistrationStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.user_id = "@my-user:test" @@ -35,9 +28,8 @@ class RegistrationStoreTestCase(unittest.TestCase): self.pwhash = "{xx1}123456789" self.device_id = "akgjhdjklgshg" - @defer.inlineCallbacks def test_register(self): - yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) + self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.assertEquals( { @@ -49,93 +41,81 @@ class RegistrationStoreTestCase(unittest.TestCase): "consent_version": None, "consent_server_notice_sent": None, "appservice_id": None, - "creation_ts": 1000, + "creation_ts": 0, "user_type": None, "deactivated": 0, "shadow_banned": 0, }, - (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))), + (self.get_success(self.store.get_user_by_id(self.user_id))), ) - @defer.inlineCallbacks def test_add_tokens(self): - yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) - yield defer.ensureDeferred( + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + self.get_success( self.store.add_access_token_to_user( self.user_id, self.tokens[1], self.device_id, valid_until_ms=None ) ) - result = yield defer.ensureDeferred( - self.store.get_user_by_access_token(self.tokens[1]) - ) + result = self.get_success(self.store.get_user_by_access_token(self.tokens[1])) self.assertEqual(result.user_id, self.user_id) self.assertEqual(result.device_id, self.device_id) self.assertIsNotNone(result.token_id) - @defer.inlineCallbacks def test_user_delete_access_tokens(self): # add some tokens - yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) - yield defer.ensureDeferred( + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + self.get_success( self.store.add_access_token_to_user( self.user_id, self.tokens[0], device_id=None, valid_until_ms=None ) ) - yield defer.ensureDeferred( + self.get_success( self.store.add_access_token_to_user( self.user_id, self.tokens[1], self.device_id, valid_until_ms=None ) ) # now delete some - yield defer.ensureDeferred( + self.get_success( self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id) ) # check they were deleted - user = yield defer.ensureDeferred( - self.store.get_user_by_access_token(self.tokens[1]) - ) + user = self.get_success(self.store.get_user_by_access_token(self.tokens[1])) self.assertIsNone(user, "access token was not deleted by device_id") # check the one not associated with the device was not deleted - user = yield defer.ensureDeferred( - self.store.get_user_by_access_token(self.tokens[0]) - ) + user = self.get_success(self.store.get_user_by_access_token(self.tokens[0])) self.assertEqual(self.user_id, user.user_id) # now delete the rest - yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) + self.get_success(self.store.user_delete_access_tokens(self.user_id)) - user = yield defer.ensureDeferred( - self.store.get_user_by_access_token(self.tokens[0]) - ) + user = self.get_success(self.store.get_user_by_access_token(self.tokens[0])) self.assertIsNone(user, "access token was not deleted without device_id") - @defer.inlineCallbacks def test_is_support_user(self): TEST_USER = "@test:test" SUPPORT_USER = "@support:test" - res = yield defer.ensureDeferred(self.store.is_support_user(None)) + res = self.get_success(self.store.is_support_user(None)) self.assertFalse(res) - yield defer.ensureDeferred( + self.get_success( self.store.register_user(user_id=TEST_USER, password_hash=None) ) - res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER)) + res = self.get_success(self.store.is_support_user(TEST_USER)) self.assertFalse(res) - yield defer.ensureDeferred( + self.get_success( self.store.register_user( user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT ) ) - res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER)) + res = self.get_success(self.store.is_support_user(SUPPORT_USER)) self.assertTrue(res) - @defer.inlineCallbacks def test_3pid_inhibit_invalid_validation_session_error(self): """Tests that enabling the configuration option to inhibit 3PID errors on /requestToken also inhibits validation errors caused by an unknown session ID. @@ -143,30 +123,28 @@ class RegistrationStoreTestCase(unittest.TestCase): # Check that, with the config setting set to false (the default value), a # validation error is caused by the unknown session ID. - try: - yield defer.ensureDeferred( - self.store.validate_threepid_session( - "fake_sid", - "fake_client_secret", - "fake_token", - 0, - ) - ) - except ThreepidValidationError as e: - self.assertEquals(e.msg, "Unknown session_id", e) + e = self.get_failure( + self.store.validate_threepid_session( + "fake_sid", + "fake_client_secret", + "fake_token", + 0, + ), + ThreepidValidationError, + ) + self.assertEquals(e.value.msg, "Unknown session_id", e) # Set the config setting to true. self.store._ignore_unknown_session_error = True # Check that now the validation error is caused by the token not matching. - try: - yield defer.ensureDeferred( - self.store.validate_threepid_session( - "fake_sid", - "fake_client_secret", - "fake_token", - 0, - ) - ) - except ThreepidValidationError as e: - self.assertEquals(e.msg, "Validation token not found or has expired", e) + e = self.get_failure( + self.store.validate_threepid_session( + "fake_sid", + "fake_client_secret", + "fake_token", + 0, + ), + ThreepidValidationError, + ) + self.assertEquals(e.value.msg, "Validation token not found or has expired", e) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index bc8400f240..0089d33c93 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,22 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions from synapse.types import RoomAlias, RoomID, UserID -from tests import unittest -from tests.utils import setup_test_homeserver - +from tests.unittest import HomeserverTestCase -class RoomStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) +class RoomStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table self.store = hs.get_datastore() @@ -37,7 +30,7 @@ class RoomStoreTestCase(unittest.TestCase): self.alias = RoomAlias.from_string("#a-room-name:test") self.u_creator = UserID.from_string("@creator:test") - yield defer.ensureDeferred( + self.get_success( self.store.store_room( self.room.to_string(), room_creator_user_id=self.u_creator.to_string(), @@ -46,7 +39,6 @@ class RoomStoreTestCase(unittest.TestCase): ) ) - @defer.inlineCallbacks def test_get_room(self): self.assertDictContainsSubset( { @@ -54,16 +46,12 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "is_public": True, }, - (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))), + (self.get_success(self.store.get_room(self.room.to_string()))), ) - @defer.inlineCallbacks def test_get_room_unknown_room(self): - self.assertIsNone( - (yield defer.ensureDeferred(self.store.get_room("!uknown:test"))) - ) + self.assertIsNone((self.get_success(self.store.get_room("!uknown:test")))) - @defer.inlineCallbacks def test_get_room_with_stats(self): self.assertDictContainsSubset( { @@ -71,29 +59,17 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "public": True, }, - ( - yield defer.ensureDeferred( - self.store.get_room_with_stats(self.room.to_string()) - ) - ), + (self.get_success(self.store.get_room_with_stats(self.room.to_string()))), ) - @defer.inlineCallbacks def test_get_room_with_stats_unknown_room(self): self.assertIsNone( - ( - yield defer.ensureDeferred( - self.store.get_room_with_stats("!uknown:test") - ) - ), + (self.get_success(self.store.get_room_with_stats("!uknown:test"))), ) -class RoomEventsStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = setup_test_homeserver(self.addCleanup) - +class RoomEventsStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() @@ -102,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") - yield defer.ensureDeferred( + self.get_success( self.store.store_room( self.room.to_string(), room_creator_user_id="@creator:text", @@ -111,23 +87,21 @@ class RoomEventsStoreTestCase(unittest.TestCase): ) ) - @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield defer.ensureDeferred( + self.get_success( self.storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) ) - @defer.inlineCallbacks def STALE_test_room_name(self): name = "A-Room-Name" - yield self.inject_room_event( + self.inject_room_event( etype=EventTypes.Name, name=name, content={"name": name}, depth=1 ) - state = yield defer.ensureDeferred( + state = self.get_success( self.store.get_current_state(room_id=self.room.to_string()) ) @@ -137,15 +111,14 @@ class RoomEventsStoreTestCase(unittest.TestCase): state[0], ) - @defer.inlineCallbacks def STALE_test_room_topic(self): topic = "A place for things" - yield self.inject_room_event( + self.inject_room_event( etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 ) - state = yield defer.ensureDeferred( + state = self.get_success( self.store.get_current_state(room_id=self.room.to_string()) ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8bd12fa847..f06b452fa9 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,24 +15,18 @@ import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.storage.state import StateFilter from synapse.types import RoomID, UserID -import tests.unittest -import tests.utils +from tests.unittest import HomeserverTestCase logger = logging.getLogger(__name__) -class StateStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) - +class StateStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.storage = hs.get_storage() self.state_datastore = self.storage.state.stores.state @@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") - yield defer.ensureDeferred( + self.get_success( self.store.store_room( self.room.to_string(), room_creator_user_id="@creator:text", @@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase): ) ) - @defer.inlineCallbacks def inject_state_event(self, room, sender, typ, state_key, content): builder = self.event_builder_factory.for_room_version( RoomVersions.V1, @@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase): }, ) - event, context = yield defer.ensureDeferred( + event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - yield defer.ensureDeferred( - self.storage.persistence.persist_event(event, context) - ) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(s1[t].event_id, s2[t].event_id) self.assertEqual(len(s1), len(s2)) - @defer.inlineCallbacks def test_get_state_groups_ids(self): - e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, "", {} - ) - e2 = yield self.inject_state_event( + e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) + e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield defer.ensureDeferred( + state_group_map = self.get_success( self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) @@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase): {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, ) - @defer.inlineCallbacks def test_get_state_groups(self): - e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, "", {} - ) - e2 = yield self.inject_state_event( + e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) + e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield defer.ensureDeferred( + state_group_map = self.get_success( self.storage.state.get_state_groups(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) @@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) - @defer.inlineCallbacks def test_get_state_for_event(self): # this defaults to a linear DAG as each new injection defaults to whatever # forward extremities are currently in the DB for this room. - e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, "", {} - ) - e2 = yield self.inject_state_event( + e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) + e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - e3 = yield self.inject_state_event( + e3 = self.inject_state_event( self.room, self.u_alice, EventTypes.Member, self.u_alice.to_string(), {"membership": Membership.JOIN}, ) - e4 = yield self.inject_state_event( + e4 = self.inject_state_event( self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {"membership": Membership.JOIN}, ) - e5 = yield self.inject_state_event( + e5 = self.inject_state_event( self.room, self.u_bob, EventTypes.Member, @@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield defer.ensureDeferred( - self.storage.state.get_state_for_event(e5.event_id) - ) + state = self.get_success(self.storage.state.get_state_for_event(e5.event_id)) self.assertIsNotNone(e4) @@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) @@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) @@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( @@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( @@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield defer.ensureDeferred( + group_ids = self.get_success( self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) ) group = list(group_ids.keys())[0] # test _get_state_for_group_using_cache correctly filters out members # with types=[] - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -377,14 +336,11 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### # deliberately remove e2 (room name) from the _state_group_cache - ( - is_all, - known_absent, - state_dict_ids, - ) = self.state_datastore._state_group_cache.get(group) + cache_entry = self.state_datastore._state_group_cache.get(group) + state_dict_ids = cache_entry.value - self.assertEqual(is_all, True) - self.assertEqual(known_absent, set()) + self.assertEqual(cache_entry.full, True) + self.assertEqual(cache_entry.known_absent, set()) self.assertDictEqual( state_dict_ids, { @@ -403,14 +359,11 @@ class StateStoreTestCase(tests.unittest.TestCase): fetched_keys=((e1.type, e1.state_key),), ) - ( - is_all, - known_absent, - state_dict_ids, - ) = self.state_datastore._state_group_cache.get(group) + cache_entry = self.state_datastore._state_group_cache.get(group) + state_dict_ids = cache_entry.value - self.assertEqual(is_all, False) - self.assertEqual(known_absent, {(e1.type, e1.state_key)}) + self.assertEqual(cache_entry.full, False) + self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)}) self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) ############################################ @@ -419,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -434,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -450,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -464,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -486,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -500,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -516,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -530,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index a6f63f4aaf..019c5b7b14 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from tests import unittest -from tests.utils import setup_test_homeserver +from tests.unittest import HomeserverTestCase, override_config ALICE = "@alice:a" BOB = "@bob:b" @@ -25,73 +22,52 @@ BOBBY = "@bobby:a" BELA = "@somenickname:a" -class UserDirectoryStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = self.hs.get_datastore() +class UserDirectoryStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. - yield defer.ensureDeferred( - self.store.update_profile_in_user_dir(ALICE, "alice", None) - ) - yield defer.ensureDeferred( - self.store.update_profile_in_user_dir(BOB, "bob", None) - ) - yield defer.ensureDeferred( - self.store.update_profile_in_user_dir(BOBBY, "bobby", None) - ) - yield defer.ensureDeferred( - self.store.update_profile_in_user_dir(BELA, "Bela", None) - ) - yield defer.ensureDeferred( - self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) - ) + self.get_success(self.store.update_profile_in_user_dir(ALICE, "alice", None)) + self.get_success(self.store.update_profile_in_user_dir(BOB, "bob", None)) + self.get_success(self.store.update_profile_in_user_dir(BOBBY, "bobby", None)) + self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None)) + self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))) - @defer.inlineCallbacks def test_search_user_dir(self): # normally when alice searches the directory she should just find # bob because bobby doesn't share a room with her. - r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) + r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10)) self.assertFalse(r["limited"]) self.assertEqual(1, len(r["results"])) self.assertDictEqual( r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None} ) - @defer.inlineCallbacks + @override_config({"user_directory": {"search_all_users": True}}) def test_search_user_dir_all_users(self): - self.hs.config.user_directory_search_all_users = True - try: - r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) - self.assertFalse(r["limited"]) - self.assertEqual(2, len(r["results"])) - self.assertDictEqual( - r["results"][0], - {"user_id": BOB, "display_name": "bob", "avatar_url": None}, - ) - self.assertDictEqual( - r["results"][1], - {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None}, - ) - finally: - self.hs.config.user_directory_search_all_users = False + r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(2, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": BOB, "display_name": "bob", "avatar_url": None}, + ) + self.assertDictEqual( + r["results"][1], + {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None}, + ) - @defer.inlineCallbacks + @override_config({"user_directory": {"search_all_users": True}}) def test_search_user_dir_stop_words(self): """Tests that a user can look up another user by searching for the start if its display name even if that name happens to be a common English word that would usually be ignored in full text searches. """ - self.hs.config.user_directory_search_all_users = True - try: - r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10)) - self.assertFalse(r["limited"]) - self.assertEqual(1, len(r["results"])) - self.assertDictEqual( - r["results"][0], - {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, - ) - finally: - self.hs.config.user_directory_search_all_users = False + r = self.get_success(self.store.search_user_dir(ALICE, "be", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, + ) diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 3f2691ee6b..b5f18344dc 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py
@@ -207,6 +207,226 @@ class EventAuthTestCase(unittest.TestCase): do_sig_check=False, ) + def test_join_rules_public(self): + """ + Test joining a public room. + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.join_rules", ""): _join_rules_event(creator, "public"), + } + + # Check join. + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user cannot be force-joined to a room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _member_event(pleb, "join", sender=creator), + auth_events, + do_sig_check=False, + ) + + # Banned should be rejected. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user who left can re-join. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can send a join if they're in the room. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can accept an invite. + auth_events[("m.room.member", pleb)] = _member_event( + pleb, "invite", sender=creator + ) + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + def test_join_rules_invite(self): + """ + Test joining an invite only room. + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.join_rules", ""): _join_rules_event(creator, "invite"), + } + + # A join without an invite is rejected. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user cannot be force-joined to a room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _member_event(pleb, "join", sender=creator), + auth_events, + do_sig_check=False, + ) + + # Banned should be rejected. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user who left cannot re-join. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can send a join if they're in the room. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can accept an invite. + auth_events[("m.room.member", pleb)] = _member_event( + pleb, "invite", sender=creator + ) + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + def test_join_rules_msc3083_restricted(self): + """ + Test joining a restricted room from MSC3083. + + This is pretty much the same test as public. + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"), + } + + # Older room versions don't understand this join rule + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # Check join. + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user cannot be force-joined to a room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.MSC3083, + _member_event(pleb, "join", sender=creator), + auth_events, + do_sig_check=False, + ) + + # Banned should be rejected. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user who left can re-join. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can send a join if they're in the room. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can accept an invite. + auth_events[("m.room.member", pleb)] = _member_event( + pleb, "invite", sender=creator + ) + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + # helpers for making events @@ -225,19 +445,24 @@ def _create_event(user_id): ) -def _join_event(user_id): +def _member_event(user_id, membership, sender=None): return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), "type": "m.room.member", - "sender": user_id, + "sender": sender or user_id, "state_key": user_id, - "content": {"membership": "join"}, + "content": {"membership": membership}, + "prev_events": [], } ) +def _join_event(user_id): + return _member_event(user_id, "join") + + def _power_levels_event(sender, content): return make_event_from_dict( { @@ -277,6 +502,21 @@ def _random_state_event(sender): ) +def _join_rules_event(sender, join_rule): + return make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.join_rules", + "sender": sender, + "state_key": "", + "content": { + "join_rule": join_rule, + }, + } + ) + + event_count = 0 diff --git a/tests/unittest.py b/tests/unittest.py
index ca7031c724..57b6a395c7 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -32,6 +32,7 @@ from twisted.python.threadpool import ThreadPool from twisted.trial import unittest from twisted.web.resource import Resource +from synapse import events from synapse.api.constants import EventTypes, Membership from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig @@ -140,7 +141,7 @@ class TestCase(unittest.TestCase): try: self.assertEquals(attrs[key], getattr(obj, key)) except AssertionError as e: - raise (type(e))(e.message + " for '.%s'" % key) + raise (type(e))("Assert error for '.{}':".format(key)) from e def assert_dict(self, required, actual): """Does a partial assert of a dict. @@ -229,6 +230,11 @@ class HomeserverTestCase(TestCase): self._hs_args = {"clock": self.clock, "reactor": self.reactor} self.hs = self.make_homeserver(self.reactor, self.clock) + # Honour the `use_frozen_dicts` config option. We have to do this + # manually because this is taken care of in the app `start` code, which + # we don't run. Plus we want to reset it on tearDown. + events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts + if self.hs is None: raise Exception("No homeserver returned from make_homeserver.") @@ -292,6 +298,10 @@ class HomeserverTestCase(TestCase): if hasattr(self, "prepare"): self.prepare(self.reactor, self.clock, self.hs) + def tearDown(self): + # Reset to not use frozen dicts. + events.USE_FROZEN_DICTS = False + def wait_on_thread(self, deferred, timeout=10): """ Wait until a Deferred is done, where it's waiting on a real thread. @@ -461,7 +471,7 @@ class HomeserverTestCase(TestCase): kwargs["config"] = config_obj async def run_bg_updates(): - with LoggingContext("run_bg_updates", request="run_bg_updates-1"): + with LoggingContext("run_bg_updates"): while not await stor.db_pool.updates.has_completed_background_updates(): await stor.db_pool.updates.do_next_background_update(1) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index afb11b9caf..e434e21aee 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py
@@ -661,14 +661,13 @@ class CachedListDescriptorTestCase(unittest.TestCase): @descriptors.cachedList("fn", "args1") async def list_fn(self, args1, arg2): - assert current_context().request == "c1" + assert current_context().name == "c1" # we want this to behave like an asynchronous function await run_on_reactor() - assert current_context().request == "c1" + assert current_context().name == "c1" return self.mock(args1, arg2) - with LoggingContext() as c1: - c1.request = "c1" + with LoggingContext("c1") as c1: obj = Cls() obj.mock.return_value = {10: "fish", 20: "chips"} d1 = obj.list_fn([10, 20], 2) diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index 34fdc9a43a..2f41333f4c 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py
@@ -27,7 +27,9 @@ class DictCacheTestCase(unittest.TestCase): key = "test_simple_cache_hit_full" v = self.cache.get(key) - self.assertEqual((False, set(), {}), v) + self.assertIs(v.full, False) + self.assertEqual(v.known_absent, set()) + self.assertEqual({}, v.value) seq = self.cache.sequence test_value = {"test": "test_simple_cache_hit_full"} diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 58ee918f65..5d9c4665aa 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py
@@ -17,11 +17,10 @@ from .. import unittest class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): - self.assertEquals(current_context().request, value) + self.assertEquals(current_context().name, value) def test_with_context(self): - with LoggingContext() as context_one: - context_one.request = "test" + with LoggingContext("test"): self._check_test_key("test") @defer.inlineCallbacks @@ -30,15 +29,13 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def competing_callback(): - with LoggingContext() as competing_context: - competing_context.request = "competing" + with LoggingContext("competing"): yield clock.sleep(0) self._check_test_key("competing") reactor.callLater(0, competing_callback) - with LoggingContext() as context_one: - context_one.request = "one" + with LoggingContext("one"): yield clock.sleep(0) self._check_test_key("one") @@ -47,9 +44,7 @@ class LoggingContextTestCase(unittest.TestCase): callback_completed = [False] - with LoggingContext() as context_one: - context_one.request = "one" - + with LoggingContext("one"): # fire off function, but don't wait on it. d2 = run_in_background(function) @@ -133,9 +128,7 @@ class LoggingContextTestCase(unittest.TestCase): sentinel_context = current_context() - with LoggingContext() as context_one: - context_one.request = "one" - + with LoggingContext("one"): d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable self.assertIs(current_context(), sentinel_context) @@ -149,9 +142,7 @@ class LoggingContextTestCase(unittest.TestCase): def test_make_deferred_yieldable_with_chained_deferreds(self): sentinel_context = current_context() - with LoggingContext() as context_one: - context_one.request = "one" - + with LoggingContext("one"): d1 = make_deferred_yieldable(_chained_deferred_function()) # make sure that the context was reset by make_deferred_yieldable self.assertIs(current_context(), sentinel_context) @@ -166,9 +157,7 @@ class LoggingContextTestCase(unittest.TestCase): """Check that make_deferred_yieldable does the right thing when its argument isn't actually a deferred""" - with LoggingContext() as context_one: - context_one.request = "one" - + with LoggingContext("one"): d1 = make_deferred_yieldable("bum") self._check_test_key("one") @@ -177,9 +166,9 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("one") def test_nested_logging_context(self): - with LoggingContext(request="foo"): + with LoggingContext("foo"): nested_context = nested_logging_context(suffix="bar") - self.assertEqual(nested_context.request, "foo-bar") + self.assertEqual(nested_context.name, "foo-bar") @defer.inlineCallbacks def test_make_deferred_yieldable_with_await(self): @@ -193,9 +182,7 @@ class LoggingContextTestCase(unittest.TestCase): sentinel_context = current_context() - with LoggingContext() as context_one: - context_one.request = "one" - + with LoggingContext("one"): d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable self.assertIs(current_context(), sentinel_context) diff --git a/tests/utils.py b/tests/utils.py
index be80b13760..a141ee6496 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -122,7 +122,6 @@ def default_config(name, parse=False): "enable_registration_captcha": False, "macaroon_secret_key": "not even a little secret", "trusted_third_party_id_servers": [], - "room_invite_state_types": [], "password_providers": [], "worker_replication_url": "", "worker_app": None,