summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--CHANGES.md64
-rw-r--r--INSTALL.md12
-rw-r--r--README.rst10
-rw-r--r--changelog.d/6455.feature1
-rw-r--r--changelog.d/7021.bugfix1
-rw-r--r--changelog.d/7613.feature1
-rw-r--r--changelog.d/7732.bugfix1
-rw-r--r--changelog.d/7740.misc1
-rw-r--r--changelog.d/7760.bugfix1
-rw-r--r--changelog.d/7765.misc1
-rw-r--r--changelog.d/7766.bugfix1
-rw-r--r--changelog.d/7768.misc1
-rw-r--r--changelog.d/7769.misc1
-rw-r--r--changelog.d/7770.misc1
-rw-r--r--changelog.d/7775.misc1
-rw-r--r--changelog.d/7776.doc1
-rw-r--r--changelog.d/7779.bugfix1
-rw-r--r--changelog.d/7780.misc1
-rw-r--r--changelog.d/7786.misc1
-rw-r--r--changelog.d/7789.doc1
-rw-r--r--changelog.d/7791.docker1
-rw-r--r--changelog.d/7793.misc1
-rw-r--r--changelog.d/7797.bugfix1
-rw-r--r--changelog.d/7798.feature1
-rw-r--r--changelog.d/7799.misc1
-rw-r--r--changelog.d/7800.misc1
-rw-r--r--changelog.d/7802.misc1
-rw-r--r--changelog.d/7804.bugfix1
-rw-r--r--changelog.d/7805.misc1
-rw-r--r--changelog.d/7810.bugfix1
-rw-r--r--changelog.d/7813.misc1
-rw-r--r--changelog.d/7815.bugfix1
-rw-r--r--changelog.d/7817.bugfix1
-rw-r--r--changelog.d/7820.misc1
-rw-r--r--changelog.d/7822.bugfix1
-rw-r--r--changelog.d/7827.feature1
-rw-r--r--changelog.d/7829.bugfix1
-rw-r--r--changelog.d/7830.feature1
-rw-r--r--changelog.d/7836.misc1
-rw-r--r--changelog.d/7839.docker1
-rw-r--r--changelog.d/7842.feature1
-rw-r--r--changelog.d/7844.bugfix1
-rw-r--r--changelog.d/7846.feature1
-rw-r--r--changelog.d/7847.feature1
-rw-r--r--changelog.d/7848.misc1
-rw-r--r--changelog.d/7849.misc1
-rw-r--r--changelog.d/7850.bugfix1
-rw-r--r--changelog.d/7851.misc1
-rw-r--r--changelog.d/7853.misc1
-rw-r--r--changelog.d/7854.bugfix1
-rw-r--r--changelog.d/7855.feature1
-rw-r--r--changelog.d/7856.misc1
-rw-r--r--changelog.d/7858.misc1
-rw-r--r--changelog.d/7859.bugfix1
-rw-r--r--changelog.d/7860.misc1
-rw-r--r--changelog.d/7861.misc1
-rw-r--r--changelog.d/7866.bugfix1
-rw-r--r--changelog.d/7868.misc1
-rw-r--r--changelog.d/7869.feature1
-rw-r--r--changelog.d/7870.misc1
-rw-r--r--changelog.d/7871.misc1
-rw-r--r--changelog.d/7872.bugfix1
-rw-r--r--changelog.d/7874.misc1
-rw-r--r--changelog.d/7877.misc1
-rw-r--r--changelog.d/7878.removal1
-rw-r--r--changelog.d/7879.feature1
-rw-r--r--changelog.d/7880.bugfix1
-rw-r--r--changelog.d/7881.misc1
-rw-r--r--changelog.d/7882.misc1
-rw-r--r--changelog.d/7884.misc1
-rw-r--r--changelog.d/7885.doc1
-rw-r--r--changelog.d/7888.misc1
-rw-r--r--changelog.d/7889.doc1
-rw-r--r--changelog.d/7890.misc1
-rw-r--r--changelog.d/7892.misc1
-rw-r--r--changelog.d/7895.bugfix1
-rw-r--r--changelog.d/7897.misc2
-rw-r--r--changelog.d/7908.feature1
-rw-r--r--changelog.d/7912.misc1
-rw-r--r--changelog.d/7914.misc1
-rw-r--r--changelog.d/7919.misc1
-rw-r--r--changelog.d/7927.misc1
-rw-r--r--changelog.d/7928.misc1
-rw-r--r--changelog.d/7929.misc1
-rw-r--r--changelog.d/7930.feature1
-rw-r--r--changelog.d/7931.feature1
-rw-r--r--changelog.d/7933.doc1
-rw-r--r--changelog.d/7934.doc1
-rw-r--r--changelog.d/7935.misc1
-rw-r--r--changelog.d/7939.misc1
-rwxr-xr-xcontrib/cmdclient/console.py21
-rw-r--r--contrib/cmdclient/http.py10
-rw-r--r--contrib/experiments/test_messaging.py55
-rw-r--r--contrib/grafana/synapse.json299
-rw-r--r--contrib/graph/graph.py21
-rw-r--r--contrib/graph/graph2.py11
-rw-r--r--contrib/graph/graph3.py22
-rw-r--r--contrib/jitsimeetbridge/jitsimeetbridge.py10
-rwxr-xr-xcontrib/scripts/kick_users.py6
-rw-r--r--debian/changelog18
-rw-r--r--docker/Dockerfile57
-rw-r--r--docker/README.md15
-rwxr-xr-xdocker/start.py12
-rw-r--r--docs/ACME.md5
-rw-r--r--docs/admin_api/purge_room.md2
-rw-r--r--docs/admin_api/rooms.md126
-rw-r--r--docs/admin_api/shutdown_room.md2
-rw-r--r--docs/admin_api/user_admin_api.rst6
-rw-r--r--docs/jwt.md19
-rw-r--r--docs/password_auth_providers.md187
-rw-r--r--docs/reverse_proxy.md16
-rw-r--r--docs/sample_config.yaml165
-rwxr-xr-xscripts-dev/build_debian_packages1
-rwxr-xr-xscripts-dev/lint.sh2
-rwxr-xr-xscripts/synapse_port_db12
-rw-r--r--stubs/txredisapi.pyi1
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py2
-rw-r--r--synapse/api/errors.py130
-rw-r--r--synapse/app/generic_worker.py107
-rw-r--r--synapse/app/homeserver.py25
-rw-r--r--synapse/appservice/api.py20
-rw-r--r--synapse/config/_base.py38
-rw-r--r--synapse/config/_base.pyi5
-rw-r--r--synapse/config/database.py2
-rw-r--r--synapse/config/emailconfig.py120
-rw-r--r--synapse/config/federation.py98
-rw-r--r--synapse/config/homeserver.py3
-rw-r--r--synapse/config/jwt_config.py28
-rw-r--r--synapse/config/push.py5
-rw-r--r--synapse/config/room.py7
-rw-r--r--synapse/config/server.py74
-rw-r--r--synapse/config/workers.py19
-rw-r--r--synapse/event_auth.py10
-rw-r--r--synapse/events/utils.py6
-rw-r--r--synapse/federation/federation_client.py44
-rw-r--r--synapse/federation/federation_server.py148
-rw-r--r--synapse/federation/send_queue.py16
-rw-r--r--synapse/federation/sender/__init__.py52
-rw-r--r--synapse/federation/sender/per_destination_queue.py26
-rw-r--r--synapse/federation/sender/transaction_manager.py2
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/federation/transport/server.py20
-rw-r--r--synapse/handlers/_base.py17
-rw-r--r--synapse/handlers/auth.py7
-rw-r--r--synapse/handlers/cas_handler.py2
-rw-r--r--synapse/handlers/deactivate_account.py65
-rw-r--r--synapse/handlers/device.py249
-rw-r--r--synapse/handlers/e2e_keys.py147
-rw-r--r--synapse/handlers/e2e_room_keys.py75
-rw-r--r--synapse/handlers/federation.py93
-rw-r--r--synapse/handlers/message.py284
-rw-r--r--synapse/handlers/presence.py43
-rw-r--r--synapse/handlers/profile.py63
-rw-r--r--synapse/handlers/receipts.py16
-rw-r--r--synapse/handlers/register.py22
-rw-r--r--synapse/handlers/room.py208
-rw-r--r--synapse/handlers/room_list.py62
-rw-r--r--synapse/handlers/room_member.py8
-rw-r--r--synapse/handlers/room_member_worker.py2
-rw-r--r--synapse/handlers/sync.py17
-rw-r--r--synapse/handlers/typing.py243
-rw-r--r--synapse/handlers/ui_auth/checkers.py38
-rw-r--r--synapse/http/client.py34
-rw-r--r--synapse/http/federation/matrix_federation_agent.py16
-rw-r--r--synapse/http/federation/srv_resolver.py10
-rw-r--r--synapse/http/server.py75
-rw-r--r--synapse/http/servlet.py10
-rw-r--r--synapse/http/site.py4
-rw-r--r--synapse/logging/context.py29
-rw-r--r--synapse/logging/opentracing.py63
-rw-r--r--synapse/logging/utils.py126
-rw-r--r--synapse/metrics/background_process_metrics.py10
-rw-r--r--synapse/notifier.py2
-rw-r--r--synapse/push/mailer.py61
-rw-r--r--synapse/push/pusherpool.py78
-rw-r--r--synapse/replication/http/_base.py4
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py2
-rw-r--r--synapse/replication/tcp/__init__.py2
-rw-r--r--synapse/replication/tcp/commands.py12
-rw-r--r--synapse/replication/tcp/handler.py348
-rw-r--r--synapse/replication/tcp/protocol.py21
-rw-r--r--synapse/replication/tcp/redis.py24
-rw-r--r--synapse/replication/tcp/streams/_base.py7
-rw-r--r--synapse/replication/tcp/streams/events.py4
-rw-r--r--synapse/res/templates/mail-Element.css7
-rw-r--r--synapse/res/templates/notice_expiry.html2
-rw-r--r--synapse/res/templates/notif_mail.html2
-rw-r--r--synapse/rest/admin/__init__.py4
-rw-r--r--synapse/rest/admin/rooms.py182
-rw-r--r--synapse/rest/admin/users.py10
-rw-r--r--synapse/rest/client/v1/login.py27
-rw-r--r--synapse/rest/client/v1/room.py28
-rw-r--r--synapse/rest/client/v2_alpha/_base.py11
-rw-r--r--synapse/rest/client/v2_alpha/sync.py9
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py4
-rw-r--r--synapse/server.py27
-rw-r--r--synapse/server.pyi9
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/background_updates.py5
-rw-r--r--synapse/storage/data_stores/main/__init__.py2
-rw-r--r--synapse/storage/data_stores/main/account_data.py16
-rw-r--r--synapse/storage/data_stores/main/appservice.py4
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py9
-rw-r--r--synapse/storage/data_stores/main/devices.py2
-rw-r--r--synapse/storage/data_stores/main/e2e_room_keys.py10
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py2
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py4
-rw-r--r--synapse/storage/data_stores/main/events.py76
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py14
-rw-r--r--synapse/storage/data_stores/main/events_worker.py9
-rw-r--r--synapse/storage/data_stores/main/group_server.py22
-rw-r--r--synapse/storage/data_stores/main/push_rule.py6
-rw-r--r--synapse/storage/data_stores/main/pusher.py6
-rw-r--r--synapse/storage/data_stores/main/receipts.py8
-rw-r--r--synapse/storage/data_stores/main/registration.py65
-rw-r--r--synapse/storage/data_stores/main/room.py17
-rw-r--r--synapse/storage/data_stores/main/roommember.py5
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/10drop_local_rejections_stream.sql22
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql22
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py34
-rw-r--r--synapse/storage/data_stores/main/search.py6
-rw-r--r--synapse/storage/data_stores/main/state.py65
-rw-r--r--synapse/storage/data_stores/main/stream.py97
-rw-r--r--synapse/storage/data_stores/main/tags.py5
-rw-r--r--synapse/storage/data_stores/main/ui_auth.py12
-rw-r--r--synapse/storage/data_stores/main/user_erasure_store.py26
-rw-r--r--synapse/storage/data_stores/state/store.py12
-rw-r--r--synapse/storage/engines/_base.py6
-rw-r--r--synapse/storage/engines/postgres.py6
-rw-r--r--synapse/storage/engines/sqlite.py13
-rw-r--r--synapse/storage/util/id_generators.py8
-rw-r--r--synapse/storage/util/sequence.py98
-rw-r--r--synapse/streams/config.py4
-rw-r--r--synapse/streams/events.py2
-rw-r--r--synapse/util/__init__.py2
-rw-r--r--synapse/util/async_helpers.py2
-rw-r--r--synapse/util/caches/descriptors.py2
-rw-r--r--synapse/util/distributor.py30
-rw-r--r--synapse/util/patch_inline_callbacks.py2
-rw-r--r--synapse/util/retryutils.py4
-rw-r--r--synapse/util/stringutils.py2
-rw-r--r--synapse/visibility.py4
-rw-r--r--tests/crypto/test_keyring.py2
-rw-r--r--tests/events/test_snapshot.py36
-rw-r--r--tests/handlers/test_device.py13
-rw-r--r--tests/handlers/test_e2e_keys.py296
-rw-r--r--tests/handlers/test_e2e_room_keys.py373
-rw-r--r--tests/handlers/test_profile.py17
-rw-r--r--tests/handlers/test_typing.py4
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py51
-rw-r--r--tests/http/federation/test_srv_resolver.py26
-rw-r--r--tests/replication/_base.py168
-rw-r--r--tests/replication/tcp/streams/test_events.py76
-rw-r--r--tests/replication/test_client_reader_shard.py96
-rw-r--r--tests/replication/test_federation_ack.py1
-rw-r--r--tests/replication/test_federation_sender_shard.py235
-rw-r--r--tests/replication/test_pusher_shard.py193
-rw-r--r--tests/rest/admin/test_room.py441
-rw-r--r--tests/rest/admin/test_user.py47
-rw-r--r--tests/rest/client/test_retention.py2
-rw-r--r--tests/rest/client/v1/test_login.py133
-rw-r--r--tests/rest/client/v1/test_presence.py2
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py2
-rw-r--r--tests/server.py26
-rw-r--r--tests/storage/test_room.py8
-rw-r--r--tests/storage/test_roommember.py56
-rw-r--r--tests/storage/test_state.py4
-rw-r--r--tests/test_federation.py35
-rw-r--r--tests/test_mau.py2
-rw-r--r--tests/test_server.py59
-rw-r--r--tests/test_utils/event_injection.py28
-rw-r--r--tests/test_visibility.py14
-rw-r--r--tests/unittest.py4
-rw-r--r--tests/util/test_logcontext.py4
-rw-r--r--tests/utils.py4
-rw-r--r--tox.ini2
277 files changed, 5682 insertions, 2761 deletions
diff --git a/CHANGES.md b/CHANGES.md

index 3a0fe606f8..6d4bd23e4e 100644 --- a/CHANGES.md +++ b/CHANGES.md
@@ -1,3 +1,67 @@ +Synapse 1.17.0 (2020-07-13) +=========================== + +Synapse 1.17.0 is identical to 1.17.0rc1, with the addition of the fix that was included in 1.16.1. + + +Synapse 1.16.1 (2020-07-10) +=========================== + +In some distributions of Synapse 1.16.0, we incorrectly included a database migration which added a new, unused table. This release removes the redundant table. + +Bugfixes +-------- + +- Drop table `local_rejections_stream` which was incorrectly added in Synapse 1.16.0. ([\#7816](https://github.com/matrix-org/synapse/issues/7816), [b1beb3ff5](https://github.com/matrix-org/synapse/commit/b1beb3ff5)) + + +Synapse 1.17.0rc1 (2020-07-09) +============================== + +Bugfixes +-------- + +- Fix inconsistent handling of upper and lower case in email addresses when used as identifiers for login, etc. Contributed by @dklimpel. ([\#7021](https://github.com/matrix-org/synapse/issues/7021)) +- Fix "Tried to close a non-active scope!" error messages when opentracing is enabled. ([\#7732](https://github.com/matrix-org/synapse/issues/7732)) +- Fix incorrect error message when database CTYPE was set incorrectly. ([\#7760](https://github.com/matrix-org/synapse/issues/7760)) +- Fix to not ignore `set_tweak` actions in Push Rules that have no `value`, as permitted by the specification. ([\#7766](https://github.com/matrix-org/synapse/issues/7766)) +- Fix synctl to handle empty config files correctly. Contributed by @kotovalexarian. ([\#7779](https://github.com/matrix-org/synapse/issues/7779)) +- Fixes a long standing bug in worker mode where worker information was saved in the devices table instead of the original IP address and user agent. ([\#7797](https://github.com/matrix-org/synapse/issues/7797)) +- Fix 'stuck invites' which happen when we are unable to reject a room invite received over federation. ([\#7804](https://github.com/matrix-org/synapse/issues/7804), [\#7809](https://github.com/matrix-org/synapse/issues/7809), [\#7810](https://github.com/matrix-org/synapse/issues/7810)) + + +Updates to the Docker image +--------------------------- + +- Include libwebp in the Docker file to properly handle webp image uploads. ([\#7791](https://github.com/matrix-org/synapse/issues/7791)) + + +Improved Documentation +---------------------- + +- Improve the documentation of the non-standard JSON web token login type. ([\#7776](https://github.com/matrix-org/synapse/issues/7776)) +- Update doc links for caddy. Contributed by Nicolai Søborg. ([\#7789](https://github.com/matrix-org/synapse/issues/7789)) + + +Internal Changes +---------------- + +- Refactor getting replication updates from database. ([\#7740](https://github.com/matrix-org/synapse/issues/7740)) +- Send push notifications with a high or low priority depending upon whether they may generate user-observable effects. ([\#7765](https://github.com/matrix-org/synapse/issues/7765)) +- Use symbolic names for replication stream names. ([\#7768](https://github.com/matrix-org/synapse/issues/7768)) +- Add early returns to `_check_for_soft_fail`. ([\#7769](https://github.com/matrix-org/synapse/issues/7769)) +- Fix up `synapse.handlers.federation` to pass mypy. ([\#7770](https://github.com/matrix-org/synapse/issues/7770)) +- Convert the appserver handler to async/await. ([\#7775](https://github.com/matrix-org/synapse/issues/7775)) +- Allow to use higher versions of prometheus_client <0.9.0 which are expected to introduce no breaking changes. Contributed by Oliver Kurz. ([\#7780](https://github.com/matrix-org/synapse/issues/7780)) +- Update linting scripts and codebase to be compatible with `isort` v5. ([\#7786](https://github.com/matrix-org/synapse/issues/7786)) +- Stop populating unused table `local_invites`. ([\#7793](https://github.com/matrix-org/synapse/issues/7793)) +- Ensure that strings (not bytes) are passed into JSON serialization. ([\#7799](https://github.com/matrix-org/synapse/issues/7799)) +- Switch from simplejson to the standard library json. ([\#7800](https://github.com/matrix-org/synapse/issues/7800)) +- Add `signing_key` property to `HomeServer` to save code duplication. ([\#7805](https://github.com/matrix-org/synapse/issues/7805)) +- Improve stacktraces from exceptions in background processes. ([\#7808](https://github.com/matrix-org/synapse/issues/7808)) +- Fix various spelling errors in comments and log lines. ([\#7811](https://github.com/matrix-org/synapse/issues/7811)) + + Synapse 1.16.0 (2020-07-08) =========================== diff --git a/INSTALL.md b/INSTALL.md
index ef80a26c3f..b507de7442 100644 --- a/INSTALL.md +++ b/INSTALL.md
@@ -405,13 +405,11 @@ so, you will need to edit `homeserver.yaml`, as follows: ``` * You will also need to uncomment the `tls_certificate_path` and - `tls_private_key_path` lines under the `TLS` section. You can either - point these settings at an existing certificate and key, or you can - enable Synapse's built-in ACME (Let's Encrypt) support. Instructions - for having Synapse automatically provision and renew federation - certificates through ACME can be found at [ACME.md](docs/ACME.md). - Note that, as pointed out in that document, this feature will not - work with installs set up after November 2019. + `tls_private_key_path` lines under the `TLS` section. You will need to manage + provisioning of these certificates yourself — Synapse had built-in ACME + support, but the ACMEv1 protocol Synapse implements is deprecated, not + allowed by LetsEncrypt for new sites, and will break for existing sites in + late 2020. See [ACME.md](docs/ACME.md). If you are using your own certificate, be sure to use a `.pem` file that includes the full certificate chain including any intermediate certificates diff --git a/README.rst b/README.rst
index 38376e23c2..f7116b3480 100644 --- a/README.rst +++ b/README.rst
@@ -188,12 +188,8 @@ Using PostgreSQL ================ Synapse offers two database engines: - * `SQLite <https://sqlite.org/>`_ * `PostgreSQL <https://www.postgresql.org>`_ - -By default Synapse uses SQLite in and doing so trades performance for convenience. -SQLite is only recommended in Synapse for testing purposes or for servers with -light workloads. + * `SQLite <https://sqlite.org/>`_ Almost all installations should opt to use PostgreSQL. Advantages include: @@ -207,6 +203,10 @@ Almost all installations should opt to use PostgreSQL. Advantages include: For information on how to install and use PostgreSQL, please see `docs/postgres.md <docs/postgres.md>`_. +By default Synapse uses SQLite and in doing so trades performance for convenience. +SQLite is only recommended in Synapse for testing purposes or for servers with +light workloads. + .. _reverse-proxy: Using a reverse proxy with Synapse diff --git a/changelog.d/6455.feature b/changelog.d/6455.feature new file mode 100644
index 0000000000..eb286cb70f --- /dev/null +++ b/changelog.d/6455.feature
@@ -0,0 +1 @@ +Include room states on invite events that are sent to application services. Contributed by @Sorunome. diff --git a/changelog.d/7021.bugfix b/changelog.d/7021.bugfix deleted file mode 100644
index 140fe37b2d..0000000000 --- a/changelog.d/7021.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix inconsistent handling of upper and lower case in email addresses when used as identifiers for login, etc. Contributed by @dklimpel. diff --git a/changelog.d/7613.feature b/changelog.d/7613.feature new file mode 100644
index 0000000000..b671dc2fcc --- /dev/null +++ b/changelog.d/7613.feature
@@ -0,0 +1 @@ +Add delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel. diff --git a/changelog.d/7732.bugfix b/changelog.d/7732.bugfix deleted file mode 100644
index d5e352e141..0000000000 --- a/changelog.d/7732.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix "Tried to close a non-active scope!" error messages when opentracing is enabled. diff --git a/changelog.d/7740.misc b/changelog.d/7740.misc deleted file mode 100644
index f93149502e..0000000000 --- a/changelog.d/7740.misc +++ /dev/null
@@ -1 +0,0 @@ -Refactor getting replication updates from database. diff --git a/changelog.d/7760.bugfix b/changelog.d/7760.bugfix deleted file mode 100644
index f6081f3d30..0000000000 --- a/changelog.d/7760.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix incorrect error message when database CTYPE was set incorrectly. diff --git a/changelog.d/7765.misc b/changelog.d/7765.misc deleted file mode 100644
index fa9cfd24cb..0000000000 --- a/changelog.d/7765.misc +++ /dev/null
@@ -1 +0,0 @@ -Send push notifications with a high or low priority depending upon whether they may generate user-observable effects. diff --git a/changelog.d/7766.bugfix b/changelog.d/7766.bugfix deleted file mode 100644
index ec5ecd8055..0000000000 --- a/changelog.d/7766.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix to not ignore `set_tweak` actions in Push Rules that have no `value`, as permitted by the specification. diff --git a/changelog.d/7768.misc b/changelog.d/7768.misc deleted file mode 100644
index dfb3d24c7d..0000000000 --- a/changelog.d/7768.misc +++ /dev/null
@@ -1 +0,0 @@ -Use symbolic names for replication stream names. diff --git a/changelog.d/7769.misc b/changelog.d/7769.misc deleted file mode 100644
index 2e200286ce..0000000000 --- a/changelog.d/7769.misc +++ /dev/null
@@ -1 +0,0 @@ -Add early returns to `_check_for_soft_fail`. diff --git a/changelog.d/7770.misc b/changelog.d/7770.misc deleted file mode 100644
index 5b864084be..0000000000 --- a/changelog.d/7770.misc +++ /dev/null
@@ -1 +0,0 @@ -Fix up `synapse.handlers.federation` to pass mypy. diff --git a/changelog.d/7775.misc b/changelog.d/7775.misc deleted file mode 100644
index af6fdb782f..0000000000 --- a/changelog.d/7775.misc +++ /dev/null
@@ -1 +0,0 @@ -Convert the appserver handler to async/await. diff --git a/changelog.d/7776.doc b/changelog.d/7776.doc deleted file mode 100644
index e686215688..0000000000 --- a/changelog.d/7776.doc +++ /dev/null
@@ -1 +0,0 @@ -Improve the documentation of the non-standard JSON web token login type. diff --git a/changelog.d/7779.bugfix b/changelog.d/7779.bugfix deleted file mode 100644
index 61de45d570..0000000000 --- a/changelog.d/7779.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix synctl to handle empty config files correctly. Contributed by @kotovalexarian. diff --git a/changelog.d/7780.misc b/changelog.d/7780.misc deleted file mode 100644
index a627bea458..0000000000 --- a/changelog.d/7780.misc +++ /dev/null
@@ -1 +0,0 @@ -Allow to use higher versions of prometheus_client <0.9.0 which are expected to introduce no breaking changes. Contributed by Oliver Kurz. diff --git a/changelog.d/7786.misc b/changelog.d/7786.misc deleted file mode 100644
index 27af2681dc..0000000000 --- a/changelog.d/7786.misc +++ /dev/null
@@ -1 +0,0 @@ -Update linting scripts and codebase to be compatible with `isort` v5. diff --git a/changelog.d/7789.doc b/changelog.d/7789.doc deleted file mode 100644
index 254411c769..0000000000 --- a/changelog.d/7789.doc +++ /dev/null
@@ -1 +0,0 @@ -Update doc links for caddy. Contributed by Nicolai Søborg. diff --git a/changelog.d/7791.docker b/changelog.d/7791.docker deleted file mode 100644
index a114159d4e..0000000000 --- a/changelog.d/7791.docker +++ /dev/null
@@ -1 +0,0 @@ -Include libwebp in the Docker file to properly handle webp image uploads. diff --git a/changelog.d/7793.misc b/changelog.d/7793.misc deleted file mode 100644
index 2b6cfbe274..0000000000 --- a/changelog.d/7793.misc +++ /dev/null
@@ -1 +0,0 @@ -Stop populating unused table `local_invites`. diff --git a/changelog.d/7797.bugfix b/changelog.d/7797.bugfix deleted file mode 100644
index c1259871da..0000000000 --- a/changelog.d/7797.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fixes a long standing bug in worker mode where worker information was saved in the devices table instead of the original IP address and user agent. diff --git a/changelog.d/7798.feature b/changelog.d/7798.feature new file mode 100644
index 0000000000..56ffaf0d4a --- /dev/null +++ b/changelog.d/7798.feature
@@ -0,0 +1 @@ +Add experimental support for running multiple federation sender processes. diff --git a/changelog.d/7799.misc b/changelog.d/7799.misc deleted file mode 100644
index 448b286df4..0000000000 --- a/changelog.d/7799.misc +++ /dev/null
@@ -1 +0,0 @@ -Ensure that strings (not bytes) are passed into JSON serialization. diff --git a/changelog.d/7800.misc b/changelog.d/7800.misc deleted file mode 100644
index ce2346b3d4..0000000000 --- a/changelog.d/7800.misc +++ /dev/null
@@ -1 +0,0 @@ -Switch from simplejson to the standard library json. diff --git a/changelog.d/7802.misc b/changelog.d/7802.misc new file mode 100644
index 0000000000..d81f8875c5 --- /dev/null +++ b/changelog.d/7802.misc
@@ -0,0 +1 @@ + Switch from simplejson to the standard library json. diff --git a/changelog.d/7804.bugfix b/changelog.d/7804.bugfix deleted file mode 100644
index 2772eeb0db..0000000000 --- a/changelog.d/7804.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix 'stuck invites' which happen when we are unable to reject a room invite received over federation. diff --git a/changelog.d/7805.misc b/changelog.d/7805.misc deleted file mode 100644
index cbae08774a..0000000000 --- a/changelog.d/7805.misc +++ /dev/null
@@ -1 +0,0 @@ -Add `signing_key` property to `HomeServer` to save code duplication. diff --git a/changelog.d/7810.bugfix b/changelog.d/7810.bugfix deleted file mode 100644
index 2772eeb0db..0000000000 --- a/changelog.d/7810.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix 'stuck invites' which happen when we are unable to reject a room invite received over federation. diff --git a/changelog.d/7813.misc b/changelog.d/7813.misc new file mode 100644
index 0000000000..f3005cfd27 --- /dev/null +++ b/changelog.d/7813.misc
@@ -0,0 +1 @@ +Add type hints to the http server code and remove an unused parameter. diff --git a/changelog.d/7815.bugfix b/changelog.d/7815.bugfix new file mode 100644
index 0000000000..3e7c7d412e --- /dev/null +++ b/changelog.d/7815.bugfix
@@ -0,0 +1 @@ +Fix detection of out of sync remote device lists when receiving events from remote users. diff --git a/changelog.d/7817.bugfix b/changelog.d/7817.bugfix new file mode 100644
index 0000000000..1c001070d5 --- /dev/null +++ b/changelog.d/7817.bugfix
@@ -0,0 +1 @@ +Fix bug where Synapse fails to process an incoming event over federation if the server is missing too much of the event's auth chain. diff --git a/changelog.d/7820.misc b/changelog.d/7820.misc new file mode 100644
index 0000000000..b77b5672e3 --- /dev/null +++ b/changelog.d/7820.misc
@@ -0,0 +1 @@ +Add type hints to synapse.api.errors module. diff --git a/changelog.d/7822.bugfix b/changelog.d/7822.bugfix new file mode 100644
index 0000000000..faf249a678 --- /dev/null +++ b/changelog.d/7822.bugfix
@@ -0,0 +1 @@ +Fix a bug causing Synapse to misinterpret the value `off` for `encryption_enabled_by_default_for_room_type` in its configuration file(s) if that value isn't surrounded by quotes. This bug was introduced in v1.16.0. diff --git a/changelog.d/7827.feature b/changelog.d/7827.feature new file mode 100644
index 0000000000..0fd116e198 --- /dev/null +++ b/changelog.d/7827.feature
@@ -0,0 +1 @@ +Add the option to validate the `iss` and `aud` claims for JWT logins. diff --git a/changelog.d/7829.bugfix b/changelog.d/7829.bugfix new file mode 100644
index 0000000000..dcbf385de6 --- /dev/null +++ b/changelog.d/7829.bugfix
@@ -0,0 +1 @@ +Fix bug where we did not always pass in `app_name` or `server_name` to email templates, including e.g. for registration emails. diff --git a/changelog.d/7830.feature b/changelog.d/7830.feature new file mode 100644
index 0000000000..b4f614084d --- /dev/null +++ b/changelog.d/7830.feature
@@ -0,0 +1 @@ +Add support for handling registration requests across multiple client reader workers. diff --git a/changelog.d/7836.misc b/changelog.d/7836.misc new file mode 100644
index 0000000000..a3a97c7590 --- /dev/null +++ b/changelog.d/7836.misc
@@ -0,0 +1 @@ +Ensure that calls to `json.dumps` are compatible with the standard library json. diff --git a/changelog.d/7839.docker b/changelog.d/7839.docker new file mode 100644
index 0000000000..cdf3c9631c --- /dev/null +++ b/changelog.d/7839.docker
@@ -0,0 +1 @@ +Base docker image on Debian Buster rather than Alpine Linux. Contributed by @maquis196. diff --git a/changelog.d/7842.feature b/changelog.d/7842.feature new file mode 100644
index 0000000000..727deb01c9 --- /dev/null +++ b/changelog.d/7842.feature
@@ -0,0 +1 @@ +Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH. diff --git a/changelog.d/7844.bugfix b/changelog.d/7844.bugfix new file mode 100644
index 0000000000..ad296f1b3c --- /dev/null +++ b/changelog.d/7844.bugfix
@@ -0,0 +1 @@ +Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`. diff --git a/changelog.d/7846.feature b/changelog.d/7846.feature new file mode 100644
index 0000000000..997376fe42 --- /dev/null +++ b/changelog.d/7846.feature
@@ -0,0 +1 @@ +Allow email subjects to be customised through Synapse's configuration. diff --git a/changelog.d/7847.feature b/changelog.d/7847.feature new file mode 100644
index 0000000000..4b9a8d8569 --- /dev/null +++ b/changelog.d/7847.feature
@@ -0,0 +1 @@ +Add the ability to re-activate an account from the admin API. diff --git a/changelog.d/7848.misc b/changelog.d/7848.misc new file mode 100644
index 0000000000..d9db1d8357 --- /dev/null +++ b/changelog.d/7848.misc
@@ -0,0 +1 @@ +Remove redundant `retry_on_integrity_error` wrapper for event persistence code. diff --git a/changelog.d/7849.misc b/changelog.d/7849.misc new file mode 100644
index 0000000000..e3296418c1 --- /dev/null +++ b/changelog.d/7849.misc
@@ -0,0 +1 @@ +Consistently use `db_to_json` to convert from database values to JSON objects. diff --git a/changelog.d/7850.bugfix b/changelog.d/7850.bugfix new file mode 100644
index 0000000000..5f19a89043 --- /dev/null +++ b/changelog.d/7850.bugfix
@@ -0,0 +1 @@ +Fix "AttributeError: 'str' object has no attribute 'get'" error message when applying per-room message retention policies. The bug was introduced in Synapse 1.7.0. diff --git a/changelog.d/7851.misc b/changelog.d/7851.misc new file mode 100644
index 0000000000..e5cf540edf --- /dev/null +++ b/changelog.d/7851.misc
@@ -0,0 +1 @@ +Convert E2E keys and room keys handlers to async/await. diff --git a/changelog.d/7853.misc b/changelog.d/7853.misc new file mode 100644
index 0000000000..b4f614084d --- /dev/null +++ b/changelog.d/7853.misc
@@ -0,0 +1 @@ +Add support for handling registration requests across multiple client reader workers. diff --git a/changelog.d/7854.bugfix b/changelog.d/7854.bugfix new file mode 100644
index 0000000000..b11f9dedfe --- /dev/null +++ b/changelog.d/7854.bugfix
@@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation. diff --git a/changelog.d/7855.feature b/changelog.d/7855.feature new file mode 100644
index 0000000000..2b6a9f0e71 --- /dev/null +++ b/changelog.d/7855.feature
@@ -0,0 +1 @@ +Add experimental support for running multiple pusher workers. diff --git a/changelog.d/7856.misc b/changelog.d/7856.misc new file mode 100644
index 0000000000..7d99fb67be --- /dev/null +++ b/changelog.d/7856.misc
@@ -0,0 +1 @@ +Small performance improvement in typing processing. diff --git a/changelog.d/7858.misc b/changelog.d/7858.misc new file mode 100644
index 0000000000..8f0fc2de74 --- /dev/null +++ b/changelog.d/7858.misc
@@ -0,0 +1 @@ +The default value of `filter_timeline_limit` was changed from -1 (no limit) to 100. diff --git a/changelog.d/7859.bugfix b/changelog.d/7859.bugfix new file mode 100644
index 0000000000..19cff4b061 --- /dev/null +++ b/changelog.d/7859.bugfix
@@ -0,0 +1 @@ +Fix a bug which allowed empty rooms to be rejoined over federation. diff --git a/changelog.d/7860.misc b/changelog.d/7860.misc new file mode 100644
index 0000000000..fdd48b955c --- /dev/null +++ b/changelog.d/7860.misc
@@ -0,0 +1 @@ +Convert _base, profile, and _receipts handlers to async/await. diff --git a/changelog.d/7861.misc b/changelog.d/7861.misc new file mode 100644
index 0000000000..ada616c62f --- /dev/null +++ b/changelog.d/7861.misc
@@ -0,0 +1 @@ +Optimise queueing of inbound replication commands. diff --git a/changelog.d/7866.bugfix b/changelog.d/7866.bugfix new file mode 100644
index 0000000000..6b5c3c4eca --- /dev/null +++ b/changelog.d/7866.bugfix
@@ -0,0 +1 @@ +Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers. diff --git a/changelog.d/7868.misc b/changelog.d/7868.misc new file mode 100644
index 0000000000..eadef5e4c2 --- /dev/null +++ b/changelog.d/7868.misc
@@ -0,0 +1 @@ +Convert synapse.app and federation client to async/await. diff --git a/changelog.d/7869.feature b/changelog.d/7869.feature new file mode 100644
index 0000000000..1982049a52 --- /dev/null +++ b/changelog.d/7869.feature
@@ -0,0 +1 @@ +Add experimental support for moving typing off master. diff --git a/changelog.d/7870.misc b/changelog.d/7870.misc new file mode 100644
index 0000000000..27cce2f2f9 --- /dev/null +++ b/changelog.d/7870.misc
@@ -0,0 +1 @@ +Add some type annotations to `HomeServer` and `BaseHandler`. diff --git a/changelog.d/7871.misc b/changelog.d/7871.misc new file mode 100644
index 0000000000..4d398a9f3a --- /dev/null +++ b/changelog.d/7871.misc
@@ -0,0 +1 @@ +Convert device handler to async/await. diff --git a/changelog.d/7872.bugfix b/changelog.d/7872.bugfix new file mode 100644
index 0000000000..b21f8e1f14 --- /dev/null +++ b/changelog.d/7872.bugfix
@@ -0,0 +1 @@ +Fix a long standing bug where the tracing of async functions with opentracing was broken. diff --git a/changelog.d/7874.misc b/changelog.d/7874.misc new file mode 100644
index 0000000000..f75c8d1843 --- /dev/null +++ b/changelog.d/7874.misc
@@ -0,0 +1 @@ +Convert the federation agent and related code to async/await. diff --git a/changelog.d/7877.misc b/changelog.d/7877.misc new file mode 100644
index 0000000000..a62aa0329c --- /dev/null +++ b/changelog.d/7877.misc
@@ -0,0 +1 @@ +Clean up `PreserveLoggingContext`. diff --git a/changelog.d/7878.removal b/changelog.d/7878.removal new file mode 100644
index 0000000000..d5a4066624 --- /dev/null +++ b/changelog.d/7878.removal
@@ -0,0 +1 @@ +Remove unused `synapse_replication_tcp_resource_invalidate_cache` prometheus metric. diff --git a/changelog.d/7879.feature b/changelog.d/7879.feature new file mode 100644
index 0000000000..c89655f000 --- /dev/null +++ b/changelog.d/7879.feature
@@ -0,0 +1 @@ +Report CPU metrics to prometheus for time spent processing replication commands. diff --git a/changelog.d/7880.bugfix b/changelog.d/7880.bugfix new file mode 100644
index 0000000000..356add0996 --- /dev/null +++ b/changelog.d/7880.bugfix
@@ -0,0 +1 @@ +Fix "TypeError in `synapse.notifier`" exceptions. diff --git a/changelog.d/7881.misc b/changelog.d/7881.misc new file mode 100644
index 0000000000..6799117099 --- /dev/null +++ b/changelog.d/7881.misc
@@ -0,0 +1 @@ +Change "unknown room version" logging from 'error' to 'warning'. diff --git a/changelog.d/7882.misc b/changelog.d/7882.misc new file mode 100644
index 0000000000..9002749335 --- /dev/null +++ b/changelog.d/7882.misc
@@ -0,0 +1 @@ +Stop using `device_max_stream_id` table and just use `device_inbox.stream_id`. diff --git a/changelog.d/7884.misc b/changelog.d/7884.misc new file mode 100644
index 0000000000..36c7d4de67 --- /dev/null +++ b/changelog.d/7884.misc
@@ -0,0 +1 @@ +Convert the message handler to async/await. diff --git a/changelog.d/7885.doc b/changelog.d/7885.doc new file mode 100644
index 0000000000..cbe9de4082 --- /dev/null +++ b/changelog.d/7885.doc
@@ -0,0 +1 @@ +Provide instructions on using `register_new_matrix_user` via docker. diff --git a/changelog.d/7888.misc b/changelog.d/7888.misc new file mode 100644
index 0000000000..5328d2dcca --- /dev/null +++ b/changelog.d/7888.misc
@@ -0,0 +1 @@ +Remove Ubuntu Eoan from the list of `.deb` packages that we build as it is now end-of-life. Contributed by @gary-kim. diff --git a/changelog.d/7889.doc b/changelog.d/7889.doc new file mode 100644
index 0000000000..d91f62fd39 --- /dev/null +++ b/changelog.d/7889.doc
@@ -0,0 +1 @@ +Change the sample config postgres user section to use `synapse_user` instead of `synapse` to align with the documentation. \ No newline at end of file diff --git a/changelog.d/7890.misc b/changelog.d/7890.misc new file mode 100644
index 0000000000..8c127084bc --- /dev/null +++ b/changelog.d/7890.misc
@@ -0,0 +1 @@ +Fix typo in generated config file. Contributed by @ThiefMaster. diff --git a/changelog.d/7892.misc b/changelog.d/7892.misc new file mode 100644
index 0000000000..ef4cfa04fd --- /dev/null +++ b/changelog.d/7892.misc
@@ -0,0 +1 @@ +Import ABC from `collections.abc` for Python 3.10 compatibility. diff --git a/changelog.d/7895.bugfix b/changelog.d/7895.bugfix new file mode 100644
index 0000000000..1ae7f8ca7c --- /dev/null +++ b/changelog.d/7895.bugfix
@@ -0,0 +1 @@ +Fix deprecation warning due to invalid escape sequences. \ No newline at end of file diff --git a/changelog.d/7897.misc b/changelog.d/7897.misc new file mode 100644
index 0000000000..77772533fd --- /dev/null +++ b/changelog.d/7897.misc
@@ -0,0 +1,2 @@ +Remove unused functions `time_function`, `trace_function`, `get_previous_frames` +and `get_previous_frame` from `synapse.logging.utils` module. \ No newline at end of file diff --git a/changelog.d/7908.feature b/changelog.d/7908.feature new file mode 100644
index 0000000000..4b9a8d8569 --- /dev/null +++ b/changelog.d/7908.feature
@@ -0,0 +1 @@ +Add the ability to re-activate an account from the admin API. diff --git a/changelog.d/7912.misc b/changelog.d/7912.misc new file mode 100644
index 0000000000..d619590070 --- /dev/null +++ b/changelog.d/7912.misc
@@ -0,0 +1 @@ +Convert `RoomListHandler` to async/await. diff --git a/changelog.d/7914.misc b/changelog.d/7914.misc new file mode 100644
index 0000000000..710553249c --- /dev/null +++ b/changelog.d/7914.misc
@@ -0,0 +1 @@ +Lint the `contrib/` directory in CI and linting scripts, add `synctl` to the linting script for consistency with CI. diff --git a/changelog.d/7919.misc b/changelog.d/7919.misc new file mode 100644
index 0000000000..addaa35183 --- /dev/null +++ b/changelog.d/7919.misc
@@ -0,0 +1 @@ +Use Element CSS and logo in notification emails when app name is Element. diff --git a/changelog.d/7927.misc b/changelog.d/7927.misc new file mode 100644
index 0000000000..3b864da03d --- /dev/null +++ b/changelog.d/7927.misc
@@ -0,0 +1 @@ +Optimisation to /sync handling: skip serializing the response if the client has already disconnected. diff --git a/changelog.d/7928.misc b/changelog.d/7928.misc new file mode 100644
index 0000000000..5f3aa5de0a --- /dev/null +++ b/changelog.d/7928.misc
@@ -0,0 +1 @@ +When a client disconnects, don't log it as 'Error processing request'. diff --git a/changelog.d/7929.misc b/changelog.d/7929.misc new file mode 100644
index 0000000000..d72856fe03 --- /dev/null +++ b/changelog.d/7929.misc
@@ -0,0 +1 @@ +Add debugging to `/sync` response generation (disabled by default). diff --git a/changelog.d/7930.feature b/changelog.d/7930.feature new file mode 100644
index 0000000000..a27e4812da --- /dev/null +++ b/changelog.d/7930.feature
@@ -0,0 +1 @@ +Abort federation requests where the client disconnects before the ratelimiter expires. diff --git a/changelog.d/7931.feature b/changelog.d/7931.feature new file mode 100644
index 0000000000..30eb33048b --- /dev/null +++ b/changelog.d/7931.feature
@@ -0,0 +1 @@ +Cache responses to `/_matrix/federation/v1/state_ids` to reduce duplicated work. diff --git a/changelog.d/7933.doc b/changelog.d/7933.doc new file mode 100644
index 0000000000..7022fd578b --- /dev/null +++ b/changelog.d/7933.doc
@@ -0,0 +1 @@ +Reorder database paragraphs to promote postgres over sqlite. diff --git a/changelog.d/7934.doc b/changelog.d/7934.doc new file mode 100644
index 0000000000..992d5358a7 --- /dev/null +++ b/changelog.d/7934.doc
@@ -0,0 +1 @@ +Update the dates of ACME v1's end of life in [`ACME.md`](https://github.com/matrix-org/synapse/blob/master/docs/ACME.md). diff --git a/changelog.d/7935.misc b/changelog.d/7935.misc new file mode 100644
index 0000000000..3771f99bf2 --- /dev/null +++ b/changelog.d/7935.misc
@@ -0,0 +1 @@ +Convert the auth providers to be async/await. diff --git a/changelog.d/7939.misc b/changelog.d/7939.misc new file mode 100644
index 0000000000..798833b3af --- /dev/null +++ b/changelog.d/7939.misc
@@ -0,0 +1 @@ +Convert presence handler helpers to async/await. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py
index 48da410d94..77422f5e5d 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py
@@ -17,9 +17,6 @@ """ Starts a synapse client console. """ from __future__ import print_function -from twisted.internet import reactor, defer, threads -from http import TwistedHttpClient - import argparse import cmd import getpass @@ -28,12 +25,14 @@ import shlex import sys import time import urllib -import urlparse +from http import TwistedHttpClient -import nacl.signing import nacl.encoding +import nacl.signing +import urlparse +from signedjson.sign import SignatureVerifyException, verify_signed_json -from signedjson.sign import verify_signed_json, SignatureVerifyException +from twisted.internet import defer, reactor, threads CONFIG_JSON = "cmdclient_config.json" @@ -493,7 +492,7 @@ class SynapseCmd(cmd.Cmd): "list messages <roomid> from=END&to=START&limit=3" """ args = self._parse(line, ["type", "roomid", "qp"]) - if not "type" in args or not "roomid" in args: + if "type" not in args or "roomid" not in args: print("Must specify type and room ID.") return if args["type"] not in ["members", "messages"]: @@ -508,7 +507,7 @@ class SynapseCmd(cmd.Cmd): try: key_value = key_value_str.split("=") qp[key_value[0]] = key_value[1] - except: + except Exception: print("Bad query param: %s" % key_value) return @@ -585,7 +584,7 @@ class SynapseCmd(cmd.Cmd): parsed_url = urlparse.urlparse(args["path"]) qp.update(urlparse.parse_qs(parsed_url.query)) args["path"] = parsed_url.path - except: + except Exception: pass reactor.callFromThread( @@ -772,10 +771,10 @@ def main(server_url, identity_server_url, username, token, config_path): syn_cmd.config = json.load(config) try: http_client.verbose = "on" == syn_cmd.config["verbose"] - except: + except Exception: pass print("Loaded config from %s" % config_path) - except: + except Exception: pass # Twisted-specific: Runs the command processor in Twisted's event loop diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py
index 0e101d2be5..e2534ee584 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py
@@ -14,14 +14,14 @@ # limitations under the License. from __future__ import print_function -from twisted.web.client import Agent, readBody -from twisted.web.http_headers import Headers -from twisted.internet import defer, reactor - -from pprint import pformat import json import urllib +from pprint import pformat + +from twisted.internet import defer, reactor +from twisted.web.client import Agent, readBody +from twisted.web.http_headers import Headers class HttpClient(object): diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py
index 3bbbcfa1b4..a84ec4ecae 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py
@@ -28,27 +28,24 @@ Currently assumes the local address is localhost:<port> """ -from synapse.federation import ReplicationHandler - -from synapse.federation.units import Pdu - -from synapse.util import origin_from_ucid - -from synapse.app.homeserver import SynapseHomeServer - -# from synapse.logging.utils import log_function - -from twisted.internet import reactor, defer -from twisted.python import log - import argparse +import curses.wrapper import json import logging import os import re import cursesio -import curses.wrapper + +from twisted.internet import defer, reactor +from twisted.python import log + +from synapse.app.homeserver import SynapseHomeServer +from synapse.federation import ReplicationHandler +from synapse.federation.units import Pdu +from synapse.util import origin_from_ucid + +# from synapse.logging.utils import log_function logger = logging.getLogger("example") @@ -75,7 +72,7 @@ class InputOutput(object): """ try: - m = re.match("^join (\S+)$", line) + m = re.match(r"^join (\S+)$", line) if m: # The `sender` wants to join a room. (room_name,) = m.groups() @@ -84,7 +81,7 @@ class InputOutput(object): # self.print_line("OK.") return - m = re.match("^invite (\S+) (\S+)$", line) + m = re.match(r"^invite (\S+) (\S+)$", line) if m: # `sender` wants to invite someone to a room room_name, invitee = m.groups() @@ -93,7 +90,7 @@ class InputOutput(object): # self.print_line("OK.") return - m = re.match("^send (\S+) (.*)$", line) + m = re.match(r"^send (\S+) (.*)$", line) if m: # `sender` wants to message a room room_name, body = m.groups() @@ -102,7 +99,7 @@ class InputOutput(object): # self.print_line("OK.") return - m = re.match("^backfill (\S+)$", line) + m = re.match(r"^backfill (\S+)$", line) if m: # we want to backfill a room (room_name,) = m.groups() @@ -201,16 +198,6 @@ class HomeServer(ReplicationHandler): % (pdu.context, pdu.pdu_type, json.dumps(pdu.content)) ) - # def on_state_change(self, pdu): - ##self.output.print_line("#%s (state) %s *** %s" % - ##(pdu.context, pdu.state_key, pdu.pdu_type) - ##) - - # if "joinee" in pdu.content: - # self._on_join(pdu.context, pdu.content["joinee"]) - # elif "invitee" in pdu.content: - # self._on_invite(pdu.origin, pdu.context, pdu.content["invitee"]) - def _on_message(self, pdu): """ We received a message """ @@ -314,7 +301,7 @@ class HomeServer(ReplicationHandler): return self.replication_layer.backfill(dest, room_name, limit) def _get_room_remote_servers(self, room_name): - return [i for i in self.joined_rooms.setdefault(room_name).servers] + return list(self.joined_rooms.setdefault(room_name).servers) def _get_or_create_room(self, room_name): return self.joined_rooms.setdefault(room_name, Room(room_name)) @@ -334,7 +321,7 @@ def main(stdscr): user = args.user server_name = origin_from_ucid(user) - ## Set up logging ## + # Set up logging root_logger = logging.getLogger() @@ -354,7 +341,7 @@ def main(stdscr): observer = log.PythonLoggingObserver() observer.start() - ## Set up synapse server + # Set up synapse server curses_stdio = cursesio.CursesStdIO(stdscr) input_output = InputOutput(curses_stdio, user) @@ -368,16 +355,16 @@ def main(stdscr): input_output.set_home_server(hs) - ## Add input_output logger + # Add input_output logger io_logger = IOLoggerHandler(input_output) io_logger.setFormatter(formatter) root_logger.addHandler(io_logger) - ## Start! ## + # Start! try: port = int(server_name.split(":")[1]) - except: + except Exception: port = 12345 app_hs.get_http_server().start_listening(port) diff --git a/contrib/grafana/synapse.json b/contrib/grafana/synapse.json
index 30a8681f5a..539569b5b1 100644 --- a/contrib/grafana/synapse.json +++ b/contrib/grafana/synapse.json
@@ -1,7 +1,44 @@ { + "__inputs": [ + { + "name": "DS_PROMETHEUS", + "label": "Prometheus", + "description": "", + "type": "datasource", + "pluginId": "prometheus", + "pluginName": "Prometheus" + } + ], + "__requires": [ + { + "type": "grafana", + "id": "grafana", + "name": "Grafana", + "version": "6.7.4" + }, + { + "type": "panel", + "id": "graph", + "name": "Graph", + "version": "" + }, + { + "type": "panel", + "id": "heatmap", + "name": "Heatmap", + "version": "" + }, + { + "type": "datasource", + "id": "prometheus", + "name": "Prometheus", + "version": "1.0.0" + } + ], "annotations": { "list": [ { + "$$hashKey": "object:76", "builtIn": 1, "datasource": "$datasource", "enable": false, @@ -17,8 +54,8 @@ "editable": true, "gnetId": null, "graphTooltip": 0, - "id": 1, - "iteration": 1591098104645, + "id": null, + "iteration": 1594646317221, "links": [ { "asDropdown": true, @@ -34,7 +71,7 @@ "panels": [ { "collapsed": false, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -269,7 +306,6 @@ "show": false }, "links": [], - "options": {}, "reverseYBuckets": false, "targets": [ { @@ -559,7 +595,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -1423,7 +1459,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -1795,7 +1831,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -2531,7 +2567,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -2823,7 +2859,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -2844,7 +2880,7 @@ "h": 9, "w": 12, "x": 0, - "y": 33 + "y": 6 }, "hiddenSeries": false, "id": 79, @@ -2940,7 +2976,7 @@ "h": 9, "w": 12, "x": 12, - "y": 33 + "y": 6 }, "hiddenSeries": false, "id": 83, @@ -3038,7 +3074,7 @@ "h": 9, "w": 12, "x": 0, - "y": 42 + "y": 15 }, "hiddenSeries": false, "id": 109, @@ -3137,7 +3173,7 @@ "h": 9, "w": 12, "x": 12, - "y": 42 + "y": 15 }, "hiddenSeries": false, "id": 111, @@ -3223,14 +3259,14 @@ "dashLength": 10, "dashes": false, "datasource": "$datasource", - "description": "", + "description": "Number of events queued up on the master process for processing by the federation sender", "fill": 1, "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, - "y": 51 + "y": 24 }, "hiddenSeries": false, "id": 140, @@ -3354,6 +3390,103 @@ "align": false, "alignLevel": null } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "description": "The number of events in the in-memory queues ", + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 24 + }, + "hiddenSeries": false, + "id": 142, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "dataLinks": [] + }, + "percentage": false, + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "synapse_federation_transaction_queue_pending_pdus{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", + "interval": "", + "legendFormat": "pending PDUs {{job}}-{{index}}", + "refId": "A" + }, + { + "expr": "synapse_federation_transaction_queue_pending_edus{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", + "interval": "", + "legendFormat": "pending EDUs {{job}}-{{index}}", + "refId": "B" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "In-memory federation transmission queues", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:317", + "format": "short", + "label": "events", + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "$$hashKey": "object:318", + "format": "short", + "label": "", + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "title": "Federation", @@ -3361,7 +3494,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -3567,7 +3700,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -3588,7 +3721,7 @@ "h": 7, "w": 12, "x": 0, - "y": 52 + "y": 79 }, "hiddenSeries": false, "id": 48, @@ -3682,7 +3815,7 @@ "h": 7, "w": 12, "x": 12, - "y": 52 + "y": 79 }, "hiddenSeries": false, "id": 104, @@ -3802,7 +3935,7 @@ "h": 7, "w": 12, "x": 0, - "y": 59 + "y": 86 }, "hiddenSeries": false, "id": 10, @@ -3898,7 +4031,7 @@ "h": 7, "w": 12, "x": 12, - "y": 59 + "y": 86 }, "hiddenSeries": false, "id": 11, @@ -3987,7 +4120,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -4011,7 +4144,7 @@ "h": 13, "w": 12, "x": 0, - "y": 67 + "y": 80 }, "hiddenSeries": false, "id": 12, @@ -4106,7 +4239,7 @@ "h": 13, "w": 12, "x": 12, - "y": 67 + "y": 80 }, "hiddenSeries": false, "id": 26, @@ -4201,7 +4334,7 @@ "h": 13, "w": 12, "x": 0, - "y": 80 + "y": 93 }, "hiddenSeries": false, "id": 13, @@ -4297,7 +4430,7 @@ "h": 13, "w": 12, "x": 12, - "y": 80 + "y": 93 }, "hiddenSeries": false, "id": 27, @@ -4392,7 +4525,7 @@ "h": 13, "w": 12, "x": 0, - "y": 93 + "y": 106 }, "hiddenSeries": false, "id": 28, @@ -4486,7 +4619,7 @@ "h": 13, "w": 12, "x": 12, - "y": 93 + "y": 106 }, "hiddenSeries": false, "id": 25, @@ -4572,7 +4705,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -5062,7 +5195,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -5083,7 +5216,7 @@ "h": 9, "w": 12, "x": 0, - "y": 66 + "y": 121 }, "hiddenSeries": false, "id": 91, @@ -5179,7 +5312,7 @@ "h": 9, "w": 12, "x": 12, - "y": 66 + "y": 121 }, "hiddenSeries": false, "id": 21, @@ -5271,7 +5404,7 @@ "h": 9, "w": 12, "x": 0, - "y": 75 + "y": 130 }, "hiddenSeries": false, "id": 89, @@ -5369,7 +5502,7 @@ "h": 9, "w": 12, "x": 12, - "y": 75 + "y": 130 }, "hiddenSeries": false, "id": 93, @@ -5459,7 +5592,7 @@ "h": 9, "w": 12, "x": 0, - "y": 84 + "y": 139 }, "hiddenSeries": false, "id": 95, @@ -5552,12 +5685,12 @@ "mode": "spectrum" }, "dataFormat": "tsbuckets", - "datasource": "Prometheus", + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 9, "w": 12, "x": 12, - "y": 84 + "y": 139 }, "heatmap": {}, "hideZeroBuckets": true, @@ -5567,7 +5700,6 @@ "show": true }, "links": [], - "options": {}, "reverseYBuckets": false, "targets": [ { @@ -5609,7 +5741,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -5630,7 +5762,7 @@ "h": 7, "w": 12, "x": 0, - "y": 39 + "y": 66 }, "hiddenSeries": false, "id": 2, @@ -5754,7 +5886,7 @@ "h": 7, "w": 12, "x": 12, - "y": 39 + "y": 66 }, "hiddenSeries": false, "id": 41, @@ -5847,7 +5979,7 @@ "h": 7, "w": 12, "x": 0, - "y": 46 + "y": 73 }, "hiddenSeries": false, "id": 42, @@ -5939,7 +6071,7 @@ "h": 7, "w": 12, "x": 12, - "y": 46 + "y": 73 }, "hiddenSeries": false, "id": 43, @@ -6031,7 +6163,7 @@ "h": 7, "w": 12, "x": 0, - "y": 53 + "y": 80 }, "hiddenSeries": false, "id": 113, @@ -6129,7 +6261,7 @@ "h": 7, "w": 12, "x": 12, - "y": 53 + "y": 80 }, "hiddenSeries": false, "id": 115, @@ -6215,7 +6347,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -6236,7 +6368,7 @@ "h": 9, "w": 12, "x": 0, - "y": 58 + "y": 40 }, "hiddenSeries": false, "id": 67, @@ -6267,7 +6399,7 @@ "steppedLine": false, "targets": [ { - "expr": " synapse_event_persisted_position{instance=\"$instance\",job=\"synapse\"} - ignoring(index, job, name) group_right() synapse_event_processing_positions{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", + "expr": "max(synapse_event_persisted_position{instance=\"$instance\"}) - ignoring(instance,index, job, name) group_right() synapse_event_processing_positions{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", "format": "time_series", "interval": "", "intervalFactor": 1, @@ -6328,7 +6460,7 @@ "h": 9, "w": 12, "x": 12, - "y": 58 + "y": 40 }, "hiddenSeries": false, "id": 71, @@ -6362,6 +6494,7 @@ "expr": "time()*1000-synapse_event_processing_last_ts{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", "format": "time_series", "hide": false, + "interval": "", "intervalFactor": 1, "legendFormat": "{{job}}-{{index}} {{name}}", "refId": "B" @@ -6420,7 +6553,7 @@ "h": 9, "w": 12, "x": 0, - "y": 67 + "y": 49 }, "hiddenSeries": false, "id": 121, @@ -6509,7 +6642,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -6539,7 +6672,7 @@ "h": 8, "w": 12, "x": 0, - "y": 41 + "y": 86 }, "heatmap": {}, "hideZeroBuckets": true, @@ -6549,7 +6682,6 @@ "show": true }, "links": [], - "options": {}, "reverseYBuckets": false, "targets": [ { @@ -6599,7 +6731,7 @@ "h": 8, "w": 12, "x": 12, - "y": 41 + "y": 86 }, "hiddenSeries": false, "id": 124, @@ -6700,7 +6832,7 @@ "h": 8, "w": 12, "x": 0, - "y": 49 + "y": 94 }, "heatmap": {}, "hideZeroBuckets": true, @@ -6710,7 +6842,6 @@ "show": true }, "links": [], - "options": {}, "reverseYBuckets": false, "targets": [ { @@ -6760,7 +6891,7 @@ "h": 8, "w": 12, "x": 12, - "y": 49 + "y": 94 }, "hiddenSeries": false, "id": 128, @@ -6879,7 +7010,7 @@ "h": 8, "w": 12, "x": 0, - "y": 57 + "y": 102 }, "heatmap": {}, "hideZeroBuckets": true, @@ -6889,7 +7020,6 @@ "show": true }, "links": [], - "options": {}, "reverseYBuckets": false, "targets": [ { @@ -6939,7 +7069,7 @@ "h": 8, "w": 12, "x": 12, - "y": 57 + "y": 102 }, "hiddenSeries": false, "id": 130, @@ -7058,7 +7188,7 @@ "h": 8, "w": 12, "x": 0, - "y": 65 + "y": 110 }, "heatmap": {}, "hideZeroBuckets": true, @@ -7068,12 +7198,12 @@ "show": true }, "links": [], - "options": {}, "reverseYBuckets": false, "targets": [ { - "expr": "rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0)", + "expr": "rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", "format": "heatmap", + "interval": "", "intervalFactor": 1, "legendFormat": "{{le}}", "refId": "A" @@ -7118,7 +7248,7 @@ "h": 8, "w": 12, "x": 12, - "y": 65 + "y": 110 }, "hiddenSeries": false, "id": 132, @@ -7149,29 +7279,33 @@ "steppedLine": false, "targets": [ { - "expr": "histogram_quantile(0.5, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0)) ", + "expr": "histogram_quantile(0.5, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size]))", "format": "time_series", + "interval": "", "intervalFactor": 1, "legendFormat": "50%", "refId": "A" }, { - "expr": "histogram_quantile(0.75, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0))", + "expr": "histogram_quantile(0.75, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size]))", "format": "time_series", + "interval": "", "intervalFactor": 1, "legendFormat": "75%", "refId": "B" }, { - "expr": "histogram_quantile(0.90, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0))", + "expr": "histogram_quantile(0.90, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size]))", "format": "time_series", + "interval": "", "intervalFactor": 1, "legendFormat": "90%", "refId": "C" }, { - "expr": "histogram_quantile(0.99, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0))", + "expr": "histogram_quantile(0.99, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size]))", "format": "time_series", + "interval": "", "intervalFactor": 1, "legendFormat": "99%", "refId": "D" @@ -7181,7 +7315,7 @@ "timeFrom": null, "timeRegions": [], "timeShift": null, - "title": "Number of state resolution performed, by number of state groups involved (quantiles)", + "title": "Number of state resolutions performed, by number of state groups involved (quantiles)", "tooltip": { "shared": true, "sort": 0, @@ -7233,6 +7367,7 @@ "list": [ { "current": { + "selected": false, "text": "Prometheus", "value": "Prometheus" }, @@ -7309,14 +7444,12 @@ }, { "allValue": null, - "current": { - "text": "matrix.org", - "value": "matrix.org" - }, + "current": {}, "datasource": "$datasource", "definition": "", "hide": 0, "includeAll": false, + "index": -1, "label": null, "multi": false, "name": "instance", @@ -7335,17 +7468,13 @@ { "allFormat": "regex wildcard", "allValue": "", - "current": { - "text": "synapse", - "value": [ - "synapse" - ] - }, + "current": {}, "datasource": "$datasource", "definition": "", "hide": 0, "hideLabel": false, "includeAll": true, + "index": -1, "label": "Job", "multi": true, "multiFormat": "regex values", @@ -7366,16 +7495,13 @@ { "allFormat": "regex wildcard", "allValue": ".*", - "current": { - "selected": false, - "text": "All", - "value": "$__all" - }, + "current": {}, "datasource": "$datasource", "definition": "", "hide": 0, "hideLabel": false, "includeAll": true, + "index": -1, "label": "", "multi": true, "multiFormat": "regex values", @@ -7428,5 +7554,8 @@ "timezone": "", "title": "Synapse", "uid": "000000012", - "version": 29 + "variables": { + "list": [] + }, + "version": 32 } \ No newline at end of file diff --git a/contrib/graph/graph.py b/contrib/graph/graph.py
index 92736480eb..de33fac1c7 100644 --- a/contrib/graph/graph.py +++ b/contrib/graph/graph.py
@@ -1,5 +1,13 @@ from __future__ import print_function +import argparse +import cgi +import datetime +import json + +import pydot +import urllib2 + # Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,15 +23,6 @@ from __future__ import print_function # limitations under the License. -import sqlite3 -import pydot -import cgi -import json -import datetime -import argparse -import urllib2 - - def make_name(pdu_id, origin): return "%s@%s" % (pdu_id, origin) @@ -33,7 +32,7 @@ def make_graph(pdus, room, filename_prefix): node_map = {} origins = set() - colors = set(("red", "green", "blue", "yellow", "purple")) + colors = {"red", "green", "blue", "yellow", "purple"} for pdu in pdus: origins.add(pdu.get("origin")) @@ -49,7 +48,7 @@ def make_graph(pdus, room, filename_prefix): try: c = colors.pop() color_map[o] = c - except: + except Exception: print("Run out of colours!") color_map[o] = "black" diff --git a/contrib/graph/graph2.py b/contrib/graph/graph2.py
index 4619f0e3c1..0980231e4a 100644 --- a/contrib/graph/graph2.py +++ b/contrib/graph/graph2.py
@@ -13,12 +13,13 @@ # limitations under the License. -import sqlite3 -import pydot +import argparse import cgi -import json import datetime -import argparse +import json +import sqlite3 + +import pydot from synapse.events import FrozenEvent from synapse.util.frozenutils import unfreeze @@ -98,7 +99,7 @@ def make_graph(db_name, room_id, file_prefix, limit): for prev_id, _ in event.prev_events: try: end_node = node_map[prev_id] - except: + except Exception: end_node = pydot.Node(name=prev_id, label="<<b>%s</b>>" % (prev_id,)) node_map[prev_id] = end_node diff --git a/contrib/graph/graph3.py b/contrib/graph/graph3.py
index 3154638520..91db98e7ef 100644 --- a/contrib/graph/graph3.py +++ b/contrib/graph/graph3.py
@@ -1,5 +1,15 @@ from __future__ import print_function +import argparse +import cgi +import datetime + +import pydot +import simplejson as json + +from synapse.events import FrozenEvent +from synapse.util.frozenutils import unfreeze + # Copyright 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,16 +25,6 @@ from __future__ import print_function # limitations under the License. -import pydot -import cgi -import simplejson as json -import datetime -import argparse - -from synapse.events import FrozenEvent -from synapse.util.frozenutils import unfreeze - - def make_graph(file_name, room_id, file_prefix, limit): print("Reading lines") with open(file_name) as f: @@ -106,7 +106,7 @@ def make_graph(file_name, room_id, file_prefix, limit): for prev_id, _ in event.prev_events: try: end_node = node_map[prev_id] - except: + except Exception: end_node = pydot.Node(name=prev_id, label="<<b>%s</b>>" % (prev_id,)) node_map[prev_id] = end_node diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py
index 67fb2cd1a7..69aa74bd34 100644 --- a/contrib/jitsimeetbridge/jitsimeetbridge.py +++ b/contrib/jitsimeetbridge/jitsimeetbridge.py
@@ -12,15 +12,15 @@ npm install jquery jsdom """ from __future__ import print_function -import gevent -import grequests -from BeautifulSoup import BeautifulSoup import json -import urllib import subprocess import time -# ACCESS_TOKEN="" # +import gevent +import grequests +from BeautifulSoup import BeautifulSoup + +ACCESS_TOKEN = "" MATRIXBASE = "https://matrix.org/_matrix/client/api/v1/" MYUSERNAME = "@davetest:matrix.org" diff --git a/contrib/scripts/kick_users.py b/contrib/scripts/kick_users.py
index f57e6e7d25..372dbd9e4f 100755 --- a/contrib/scripts/kick_users.py +++ b/contrib/scripts/kick_users.py
@@ -1,10 +1,12 @@ #!/usr/bin/env python from __future__ import print_function -from argparse import ArgumentParser + import json -import requests import sys import urllib +from argparse import ArgumentParser + +import requests try: raw_input diff --git a/debian/changelog b/debian/changelog
index 1e7d7191ad..3825603ae4 100644 --- a/debian/changelog +++ b/debian/changelog
@@ -1,3 +1,21 @@ +matrix-synapse-py3 (1.17.0) stable; urgency=medium + + * New synapse release 1.17.0. + + -- Synapse Packaging team <packages@matrix.org> Mon, 13 Jul 2020 10:20:31 +0100 + +matrix-synapse-py3 (1.16.1) stable; urgency=medium + + * New synapse release 1.16.1. + + -- Synapse Packaging team <packages@matrix.org> Fri, 10 Jul 2020 12:09:24 +0100 + +matrix-synapse-py3 (1.17.0rc1) stable; urgency=medium + + * New synapse release 1.17.0rc1. + + -- Synapse Packaging team <packages@matrix.org> Thu, 09 Jul 2020 16:53:12 +0100 + matrix-synapse-py3 (1.16.0) stable; urgency=medium * New synapse release 1.16.0. diff --git a/docker/Dockerfile b/docker/Dockerfile
index 093e89af6c..8b3a4246a5 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile
@@ -16,35 +16,31 @@ ARG PYTHON_VERSION=3.7 ### ### Stage 0: builder ### -FROM docker.io/python:${PYTHON_VERSION}-alpine3.11 as builder +FROM docker.io/python:${PYTHON_VERSION}-slim as builder # install the OS build deps -RUN apk add \ - build-base \ - libffi-dev \ - libjpeg-turbo-dev \ - libwebp-dev \ - libressl-dev \ - libxslt-dev \ - linux-headers \ - postgresql-dev \ - zlib-dev -# build things which have slow build steps, before we copy synapse, so that -# the layer can be cached. -# -# (we really just care about caching a wheel here, as the "pip install" below -# will install them again.) +RUN apt-get update && apt-get install -y \ + build-essential \ + libpq-dev \ + && rm -rf /var/lib/apt/lists/* +# Build dependencies that are not available as wheels, to speed up rebuilds RUN pip install --prefix="/install" --no-warn-script-location \ - cryptography \ - msgpack-python \ - pillow \ - pynacl + frozendict \ + jaeger-client \ + opentracing \ + prometheus-client \ + psycopg2 \ + pycparser \ + pyrsistent \ + pyyaml \ + simplejson \ + threadloop \ + thrift # now install synapse and all of the python deps to /install. - COPY synapse /synapse/synapse/ COPY scripts /synapse/scripts/ COPY MANIFEST.in README.rst setup.py synctl /synapse/ @@ -56,20 +52,13 @@ RUN pip install --prefix="/install" --no-warn-script-location \ ### Stage 1: runtime ### -FROM docker.io/python:${PYTHON_VERSION}-alpine3.11 +FROM docker.io/python:${PYTHON_VERSION}-slim -# xmlsec is required for saml support -RUN apk add --no-cache --virtual .runtime_deps \ - libffi \ - libjpeg-turbo \ - libwebp \ - libressl \ - libxslt \ - libpq \ - zlib \ - su-exec \ - tzdata \ - xmlsec +RUN apt-get update && apt-get install -y \ + libpq5 \ + xmlsec1 \ + gosu \ + && rm -rf /var/lib/apt/lists/* COPY --from=builder /install /usr/local COPY ./docker/start.py /start.py diff --git a/docker/README.md b/docker/README.md
index 8c337149ca..008a9ff708 100644 --- a/docker/README.md +++ b/docker/README.md
@@ -94,6 +94,21 @@ The following environment variables are supported in run mode: * `UID`, `GID`: the user and group id to run Synapse as. Defaults to `991`, `991`. * `TZ`: the [timezone](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) the container will run with. Defaults to `UTC`. +## Generating an (admin) user + +After synapse is running, you may wish to create a user via `register_new_matrix_user`. + +This requires a `registration_shared_secret` to be set in your config file. Synapse +must be restarted to pick up this change. + +You can then call the script: + +``` +docker exec -it synapse register_new_matrix_user http://localhost:8008 -c /data/homeserver.yaml --help +``` + +Remember to remove the `registration_shared_secret` and restart if you no-longer need it. + ## TLS support The default configuration exposes a single HTTP port: http://localhost:8008. It diff --git a/docker/start.py b/docker/start.py
index 2a25c9380e..9f08134158 100755 --- a/docker/start.py +++ b/docker/start.py
@@ -120,7 +120,7 @@ def generate_config_from_template(config_dir, config_path, environ, ownership): if ownership is not None: subprocess.check_output(["chown", "-R", ownership, "/data"]) - args = ["su-exec", ownership] + args + args = ["gosu", ownership] + args subprocess.check_output(args) @@ -172,8 +172,8 @@ def run_generate_config(environ, ownership): # make sure that synapse has perms to write to the data dir. subprocess.check_output(["chown", ownership, data_dir]) - args = ["su-exec", ownership] + args - os.execv("/sbin/su-exec", args) + args = ["gosu", ownership] + args + os.execv("/usr/sbin/gosu", args) else: os.execv("/usr/local/bin/python", args) @@ -189,7 +189,7 @@ def main(args, environ): ownership = "{}:{}".format(desired_uid, desired_gid) if ownership is None: - log("Will not perform chmod/su-exec as UserID already matches request") + log("Will not perform chmod/gosu as UserID already matches request") # In generate mode, generate a configuration and missing keys, then exit if mode == "generate": @@ -236,8 +236,8 @@ running with 'migrate_config'. See the README for more details. args = ["python", "-m", synapse_worker, "--config-path", config_path] if ownership is not None: - args = ["su-exec", ownership] + args - os.execv("/sbin/su-exec", args) + args = ["gosu", ownership] + args + os.execv("/usr/sbin/gosu", args) else: os.execv("/usr/local/bin/python", args) diff --git a/docs/ACME.md b/docs/ACME.md
index f4c4740476..a7a498f575 100644 --- a/docs/ACME.md +++ b/docs/ACME.md
@@ -12,13 +12,14 @@ introduced support for automatically provisioning certificates through In [March 2019](https://community.letsencrypt.org/t/end-of-life-plan-for-acmev1/88430), Let's Encrypt announced that they were deprecating version 1 of the ACME protocol, with the plan to disable the use of it for new accounts in -November 2019, and for existing accounts in June 2020. +November 2019, for new domains in June 2020, and for existing accounts and +domains in June 2021. Synapse doesn't currently support version 2 of the ACME protocol, which means that: * for existing installs, Synapse's built-in ACME support will continue - to work until June 2020. + to work until June 2021. * for new installs, this feature will not work at all. Either way, it is recommended to move from Synapse's ACME support diff --git a/docs/admin_api/purge_room.md b/docs/admin_api/purge_room.md
index 64ea7b6a64..ae01a543c6 100644 --- a/docs/admin_api/purge_room.md +++ b/docs/admin_api/purge_room.md
@@ -5,6 +5,8 @@ This API will remove all trace of a room from your database. All local users must have left the room before it can be removed. +See also: [Delete Room API](rooms.md#delete-room-api) + The API is: ``` diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md
index 624e7745ba..15b83e9824 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md
@@ -318,3 +318,129 @@ Response: "state_events": 93534 } ``` + +# Room Members API + +The Room Members admin API allows server admins to get a list of all members of a room. + +The response includes the following fields: + +* `members` - A list of all the members that are present in the room, represented by their ids. +* `total` - Total number of members in the room. + +## Usage + +A standard request: + +``` +GET /_synapse/admin/v1/rooms/<room_id>/members + +{} +``` + +Response: + +``` +{ + "members": [ + "@foo:matrix.org", + "@bar:matrix.org", + "@foobar:matrix.org + ], + "total": 3 +} +``` + +# Delete Room API + +The Delete Room admin API allows server admins to remove rooms from server +and block these rooms. +It is a combination and improvement of "[Shutdown room](shutdown_room.md)" +and "[Purge room](purge_room.md)" API. + +Shuts down a room. Moves all local users and room aliases automatically to a +new room if `new_room_user_id` is set. Otherwise local users only +leave the room without any information. + +The new room will be created with the user specified by the `new_room_user_id` parameter +as room administrator and will contain a message explaining what happened. Users invited +to the new room will have power level `-10` by default, and thus be unable to speak. + +If `block` is `True` it prevents new joins to the old room. + +This API will remove all trace of the old room from your database after removing +all local users. +Depending on the amount of history being purged a call to the API may take +several minutes or longer. + +The local server will only have the power to move local user and room aliases to +the new room. Users on other servers will be unaffected. + +The API is: + +```json +POST /_synapse/admin/v1/rooms/<room_id>/delete +``` + +with a body of: +```json +{ + "new_room_user_id": "@someuser:example.com", + "room_name": "Content Violation Notification", + "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.", + "block": true +} +``` + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see [README.rst](README.rst). + +A response body like the following is returned: + +```json +{ + "kicked_users": [ + "@foobar:example.com" + ], + "failed_to_kick_users": [], + "local_aliases": [ + "#badroom:example.com", + "#evilsaloon:example.com" + ], + "new_room_id": "!newroomid:example.com" +} +``` + +## Parameters + +The following parameters should be set in the URL: + +* `room_id` - The ID of the room. + +The following JSON body parameters are available: + +* `new_room_user_id` - Optional. If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be moved into that + room. If not set, no new room will be created and the users will just be removed + from the old room. The user ID must be on the local server, but does not necessarily + have to belong to a registered user. +* `room_name` - Optional. A string representing the name of the room that new users will be + invited to. Defaults to `Content Violation Notification` +* `message` - Optional. A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly convey why the + original room was shut down. Defaults to `Sharing illegal content on this server + is not permitted and rooms in violation will be blocked.` +* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing future attempts to + join the room. Defaults to `false`. + +The JSON body must not be empty. The body must be at least `{}`. + +## Response + +The following fields are returned in the JSON response body: + +* `kicked_users` - An array of users (`user_id`) that were kicked. +* `failed_to_kick_users` - An array of users (`user_id`) that that were not kicked. +* `local_aliases` - An array of strings representing the local aliases that were migrated from + the old room to the new. +* `new_room_id` - A string representing the room ID of the new room. diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md
index 54ce1cd234..808caeec79 100644 --- a/docs/admin_api/shutdown_room.md +++ b/docs/admin_api/shutdown_room.md
@@ -10,6 +10,8 @@ disallow any further invites or joins. The local server will only have the power to move local user and room aliases to the new room. Users on other servers will be unaffected. +See also: [Delete Room API](rooms.md#delete-room-api) + ## API You will need to authenticate with an access token for an admin user. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 7b030a6285..be05128b3e 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst
@@ -91,10 +91,14 @@ Body parameters: - ``admin``, optional, defaults to ``false``. -- ``deactivated``, optional, defaults to ``false``. +- ``deactivated``, optional. If unspecified, deactivation state will be left + unchanged on existing accounts and set to ``false`` for new accounts. If the user already exists then optional parameters default to the current value. +In order to re-activate an account ``deactivated`` must be set to ``false``. If +users do not login via single-sign-on, a new ``password`` must be provided. + List Accounts ============= diff --git a/docs/jwt.md b/docs/jwt.md
index 289d66b365..5be9fd26e3 100644 --- a/docs/jwt.md +++ b/docs/jwt.md
@@ -20,12 +20,18 @@ follows: Note that the login type of `m.login.jwt` is supported, but is deprecated. This will be removed in a future version of Synapse. -The `jwt` should encode the local part of the user ID as the standard `sub` -claim. In the case that the token is not valid, the homeserver must respond with -`401 Unauthorized` and an error code of `M_UNAUTHORIZED`. +The `token` field should include the JSON web token with the following claims: -(Note that this differs from the token based logins which return a -`403 Forbidden` and an error code of `M_FORBIDDEN` if an error occurs.) +* The `sub` (subject) claim is required and should encode the local part of the + user ID. +* The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`) + claims are optional, but validated if present. +* The issuer (`iss`) claim is optional, but required and validated if configured. +* The audience (`aud`) claim is optional, but required and validated if configured. + Providing the audience claim when not configured will cause validation to fail. + +In the case that the token is not valid, the homeserver must respond with +`403 Forbidden` and an error code of `M_FORBIDDEN`. As with other login types, there are additional fields (e.g. `device_id` and `initial_device_display_name`) which can be included in the above request. @@ -55,7 +61,8 @@ sample settings. Although JSON Web Tokens are typically generated from an external server, the examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly. -1. Configure Synapse with JWT logins: +1. Configure Synapse with JWT logins, note that this example uses a pre-shared + secret and an algorithm of HS256: ```yaml jwt_config: diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md
index 5d9ae67041..fef1d47e85 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md
@@ -19,102 +19,103 @@ password auth provider module implementations: Password auth provider classes must provide the following methods: -*class* `SomeProvider.parse_config`(*config*) +* `parse_config(config)` + This method is passed the `config` object for this module from the + homeserver configuration file. -> This method is passed the `config` object for this module from the -> homeserver configuration file. -> -> It should perform any appropriate sanity checks on the provided -> configuration, and return an object which is then passed into -> `__init__`. + It should perform any appropriate sanity checks on the provided + configuration, and return an object which is then passed into -*class* `SomeProvider`(*config*, *account_handler*) + This method should have the `@staticmethod` decoration. -> The constructor is passed the config object returned by -> `parse_config`, and a `synapse.module_api.ModuleApi` object which -> allows the password provider to check if accounts exist and/or create -> new ones. +* `__init__(self, config, account_handler)` + + The constructor is passed the config object returned by + `parse_config`, and a `synapse.module_api.ModuleApi` object which + allows the password provider to check if accounts exist and/or create + new ones. ## Optional methods -Password auth provider classes may optionally provide the following -methods. - -*class* `SomeProvider.get_db_schema_files`() - -> This method, if implemented, should return an Iterable of -> `(name, stream)` pairs of database schema files. Each file is applied -> in turn at initialisation, and a record is then made in the database -> so that it is not re-applied on the next start. - -`someprovider.get_supported_login_types`() - -> This method, if implemented, should return a `dict` mapping from a -> login type identifier (such as `m.login.password`) to an iterable -> giving the fields which must be provided by the user in the submission -> to the `/login` api. These fields are passed in the `login_dict` -> dictionary to `check_auth`. -> -> For example, if a password auth provider wants to implement a custom -> login type of `com.example.custom_login`, where the client is expected -> to pass the fields `secret1` and `secret2`, the provider should -> implement this method and return the following dict: -> -> {"com.example.custom_login": ("secret1", "secret2")} - -`someprovider.check_auth`(*username*, *login_type*, *login_dict*) - -> This method is the one that does the real work. If implemented, it -> will be called for each login attempt where the login type matches one -> of the keys returned by `get_supported_login_types`. -> -> It is passed the (possibly UNqualified) `user` provided by the client, -> the login type, and a dictionary of login secrets passed by the -> client. -> -> The method should return a Twisted `Deferred` object, which resolves -> to the canonical `@localpart:domain` user id if authentication is -> successful, and `None` if not. -> -> Alternatively, the `Deferred` can resolve to a `(str, func)` tuple, in -> which case the second field is a callback which will be called with -> the result from the `/login` call (including `access_token`, -> `device_id`, etc.) - -`someprovider.check_3pid_auth`(*medium*, *address*, *password*) - -> This method, if implemented, is called when a user attempts to -> register or log in with a third party identifier, such as email. It is -> passed the medium (ex. "email"), an address (ex. -> "<jdoe@example.com>") and the user's password. -> -> The method should return a Twisted `Deferred` object, which resolves -> to a `str` containing the user's (canonical) User ID if -> authentication was successful, and `None` if not. -> -> As with `check_auth`, the `Deferred` may alternatively resolve to a -> `(user_id, callback)` tuple. - -`someprovider.check_password`(*user_id*, *password*) - -> This method provides a simpler interface than -> `get_supported_login_types` and `check_auth` for password auth -> providers that just want to provide a mechanism for validating -> `m.login.password` logins. -> -> Iif implemented, it will be called to check logins with an -> `m.login.password` login type. It is passed a qualified -> `@localpart:domain` user id, and the password provided by the user. -> -> The method should return a Twisted `Deferred` object, which resolves -> to `True` if authentication is successful, and `False` if not. - -`someprovider.on_logged_out`(*user_id*, *device_id*, *access_token*) - -> This method, if implemented, is called when a user logs out. It is -> passed the qualified user ID, the ID of the deactivated device (if -> any: access tokens are occasionally created without an associated -> device ID), and the (now deactivated) access token. -> -> It may return a Twisted `Deferred` object; the logout request will -> wait for the deferred to complete but the result is ignored. +Password auth provider classes may optionally provide the following methods: + +* `get_db_schema_files(self)` + + This method, if implemented, should return an Iterable of + `(name, stream)` pairs of database schema files. Each file is applied + in turn at initialisation, and a record is then made in the database + so that it is not re-applied on the next start. + +* `get_supported_login_types(self)` + + This method, if implemented, should return a `dict` mapping from a + login type identifier (such as `m.login.password`) to an iterable + giving the fields which must be provided by the user in the submission + to [the `/login` API](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login). + These fields are passed in the `login_dict` dictionary to `check_auth`. + + For example, if a password auth provider wants to implement a custom + login type of `com.example.custom_login`, where the client is expected + to pass the fields `secret1` and `secret2`, the provider should + implement this method and return the following dict: + + ```python + {"com.example.custom_login": ("secret1", "secret2")} + ``` + +* `check_auth(self, username, login_type, login_dict)` + + This method does the real work. If implemented, it + will be called for each login attempt where the login type matches one + of the keys returned by `get_supported_login_types`. + + It is passed the (possibly unqualified) `user` field provided by the client, + the login type, and a dictionary of login secrets passed by the + client. + + The method should return an `Awaitable` object, which resolves + to the canonical `@localpart:domain` user ID if authentication is + successful, and `None` if not. + + Alternatively, the `Awaitable` can resolve to a `(str, func)` tuple, in + which case the second field is a callback which will be called with + the result from the `/login` call (including `access_token`, + `device_id`, etc.) + +* `check_3pid_auth(self, medium, address, password)` + + This method, if implemented, is called when a user attempts to + register or log in with a third party identifier, such as email. It is + passed the medium (ex. "email"), an address (ex. + "<jdoe@example.com>") and the user's password. + + The method should return an `Awaitable` object, which resolves + to a `str` containing the user's (canonical) User id if + authentication was successful, and `None` if not. + + As with `check_auth`, the `Awaitable` may alternatively resolve to a + `(user_id, callback)` tuple. + +* `check_password(self, user_id, password)` + + This method provides a simpler interface than + `get_supported_login_types` and `check_auth` for password auth + providers that just want to provide a mechanism for validating + `m.login.password` logins. + + If implemented, it will be called to check logins with an + `m.login.password` login type. It is passed a qualified + `@localpart:domain` user id, and the password provided by the user. + + The method should return an `Awaitable` object, which resolves + to `True` if authentication is successful, and `False` if not. + +* `on_logged_out(self, user_id, device_id, access_token)` + + This method, if implemented, is called when a user logs out. It is + passed the qualified user ID, the ID of the deactivated device (if + any: access tokens are occasionally created without an associated + device ID), and the (now deactivated) access token. + + It may return an `Awaitable` object; the logout request will + wait for the `Awaitable` to complete, but the result is ignored. diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index 131990001a..7bfb96eff6 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md
@@ -38,6 +38,11 @@ the reverse proxy and the homeserver. server { listen 443 ssl; listen [::]:443 ssl; + + # For the federation port + listen 8448 ssl default_server; + listen [::]:8448 ssl default_server; + server_name matrix.example.com; location /_matrix { @@ -48,17 +53,6 @@ server { client_max_body_size 10M; } } - -server { - listen 8448 ssl default_server; - listen [::]:8448 ssl default_server; - server_name example.com; - - location / { - proxy_pass http://localhost:8008; - proxy_set_header X-Forwarded-For $remote_addr; - } -} ``` **NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 164a104045..3227294e0b 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml
@@ -102,7 +102,9 @@ pid_file: DATADIR/homeserver.pid #gc_thresholds: [700, 10, 10] # Set the limit on the returned events in the timeline in the get -# and sync operations. The default value is -1, means no upper limit. +# and sync operations. The default value is 100. -1 means no upper limit. +# +# Uncomment the following to increase the limit to 5000. # #filter_timeline_limit: 5000 @@ -118,38 +120,6 @@ pid_file: DATADIR/homeserver.pid # #enable_search: false -# Restrict federation to the following whitelist of domains. -# N.B. we recommend also firewalling your federation listener to limit -# inbound federation traffic as early as possible, rather than relying -# purely on this application-layer restriction. If not specified, the -# default is to whitelist everything. -# -#federation_domain_whitelist: -# - lon.example.com -# - nyc.example.com -# - syd.example.com - -# Prevent federation requests from being sent to the following -# blacklist IP address CIDR ranges. If this option is not specified, or -# specified with an empty list, no ip range blacklist will be enforced. -# -# As of Synapse v1.4.0 this option also affects any outbound requests to identity -# servers provided by user input. -# -# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly -# listed here, since they correspond to unroutable addresses.) -# -federation_ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' - # List of ports that Synapse should listen on, their purpose and their # configuration. # @@ -178,7 +148,7 @@ federation_ip_range_blacklist: # names: a list of names of HTTP resources. See below for a list of # valid resource names. # -# compress: set to true to enable HTTP comression for this resource. +# compress: set to true to enable HTTP compression for this resource. # # additional_resources: Only valid for an 'http' listener. A map of # additional endpoints which should be loaded via dynamic modules. @@ -608,6 +578,39 @@ acme: +# Restrict federation to the following whitelist of domains. +# N.B. we recommend also firewalling your federation listener to limit +# inbound federation traffic as early as possible, rather than relying +# purely on this application-layer restriction. If not specified, the +# default is to whitelist everything. +# +#federation_domain_whitelist: +# - lon.example.com +# - nyc.example.com +# - syd.example.com + +# Prevent federation requests from being sent to the following +# blacklist IP address CIDR ranges. If this option is not specified, or +# specified with an empty list, no ip range blacklist will be enforced. +# +# As of Synapse v1.4.0 this option also affects any outbound requests to identity +# servers provided by user input. +# +# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly +# listed here, since they correspond to unroutable addresses.) +# +federation_ip_range_blacklist: + - '127.0.0.0/8' + - '10.0.0.0/8' + - '172.16.0.0/12' + - '192.168.0.0/16' + - '100.64.0.0/10' + - '169.254.0.0/16' + - '::1/128' + - 'fe80::/64' + - 'fc00::/7' + + ## Caching ## # Caching can be configured through the following options. @@ -682,7 +685,7 @@ caches: #database: # name: psycopg2 # args: -# user: synapse +# user: synapse_user # password: secretpassword # database: synapse # host: localhost @@ -1811,6 +1814,9 @@ sso: # Each JSON Web Token needs to contain a "sub" (subject) claim, which is # used as the localpart of the mxid. # +# Additionally, the expiration time ("exp"), not before time ("nbf"), +# and issued at ("iat") claims are validated if present. +# # Note that this is a non-standard login type and client support is # expected to be non-existant. # @@ -1838,6 +1844,24 @@ sso: # #algorithm: "provided-by-your-issuer" + # The issuer to validate the "iss" claim against. + # + # Optional, if provided the "iss" claim will be required and + # validated for all JSON web tokens. + # + #issuer: "provided-by-your-issuer" + + # A list of audiences to validate the "aud" claim against. + # + # Optional, if provided the "aud" claim will be required and + # validated for all JSON web tokens. + # + # Note that if the "aud" claim is included in a JSON web token then + # validation will fail without configuring audiences. + # + #audiences: + # - "provided-by-your-issuer" + password_config: # Uncomment to disable password login @@ -1927,8 +1951,8 @@ email: # #notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>" - # app_name defines the default value for '%(app)s' in notif_from. It - # defaults to 'Matrix'. + # app_name defines the default value for '%(app)s' in notif_from and email + # subjects. It defaults to 'Matrix'. # #app_name: my_branded_matrix_server @@ -1997,6 +2021,73 @@ email: # #template_dir: "res/templates" + # Subjects to use when sending emails from Synapse. + # + # The placeholder '%(app)s' will be replaced with the value of the 'app_name' + # setting above, or by a value dictated by the Matrix client application. + # + # If a subject isn't overridden in this configuration file, the value used as + # its example will be used. + # + #subjects: + + # Subjects for notification emails. + # + # On top of the '%(app)s' placeholder, these can use the following + # placeholders: + # + # * '%(person)s', which will be replaced by the display name of the user(s) + # that sent the message(s), e.g. "Alice and Bob". + # * '%(room)s', which will be replaced by the name of the room the + # message(s) have been sent to, e.g. "My super room". + # + # See the example provided for each setting to see which placeholder can be + # used and how to use them. + # + # Subject to use to notify about one message from one or more user(s) in a + # room which has a name. + #message_from_person_in_room: "[%(app)s] You have a message on %(app)s from %(person)s in the %(room)s room..." + # + # Subject to use to notify about one message from one or more user(s) in a + # room which doesn't have a name. + #message_from_person: "[%(app)s] You have a message on %(app)s from %(person)s..." + # + # Subject to use to notify about multiple messages from one or more users in + # a room which doesn't have a name. + #messages_from_person: "[%(app)s] You have messages on %(app)s from %(person)s..." + # + # Subject to use to notify about multiple messages in a room which has a + # name. + #messages_in_room: "[%(app)s] You have messages on %(app)s in the %(room)s room..." + # + # Subject to use to notify about multiple messages in multiple rooms. + #messages_in_room_and_others: "[%(app)s] You have messages on %(app)s in the %(room)s room and others..." + # + # Subject to use to notify about multiple messages from multiple persons in + # multiple rooms. This is similar to the setting above except it's used when + # the room in which the notification was triggered has no name. + #messages_from_person_and_others: "[%(app)s] You have messages on %(app)s from %(person)s and others..." + # + # Subject to use to notify about an invite to a room which has a name. + #invite_from_person_to_room: "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s..." + # + # Subject to use to notify about an invite to a room which doesn't have a + # name. + #invite_from_person: "[%(app)s] %(person)s has invited you to chat on %(app)s..." + + # Subject for emails related to account administration. + # + # On top of the '%(app)s' placeholder, these one can use the + # '%(server_name)s' placeholder, which will be replaced by the value of the + # 'server_name' setting in your Synapse configuration. + # + # Subject to use when sending a password reset email. + #password_reset: "[%(server_name)s] Password reset" + # + # Subject to use when sending a verification email to assert an address's + # ownership. + #email_validation: "[%(server_name)s] Validate your email" + # Password providers allow homeserver administrators to integrate # their Synapse installation with existing authentication methods diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages
index e6f4bd1dca..d055cf3287 100755 --- a/scripts-dev/build_debian_packages +++ b/scripts-dev/build_debian_packages
@@ -24,7 +24,6 @@ DISTS = ( "debian:sid", "ubuntu:xenial", "ubuntu:bionic", - "ubuntu:eoan", "ubuntu:focal", ) diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index 66b0568858..0647993658 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh
@@ -11,7 +11,7 @@ if [ $# -ge 1 ] then files=$* else - files="synapse tests scripts-dev scripts" + files="synapse tests scripts-dev scripts contrib synctl" fi echo "Linting these locations: $files" diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 2eb795192f..22a6abd7d2 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db
@@ -48,6 +48,7 @@ from synapse.storage.data_stores.main.media_repository import ( ) from synapse.storage.data_stores.main.registration import ( RegistrationBackgroundUpdateStore, + find_max_generated_user_id_localpart, ) from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore @@ -622,8 +623,10 @@ class Porter(object): ) ) - # Step 5. Do final post-processing + # Step 5. Set up sequences + self.progress.set_state("Setting up sequence generators") await self._setup_state_group_id_seq() + await self._setup_user_id_seq() self.progress.done() except Exception as e: @@ -793,6 +796,13 @@ class Porter(object): return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) + def _setup_user_id_seq(self): + def r(txn): + next_id = find_max_generated_user_id_localpart(txn) + 1 + txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) + + return self.postgres_store.db.runInteraction("setup_user_id_seq", r) + ############################################## # The following is simply UI stuff diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index cac689d4f3..c66413f003 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi
@@ -22,6 +22,7 @@ class RedisProtocol: def publish(self, channel: str, message: bytes): ... class SubscriberProtocol: + def __init__(self, *args, **kwargs): ... password: Optional[str] def subscribe(self, channels: Union[str, List[str]]): ... def connectionMade(self): ... diff --git a/synapse/__init__.py b/synapse/__init__.py
index de65ce6db8..8592dee179 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try: except ImportError: pass -__version__ = "1.16.0" +__version__ = "1.17.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index cb22508f4d..40dc62ef6c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py
@@ -537,7 +537,7 @@ class Auth(object): # Currently we ignore the `for_verification` flag even though there are # some situations where we can drop particular auth events when adding # to the event's `auth_events` (e.g. joins pointing to previous joins - # when room is publically joinable). Dropping event IDs has the + # when room is publicly joinable). Dropping event IDs has the # advantage that the auth chain for the room grows slower, but we use # the auth chain in state resolution v2 to order events, which means # care must be taken if dropping events to ensure that it doesn't diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 5305038c21..b3bab1aa52 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py
@@ -17,13 +17,17 @@ """Contains exceptions and error codes.""" import logging +import typing from http import HTTPStatus -from typing import Dict, List +from typing import Dict, List, Optional, Union from canonicaljson import json from twisted.web import http +if typing.TYPE_CHECKING: + from synapse.types import JsonDict + logger = logging.getLogger(__name__) @@ -78,11 +82,11 @@ class CodeMessageException(RuntimeError): """An exception with integer code and message string attributes. Attributes: - code (int): HTTP error code - msg (str): string describing the error + code: HTTP error code + msg: string describing the error """ - def __init__(self, code, msg): + def __init__(self, code: Union[int, HTTPStatus], msg: str): super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) # Some calls to this method pass instances of http.HTTPStatus for `code`. @@ -123,16 +127,16 @@ class SynapseError(CodeMessageException): message (as well as an HTTP status code). Attributes: - errcode (str): Matrix error code e.g 'M_FORBIDDEN' + errcode: Matrix error code e.g 'M_FORBIDDEN' """ - def __init__(self, code, msg, errcode=Codes.UNKNOWN): + def __init__(self, code: int, msg: str, errcode: str = Codes.UNKNOWN): """Constructs a synapse error. Args: - code (int): The integer error code (an HTTP response code) - msg (str): The human-readable error message. - errcode (str): The matrix error code e.g 'M_FORBIDDEN' + code: The integer error code (an HTTP response code) + msg: The human-readable error message. + errcode: The matrix error code e.g 'M_FORBIDDEN' """ super(SynapseError, self).__init__(code, msg) self.errcode = errcode @@ -145,10 +149,16 @@ class ProxiedRequestError(SynapseError): """An error from a general matrix endpoint, eg. from a proxied Matrix API call. Attributes: - errcode (str): Matrix error code e.g 'M_FORBIDDEN' + errcode: Matrix error code e.g 'M_FORBIDDEN' """ - def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): + def __init__( + self, + code: int, + msg: str, + errcode: str = Codes.UNKNOWN, + additional_fields: Optional[Dict] = None, + ): super(ProxiedRequestError, self).__init__(code, msg, errcode) if additional_fields is None: self._additional_fields = {} # type: Dict @@ -164,12 +174,12 @@ class ConsentNotGivenError(SynapseError): privacy policy. """ - def __init__(self, msg, consent_uri): + def __init__(self, msg: str, consent_uri: str): """Constructs a ConsentNotGivenError Args: - msg (str): The human-readable error message - consent_url (str): The URL where the user can give their consent + msg: The human-readable error message + consent_url: The URL where the user can give their consent """ super(ConsentNotGivenError, self).__init__( code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN @@ -185,11 +195,11 @@ class UserDeactivatedError(SynapseError): authenticated endpoint, but the account has been deactivated. """ - def __init__(self, msg): + def __init__(self, msg: str): """Constructs a UserDeactivatedError Args: - msg (str): The human-readable error message + msg: The human-readable error message """ super(UserDeactivatedError, self).__init__( code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED @@ -201,16 +211,16 @@ class FederationDeniedError(SynapseError): is not on its federation whitelist. Attributes: - destination (str): The destination which has been denied + destination: The destination which has been denied """ - def __init__(self, destination): + def __init__(self, destination: Optional[str]): """Raised by federation client or server to indicate that we are are deliberately not attempting to contact a given server because it is not on our federation whitelist. Args: - destination (str): the domain in question + destination: the domain in question """ self.destination = destination @@ -228,11 +238,11 @@ class InteractiveAuthIncompleteError(Exception): (This indicates we should return a 401 with 'result' as the body) Attributes: - result (dict): the server response to the request, which should be + result: the server response to the request, which should be passed back to the client """ - def __init__(self, result): + def __init__(self, result: "JsonDict"): super(InteractiveAuthIncompleteError, self).__init__( "Interactive auth not yet complete" ) @@ -245,7 +255,6 @@ class UnrecognizedRequestError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.UNRECOGNIZED - message = None if len(args) == 0: message = "Unrecognized request" else: @@ -256,7 +265,7 @@ class UnrecognizedRequestError(SynapseError): class NotFoundError(SynapseError): """An error indicating we can't find the thing you asked for""" - def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND): + def __init__(self, msg: str = "Not found", errcode: str = Codes.NOT_FOUND): super(NotFoundError, self).__init__(404, msg, errcode=errcode) @@ -282,21 +291,23 @@ class InvalidClientCredentialsError(SynapseError): M_UNKNOWN_TOKEN respectively. """ - def __init__(self, msg, errcode): + def __init__(self, msg: str, errcode: str): super().__init__(code=401, msg=msg, errcode=errcode) class MissingClientTokenError(InvalidClientCredentialsError): """Raised when we couldn't find the access token in a request""" - def __init__(self, msg="Missing access token"): + def __init__(self, msg: str = "Missing access token"): super().__init__(msg=msg, errcode="M_MISSING_TOKEN") class InvalidClientTokenError(InvalidClientCredentialsError): """Raised when we didn't understand the access token in a request""" - def __init__(self, msg="Unrecognised access token", soft_logout=False): + def __init__( + self, msg: str = "Unrecognised access token", soft_logout: bool = False + ): super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN") self._soft_logout = soft_logout @@ -314,11 +325,11 @@ class ResourceLimitError(SynapseError): def __init__( self, - code, - msg, - errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - admin_contact=None, - limit_type=None, + code: int, + msg: str, + errcode: str = Codes.RESOURCE_LIMIT_EXCEEDED, + admin_contact: Optional[str] = None, + limit_type: Optional[str] = None, ): self.admin_contact = admin_contact self.limit_type = limit_type @@ -366,10 +377,10 @@ class StoreError(SynapseError): class InvalidCaptchaError(SynapseError): def __init__( self, - code=400, - msg="Invalid captcha.", - error_url=None, - errcode=Codes.CAPTCHA_INVALID, + code: int = 400, + msg: str = "Invalid captcha.", + error_url: Optional[str] = None, + errcode: str = Codes.CAPTCHA_INVALID, ): super(InvalidCaptchaError, self).__init__(code, msg, errcode) self.error_url = error_url @@ -384,10 +395,10 @@ class LimitExceededError(SynapseError): def __init__( self, - code=429, - msg="Too Many Requests", - retry_after_ms=None, - errcode=Codes.LIMIT_EXCEEDED, + code: int = 429, + msg: str = "Too Many Requests", + retry_after_ms: Optional[int] = None, + errcode: str = Codes.LIMIT_EXCEEDED, ): super(LimitExceededError, self).__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms @@ -400,10 +411,10 @@ class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store """ - def __init__(self, current_version): + def __init__(self, current_version: str): """ Args: - current_version (str): the current version of the store they should have used + current_version: the current version of the store they should have used """ super(RoomKeysVersionError, self).__init__( 403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION @@ -415,7 +426,7 @@ class UnsupportedRoomVersionError(SynapseError): """The client's request to create a room used a room version that the server does not support.""" - def __init__(self, msg="Homeserver does not support this room version"): + def __init__(self, msg: str = "Homeserver does not support this room version"): super(UnsupportedRoomVersionError, self).__init__( code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION, ) @@ -437,7 +448,7 @@ class IncompatibleRoomVersionError(SynapseError): failing. """ - def __init__(self, room_version): + def __init__(self, room_version: str): super(IncompatibleRoomVersionError, self).__init__( code=400, msg="Your homeserver does not support the features required to " @@ -457,8 +468,8 @@ class PasswordRefusedError(SynapseError): def __init__( self, - msg="This password doesn't comply with the server's policy", - errcode=Codes.WEAK_PASSWORD, + msg: str = "This password doesn't comply with the server's policy", + errcode: str = Codes.WEAK_PASSWORD, ): super(PasswordRefusedError, self).__init__( code=400, msg=msg, errcode=errcode, @@ -483,14 +494,14 @@ class RequestSendFailed(RuntimeError): self.can_retry = can_retry -def cs_error(msg, code=Codes.UNKNOWN, **kwargs): +def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs): """ Utility method for constructing an error response for client-server interactions. Args: - msg (str): The error message. - code (str): The error code. - kwargs : Additional keys to add to the response. + msg: The error message. + code: The error code. + kwargs: Additional keys to add to the response. Returns: A dict representing the error response JSON. """ @@ -512,7 +523,14 @@ class FederationError(RuntimeError): is wrong (e.g., it referred to an invalid event) """ - def __init__(self, level, code, reason, affected, source=None): + def __init__( + self, + level: str, + code: int, + reason: str, + affected: str, + source: Optional[str] = None, + ): if level not in ["FATAL", "ERROR", "WARN"]: raise ValueError("Level is not valid: %s" % (level,)) self.level = level @@ -539,16 +557,16 @@ class HttpResponseException(CodeMessageException): Represents an HTTP-level failure of an outbound request Attributes: - response (bytes): body of response + response: body of response """ - def __init__(self, code, msg, response): + def __init__(self, code: int, msg: str, response: bytes): """ Args: - code (int): HTTP status code - msg (str): reason phrase from HTTP response status line - response (bytes): body of response + code: HTTP status code + msg: reason phrase from HTTP response status line + response: body of response """ super(HttpResponseException, self).__init__(code, msg) self.response = response @@ -573,7 +591,7 @@ class HttpResponseException(CodeMessageException): # try to parse the body as json, to get better errcode/msg, but # default to M_UNKNOWN with the HTTP status as the error text try: - j = json.loads(self.response) + j = json.loads(self.response.decode("utf-8")) except ValueError: j = {} diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f6792d9fc8..c1b76d827b 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set from typing_extensions import ContextManager -from twisted.internet import address, defer, reactor +from twisted.internet import address, reactor import synapse import synapse.events @@ -111,6 +111,7 @@ from synapse.rest.client.v1.room import ( RoomSendEventRestServlet, RoomStateEventRestServlet, RoomStateRestServlet, + RoomTypingRestServlet, ) from synapse.rest.client.v1.voip import VoipRestServlet from synapse.rest.client.v2_alpha import groups, sync, user_directory @@ -374,9 +375,8 @@ class GenericWorkerPresence(BasePresenceHandler): return _user_syncing() - @defer.inlineCallbacks - def notify_from_replication(self, states, stream_id): - parties = yield get_interested_parties(self.store, states) + async def notify_from_replication(self, states, stream_id): + parties = await get_interested_parties(self.store, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -386,8 +386,7 @@ class GenericWorkerPresence(BasePresenceHandler): users=users_to_states.keys(), ) - @defer.inlineCallbacks - def process_replication_rows(self, token, rows): + async def process_replication_rows(self, token, rows): states = [ UserPresenceState( row.user_id, @@ -405,7 +404,7 @@ class GenericWorkerPresence(BasePresenceHandler): self.user_to_current_state[state.user_id] = state stream_id = token - yield self.notify_from_replication(states, stream_id) + await self.notify_from_replication(states, stream_id) def get_currently_syncing_users_for_replication(self) -> Iterable[str]: return [ @@ -451,37 +450,6 @@ class GenericWorkerPresence(BasePresenceHandler): await self._bump_active_client(user_id=user_id) -class GenericWorkerTyping(object): - def __init__(self, hs): - self._latest_room_serial = 0 - self._reset() - - def _reset(self): - """ - Reset the typing handler's data caches. - """ - # map room IDs to serial numbers - self._room_serials = {} - # map room IDs to sets of users currently typing - self._room_typing = {} - - def process_replication_rows(self, token, rows): - if self._latest_room_serial > token: - # The master has gone backwards. To prevent inconsistent data, just - # clear everything. - self._reset() - - # Set the latest serial token to whatever the server gave us. - self._latest_room_serial = token - - for row in rows: - self._room_serials[row.room_id] = token - self._room_typing[row.room_id] = row.user_ids - - def get_current_token(self) -> int: - return self._latest_room_serial - - class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. @@ -511,25 +479,7 @@ class GenericWorkerSlavedStore( SearchWorkerStore, BaseSlavedStore, ): - def __init__(self, database, db_conn, hs): - super(GenericWorkerSlavedStore, self).__init__(database, db_conn, hs) - - # We pull out the current federation stream position now so that we - # always have a known value for the federation position in memory so - # that we don't have to bounce via a deferred once when we start the - # replication streams. - self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) - - def _get_federation_out_pos(self, db_conn): - sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?" - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, ("federation",)) - rows = txn.fetchall() - txn.close() - - return rows[0][0] if rows else -1 + pass class GenericWorkerServer(HomeServer): @@ -576,6 +526,7 @@ class GenericWorkerServer(HomeServer): KeyUploadServlet(self).register(resource) AccountDataServlet(self).register(resource) RoomAccountDataServlet(self).register(resource) + RoomTypingRestServlet(self).register(resource) sync.register_servlets(self, resource) events.register_servlets(self, resource) @@ -687,9 +638,6 @@ class GenericWorkerServer(HomeServer): def build_presence_handler(self): return GenericWorkerPresence(self) - def build_typing_handler(self): - return GenericWorkerTyping(self) - class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): @@ -812,19 +760,11 @@ class FederationSenderHandler(object): self.federation_sender = hs.get_federation_sender() self._hs = hs - # if the worker is restarted, we want to pick up where we left off in - # the replication stream, so load the position from the database. - # - # XXX is this actually worthwhile? Whenever the master is restarted, we'll - # drop some rows anyway (which is mostly fine because we're only dropping - # typing and presence notifications). If the replication stream is - # unreliable, why do we do all this hoop-jumping to store the position in the - # database? See also https://github.com/matrix-org/synapse/issues/7535. - # - self.federation_position = self.store.federation_out_pos_startup + # Stores the latest position in the federation stream we've gotten up + # to. This is always set before we use it. + self.federation_position = None self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") - self._last_ack = self.federation_position def on_start(self): # There may be some events that are persisted but haven't been sent, @@ -932,7 +872,6 @@ class FederationSenderHandler(object): # We ACK this token over replication so that the master can drop # its in memory queues self._hs.get_tcp_replication().send_federation_ack(current_position) - self._last_ack = current_position except Exception: logger.exception("Error updating federation stream position") @@ -960,7 +899,7 @@ def start(config_options): ) if config.worker_app == "synapse.app.appservice": - if config.notify_appservices: + if config.appservice.notify_appservices: sys.stderr.write( "\nThe appservices must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -970,13 +909,13 @@ def start(config_options): sys.exit(1) # Force the appservice to start since they will be disabled in the main config - config.notify_appservices = True + config.appservice.notify_appservices = True else: # For other worker types we force this to off. - config.notify_appservices = False + config.appservice.notify_appservices = False if config.worker_app == "synapse.app.pusher": - if config.start_pushers: + if config.server.start_pushers: sys.stderr.write( "\nThe pushers must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -986,13 +925,13 @@ def start(config_options): sys.exit(1) # Force the pushers to start since they will be disabled in the main config - config.start_pushers = True + config.server.start_pushers = True else: # For other worker types we force this to off. - config.start_pushers = False + config.server.start_pushers = False if config.worker_app == "synapse.app.user_dir": - if config.update_user_directory: + if config.server.update_user_directory: sys.stderr.write( "\nThe update_user_directory must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -1002,13 +941,13 @@ def start(config_options): sys.exit(1) # Force the pushers to start since they will be disabled in the main config - config.update_user_directory = True + config.server.update_user_directory = True else: # For other worker types we force this to off. - config.update_user_directory = False + config.server.update_user_directory = False if config.worker_app == "synapse.app.federation_sender": - if config.send_federation: + if config.federation.send_federation: sys.stderr.write( "\nThe send_federation must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -1018,10 +957,10 @@ def start(config_options): sys.exit(1) # Force the pushers to start since they will be disabled in the main config - config.send_federation = True + config.federation.send_federation = True else: # For other worker types we force this to off. - config.send_federation = False + config.federation.send_federation = False synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 09291d86ad..ec7401f911 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py
@@ -483,8 +483,7 @@ class SynapseService(service.Service): _stats_process = [] -@defer.inlineCallbacks -def phone_stats_home(hs, stats, stats_process=_stats_process): +async def phone_stats_home(hs, stats, stats_process=_stats_process): logger.info("Gathering stats for reporting") now = int(hs.get_clock().time()) uptime = int(now - hs.start_time) @@ -522,28 +521,28 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): stats["python_version"] = "{}.{}.{}".format( version.major, version.minor, version.micro ) - stats["total_users"] = yield hs.get_datastore().count_all_users() + stats["total_users"] = await hs.get_datastore().count_all_users() - total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() + total_nonbridged_users = await hs.get_datastore().count_nonbridged_users() stats["total_nonbridged_users"] = total_nonbridged_users - daily_user_type_results = yield hs.get_datastore().count_daily_user_type() + daily_user_type_results = await hs.get_datastore().count_daily_user_type() for name, count in daily_user_type_results.items(): stats["daily_user_type_" + name] = count - room_count = yield hs.get_datastore().get_room_count() + room_count = await hs.get_datastore().get_room_count() stats["total_room_count"] = room_count - stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() - stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() - stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() - stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() + stats["daily_active_users"] = await hs.get_datastore().count_daily_users() + stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() + stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() + stats["daily_messages"] = await hs.get_datastore().count_daily_messages() - r30_results = yield hs.get_datastore().count_r30_users() + r30_results = await hs.get_datastore().count_r30_users() for name, count in r30_results.items(): stats["r30_users_" + name] = count - daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() + daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() stats["daily_sent_messages"] = daily_sent_messages stats["cache_factor"] = hs.config.caches.global_factor stats["event_cache_size"] = hs.config.caches.event_cache_size @@ -558,7 +557,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: - yield hs.get_proxied_http_client().put_json( + await hs.get_proxied_http_client().put_json( hs.config.report_stats_endpoint, stats ) except Exception as e: diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index f92bfb420b..1e0e4d497d 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py
@@ -19,7 +19,7 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.api.constants import ThirdPartyEntityKind +from synapse.api.constants import EventTypes, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient @@ -207,7 +207,7 @@ class ApplicationServiceApi(SimpleHttpClient): if service.url is None: return True - events = self._serialize(events) + events = self._serialize(service, events) if txn_id is None: logger.warning( @@ -233,6 +233,18 @@ class ApplicationServiceApi(SimpleHttpClient): failed_transactions_counter.labels(service.id).inc() return False - def _serialize(self, events): + def _serialize(self, service, events): time_now = self.clock.time_msec() - return [serialize_event(e, time_now, as_client_event=True) for e in events] + return [ + serialize_event( + e, + time_now, + as_client_event=True, + is_invite=( + e.type == EventTypes.Member + and e.membership == "invite" + and service.is_interested_in_user(e.state_key) + ), + ) + for e in events + ] diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1391e5fc43..fd137853b1 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py
@@ -19,9 +19,11 @@ import argparse import errno import os from collections import OrderedDict +from hashlib import sha256 from textwrap import dedent -from typing import Any, MutableMapping, Optional +from typing import Any, List, MutableMapping, Optional +import attr import yaml @@ -717,4 +719,36 @@ def find_config_files(search_paths): return config_files -__all__ = ["Config", "RootConfig"] +@attr.s +class ShardedWorkerHandlingConfig: + """Algorithm for choosing which instance is responsible for handling some + sharded work. + + For example, the federation senders use this to determine which instances + handles sending stuff to a given destination (which is used as the `key` + below). + """ + + instances = attr.ib(type=List[str]) + + def should_handle(self, instance_name: str, key: str) -> bool: + """Whether this instance is responsible for handling the given key. + """ + + # If multiple instances are not defined we always return true. + if not self.instances or len(self.instances) == 1: + return True + + # We shard by taking the hash, modulo it by the number of instances and + # then checking whether this instance matches the instance at that + # index. + # + # (Technically this introduces some bias and is not entirely uniform, + # but since the hash is so large the bias is ridiculously small). + dest_hash = sha256(key.encode("utf8")).digest() + dest_int = int.from_bytes(dest_hash, byteorder="little") + remainder = dest_int % (len(self.instances)) + return self.instances[remainder] == instance_name + + +__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 9e576060d4..eb911e8f9f 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi
@@ -137,3 +137,8 @@ class Config: def read_config_files(config_files: List[str]): ... def find_config_files(search_paths: List[str]): ... + +class ShardedWorkerHandlingConfig: + instances: List[str] + def __init__(self, instances: List[str]) -> None: ... + def should_handle(self, instance_name: str, key: str) -> bool: ... diff --git a/synapse/config/database.py b/synapse/config/database.py
index 1064c2697b..62bccd9ef5 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py
@@ -55,7 +55,7 @@ DEFAULT_CONFIG = """\ #database: # name: psycopg2 # args: -# user: synapse +# user: synapse_user # password: secretpassword # database: synapse # host: localhost diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index df08bcd1bc..a63acbdc63 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py
@@ -22,6 +22,7 @@ import os from enum import Enum from typing import Optional +import attr import pkg_resources from ._base import Config, ConfigError @@ -32,6 +33,33 @@ Password reset emails are enabled on this homeserver due to a partial %s """ +DEFAULT_SUBJECTS = { + "message_from_person_in_room": "[%(app)s] You have a message on %(app)s from %(person)s in the %(room)s room...", + "message_from_person": "[%(app)s] You have a message on %(app)s from %(person)s...", + "messages_from_person": "[%(app)s] You have messages on %(app)s from %(person)s...", + "messages_in_room": "[%(app)s] You have messages on %(app)s in the %(room)s room...", + "messages_in_room_and_others": "[%(app)s] You have messages on %(app)s in the %(room)s room and others...", + "messages_from_person_and_others": "[%(app)s] You have messages on %(app)s from %(person)s and others...", + "invite_from_person": "[%(app)s] %(person)s has invited you to chat on %(app)s...", + "invite_from_person_to_room": "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s...", + "password_reset": "[%(server_name)s] Password reset", + "email_validation": "[%(server_name)s] Validate your email", +} + + +@attr.s +class EmailSubjectConfig: + message_from_person_in_room = attr.ib(type=str) + message_from_person = attr.ib(type=str) + messages_from_person = attr.ib(type=str) + messages_in_room = attr.ib(type=str) + messages_in_room_and_others = attr.ib(type=str) + messages_from_person_and_others = attr.ib(type=str) + invite_from_person = attr.ib(type=str) + invite_from_person_to_room = attr.ib(type=str) + password_reset = attr.ib(type=str) + email_validation = attr.ib(type=str) + class EmailConfig(Config): section = "email" @@ -72,7 +100,7 @@ class EmailConfig(Config): template_dir = email_config.get("template_dir") # we need an absolute path, because we change directory after starting (and - # we don't yet know what auxilliary templates like mail.css we will need). + # we don't yet know what auxiliary templates like mail.css we will need). # (Note that loading as package_resources with jinja.PackageLoader doesn't # work for the same reason.) if not template_dir: @@ -294,8 +322,17 @@ class EmailConfig(Config): if not os.path.isfile(p): raise ConfigError("Unable to find email template file %s" % (p,)) + subjects_config = email_config.get("subjects", {}) + subjects = {} + + for key, default in DEFAULT_SUBJECTS.items(): + subjects[key] = subjects_config.get(key, default) + + self.email_subjects = EmailSubjectConfig(**subjects) + def generate_config_section(self, config_dir_path, server_name, **kwargs): - return """\ + return ( + """\ # Configuration for sending emails from Synapse. # email: @@ -323,17 +360,17 @@ class EmailConfig(Config): # notif_from defines the "From" address to use when sending emails. # It must be set if email sending is enabled. # - # The placeholder '%(app)s' will be replaced by the application name, + # The placeholder '%%(app)s' will be replaced by the application name, # which is normally 'app_name' (below), but may be overridden by the # Matrix client application. # - # Note that the placeholder must be written '%(app)s', including the + # Note that the placeholder must be written '%%(app)s', including the # trailing 's'. # - #notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>" + #notif_from: "Your Friendly %%(app)s homeserver <noreply@example.com>" - # app_name defines the default value for '%(app)s' in notif_from. It - # defaults to 'Matrix'. + # app_name defines the default value for '%%(app)s' in notif_from and email + # subjects. It defaults to 'Matrix'. # #app_name: my_branded_matrix_server @@ -401,7 +438,76 @@ class EmailConfig(Config): # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates # #template_dir: "res/templates" + + # Subjects to use when sending emails from Synapse. + # + # The placeholder '%%(app)s' will be replaced with the value of the 'app_name' + # setting above, or by a value dictated by the Matrix client application. + # + # If a subject isn't overridden in this configuration file, the value used as + # its example will be used. + # + #subjects: + + # Subjects for notification emails. + # + # On top of the '%%(app)s' placeholder, these can use the following + # placeholders: + # + # * '%%(person)s', which will be replaced by the display name of the user(s) + # that sent the message(s), e.g. "Alice and Bob". + # * '%%(room)s', which will be replaced by the name of the room the + # message(s) have been sent to, e.g. "My super room". + # + # See the example provided for each setting to see which placeholder can be + # used and how to use them. + # + # Subject to use to notify about one message from one or more user(s) in a + # room which has a name. + #message_from_person_in_room: "%(message_from_person_in_room)s" + # + # Subject to use to notify about one message from one or more user(s) in a + # room which doesn't have a name. + #message_from_person: "%(message_from_person)s" + # + # Subject to use to notify about multiple messages from one or more users in + # a room which doesn't have a name. + #messages_from_person: "%(messages_from_person)s" + # + # Subject to use to notify about multiple messages in a room which has a + # name. + #messages_in_room: "%(messages_in_room)s" + # + # Subject to use to notify about multiple messages in multiple rooms. + #messages_in_room_and_others: "%(messages_in_room_and_others)s" + # + # Subject to use to notify about multiple messages from multiple persons in + # multiple rooms. This is similar to the setting above except it's used when + # the room in which the notification was triggered has no name. + #messages_from_person_and_others: "%(messages_from_person_and_others)s" + # + # Subject to use to notify about an invite to a room which has a name. + #invite_from_person_to_room: "%(invite_from_person_to_room)s" + # + # Subject to use to notify about an invite to a room which doesn't have a + # name. + #invite_from_person: "%(invite_from_person)s" + + # Subject for emails related to account administration. + # + # On top of the '%%(app)s' placeholder, these one can use the + # '%%(server_name)s' placeholder, which will be replaced by the value of the + # 'server_name' setting in your Synapse configuration. + # + # Subject to use when sending a password reset email. + #password_reset: "%(password_reset)s" + # + # Subject to use when sending a verification email to assert an address's + # ownership. + #email_validation: "%(email_validation)s" """ + % DEFAULT_SUBJECTS + ) class ThreepidBehaviour(Enum): diff --git a/synapse/config/federation.py b/synapse/config/federation.py new file mode 100644
index 0000000000..82ff9664de --- /dev/null +++ b/synapse/config/federation.py
@@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from netaddr import IPSet + +from ._base import Config, ConfigError, ShardedWorkerHandlingConfig + + +class FederationConfig(Config): + section = "federation" + + def read_config(self, config, **kwargs): + # Whether to send federation traffic out in this process. This only + # applies to some federation traffic, and so shouldn't be used to + # "disable" federation + self.send_federation = config.get("send_federation", True) + + federation_sender_instances = config.get("federation_sender_instances") or [] + self.federation_shard_config = ShardedWorkerHandlingConfig( + federation_sender_instances + ) + + # FIXME: federation_domain_whitelist needs sytests + self.federation_domain_whitelist = None # type: Optional[dict] + federation_domain_whitelist = config.get("federation_domain_whitelist", None) + + if federation_domain_whitelist is not None: + # turn the whitelist into a hash for speed of lookup + self.federation_domain_whitelist = {} + + for domain in federation_domain_whitelist: + self.federation_domain_whitelist[domain] = True + + self.federation_ip_range_blacklist = config.get( + "federation_ip_range_blacklist", [] + ) + + # Attempt to create an IPSet from the given ranges + try: + self.federation_ip_range_blacklist = IPSet( + self.federation_ip_range_blacklist + ) + + # Always blacklist 0.0.0.0, :: + self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + except Exception as e: + raise ConfigError( + "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e + ) + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + # Restrict federation to the following whitelist of domains. + # N.B. we recommend also firewalling your federation listener to limit + # inbound federation traffic as early as possible, rather than relying + # purely on this application-layer restriction. If not specified, the + # default is to whitelist everything. + # + #federation_domain_whitelist: + # - lon.example.com + # - nyc.example.com + # - syd.example.com + + # Prevent federation requests from being sent to the following + # blacklist IP address CIDR ranges. If this option is not specified, or + # specified with an empty list, no ip range blacklist will be enforced. + # + # As of Synapse v1.4.0 this option also affects any outbound requests to identity + # servers provided by user input. + # + # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly + # listed here, since they correspond to unroutable addresses.) + # + federation_ip_range_blacklist: + - '127.0.0.0/8' + - '10.0.0.0/8' + - '172.16.0.0/12' + - '192.168.0.0/16' + - '100.64.0.0/10' + - '169.254.0.0/16' + - '::1/128' + - 'fe80::/64' + - 'fc00::/7' + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 264c274c52..8e93d31394 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py
@@ -23,6 +23,7 @@ from .cas import CasConfig from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig +from .federation import FederationConfig from .groups import GroupsConfig from .jwt_config import JWTConfig from .key import KeyConfig @@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, TlsConfig, + FederationConfig, CacheConfig, DatabaseConfig, LoggingConfig, @@ -90,4 +92,5 @@ class HomeServerConfig(RootConfig): ThirdPartyRulesConfig, TracerConfig, RedisConfig, + FederationConfig, ] diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index fce96b4acf..3252ad9e7f 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt_config.py
@@ -32,6 +32,11 @@ class JWTConfig(Config): self.jwt_secret = jwt_config["secret"] self.jwt_algorithm = jwt_config["algorithm"] + # The issuer and audiences are optional, if provided, it is asserted + # that the claims exist on the JWT. + self.jwt_issuer = jwt_config.get("issuer") + self.jwt_audiences = jwt_config.get("audiences") + try: import jwt @@ -42,6 +47,8 @@ class JWTConfig(Config): self.jwt_enabled = False self.jwt_secret = None self.jwt_algorithm = None + self.jwt_issuer = None + self.jwt_audiences = None def generate_config_section(self, **kwargs): return """\ @@ -52,6 +59,9 @@ class JWTConfig(Config): # Each JSON Web Token needs to contain a "sub" (subject) claim, which is # used as the localpart of the mxid. # + # Additionally, the expiration time ("exp"), not before time ("nbf"), + # and issued at ("iat") claims are validated if present. + # # Note that this is a non-standard login type and client support is # expected to be non-existant. # @@ -78,4 +88,22 @@ class JWTConfig(Config): # Required if 'enabled' is true. # #algorithm: "provided-by-your-issuer" + + # The issuer to validate the "iss" claim against. + # + # Optional, if provided the "iss" claim will be required and + # validated for all JSON web tokens. + # + #issuer: "provided-by-your-issuer" + + # A list of audiences to validate the "aud" claim against. + # + # Optional, if provided the "aud" claim will be required and + # validated for all JSON web tokens. + # + # Note that if the "aud" claim is included in a JSON web token then + # validation will fail without configuring audiences. + # + #audiences: + # - "provided-by-your-issuer" """ diff --git a/synapse/config/push.py b/synapse/config/push.py
index 6f2b3a7faa..a1f3752c8a 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py
@@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ShardedWorkerHandlingConfig class PushConfig(Config): @@ -24,6 +24,9 @@ class PushConfig(Config): push_config = config.get("push", {}) self.push_include_content = push_config.get("include_content", True) + pusher_instances = config.get("pusher_instances") or [] + self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) + # There was a a 'redact_content' setting but mistakenly read from the # 'email'section'. Check for the flag in the 'push' section, and log, # but do not honour it to avoid nasty surprises when people upgrade. diff --git a/synapse/config/room.py b/synapse/config/room.py
index 6aa4de0672..52cf0b62fc 100644 --- a/synapse/config/room.py +++ b/synapse/config/room.py
@@ -50,7 +50,12 @@ class RoomConfig(Config): RoomCreationPreset.PRIVATE_CHAT, RoomCreationPreset.TRUSTED_PRIVATE_CHAT, ] - elif encryption_for_room_type == RoomDefaultEncryptionTypes.OFF: + elif ( + encryption_for_room_type == RoomDefaultEncryptionTypes.OFF + or encryption_for_room_type is False + ): + # PyYAML translates "off" into False if it's unquoted, so we also need to + # check for encryption_for_room_type being False. self.encryption_enabled_by_default_for_room_presets = [] else: raise ConfigError( diff --git a/synapse/config/server.py b/synapse/config/server.py
index 8204664883..3747a01ca7 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -23,7 +23,6 @@ from typing import Any, Dict, Iterable, List, Optional import attr import yaml -from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.endpoint import parse_and_validate_server_name @@ -136,11 +135,6 @@ class ServerConfig(Config): self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.public_baseurl = config.get("public_baseurl") - # Whether to send federation traffic out in this process. This only - # applies to some federation traffic, and so shouldn't be used to - # "disable" federation - self.send_federation = config.get("send_federation", True) - # Whether to enable user presence. self.use_presence = config.get("use_presence", True) @@ -213,7 +207,7 @@ class ServerConfig(Config): # errors when attempting to search for messages. self.enable_search = config.get("enable_search", True) - self.filter_timeline_limit = config.get("filter_timeline_limit", -1) + self.filter_timeline_limit = config.get("filter_timeline_limit", 100) # Whether we should block invites sent to users on this server # (other than those sent by local server admins) @@ -263,34 +257,6 @@ class ServerConfig(Config): # due to resource constraints self.admin_contact = config.get("admin_contact", None) - # FIXME: federation_domain_whitelist needs sytests - self.federation_domain_whitelist = None # type: Optional[dict] - federation_domain_whitelist = config.get("federation_domain_whitelist", None) - - if federation_domain_whitelist is not None: - # turn the whitelist into a hash for speed of lookup - self.federation_domain_whitelist = {} - - for domain in federation_domain_whitelist: - self.federation_domain_whitelist[domain] = True - - self.federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", [] - ) - - # Attempt to create an IPSet from the given ranges - try: - self.federation_ip_range_blacklist = IPSet( - self.federation_ip_range_blacklist - ) - - # Always blacklist 0.0.0.0, :: - self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) - except Exception as e: - raise ConfigError( - "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e - ) - if self.public_baseurl is not None: if self.public_baseurl[-1] != "/": self.public_baseurl += "/" @@ -727,7 +693,9 @@ class ServerConfig(Config): #gc_thresholds: [700, 10, 10] # Set the limit on the returned events in the timeline in the get - # and sync operations. The default value is -1, means no upper limit. + # and sync operations. The default value is 100. -1 means no upper limit. + # + # Uncomment the following to increase the limit to 5000. # #filter_timeline_limit: 5000 @@ -743,38 +711,6 @@ class ServerConfig(Config): # #enable_search: false - # Restrict federation to the following whitelist of domains. - # N.B. we recommend also firewalling your federation listener to limit - # inbound federation traffic as early as possible, rather than relying - # purely on this application-layer restriction. If not specified, the - # default is to whitelist everything. - # - #federation_domain_whitelist: - # - lon.example.com - # - nyc.example.com - # - syd.example.com - - # Prevent federation requests from being sent to the following - # blacklist IP address CIDR ranges. If this option is not specified, or - # specified with an empty list, no ip range blacklist will be enforced. - # - # As of Synapse v1.4.0 this option also affects any outbound requests to identity - # servers provided by user input. - # - # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly - # listed here, since they correspond to unroutable addresses.) - # - federation_ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' - # List of ports that Synapse should listen on, their purpose and their # configuration. # @@ -803,7 +739,7 @@ class ServerConfig(Config): # names: a list of names of HTTP resources. See below for a list of # valid resource names. # - # compress: set to true to enable HTTP comression for this resource. + # compress: set to true to enable HTTP compression for this resource. # # additional_resources: Only valid for an 'http' listener. A map of # additional endpoints which should be loaded via dynamic modules. diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index dbc661630c..2574cd3aa1 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py
@@ -34,9 +34,11 @@ class WriterLocations: Attributes: events: The instance that writes to the event and backfill streams. + events: The instance that writes to the typing stream. """ events = attr.ib(default="master", type=str) + typing = attr.ib(default="master", type=str) class WorkerConfig(Config): @@ -93,16 +95,15 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writer for events also appears in + # Check that the configured writer for events and typing also appears in # `instance_map`. - if ( - self.writers.events != "master" - and self.writers.events not in self.instance_map - ): - raise ConfigError( - "Instance %r is configured to write events but does not appear in `instance_map` config." - % (self.writers.events,) - ) + for stream in ("events", "typing"): + instance = getattr(self.writers, stream) + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) def read_arguments(self, args): # We support a bunch of command line arguments that override options in diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c582355146..c0981eee62 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py
@@ -65,14 +65,16 @@ def check( room_id = event.room_id - # I'm not really expecting to get auth events in the wrong room, but let's - # sanity-check it + # We need to ensure that the auth events are actually for the same room, to + # stop people from using powers they've been granted in other rooms for + # example. for auth_event in auth_events.values(): if auth_event.room_id != room_id: - raise Exception( + raise AuthError( + 403, "During auth for event %s in room %s, found event %s in the state " "which is in room %s" - % (event.event_id, room_id, auth_event.event_id, auth_event.room_id) + % (event.event_id, room_id, auth_event.event_id, auth_event.room_id), ) if do_sig_check: diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index f6b507977f..11f0d34ec8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections +import collections.abc import re from typing import Any, Mapping, Union @@ -424,7 +424,7 @@ def copy_power_levels_contents( Raises: TypeError if the input does not look like a valid power levels event content """ - if not isinstance(old_power_levels, collections.Mapping): + if not isinstance(old_power_levels, collections.abc.Mapping): raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,)) power_levels = {} @@ -434,7 +434,7 @@ def copy_power_levels_contents( power_levels[k] = v continue - if isinstance(v, collections.Mapping): + if isinstance(v, collections.abc.Mapping): power_levels[k] = h = {} for k1, v1 in v.items(): # we should only have one level of nesting diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 07d41ec03f..994e6c8d5a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -245,7 +245,7 @@ class FederationClient(FederationBase): event_id: event to fetch room_version: version of the room outlier: Indicates whether the PDU is an `outlier`, i.e. if - it's from an arbitary point in the context as opposed to part + it's from an arbitrary point in the context as opposed to part of the current block of PDUs. Defaults to `False` timeout: How long to try (in ms) each destination for before moving to the next destination. None indicates no timeout. @@ -351,7 +351,7 @@ class FederationClient(FederationBase): outlier: bool = False, include_none: bool = False, ) -> List[EventBase]: - """Takes a list of PDUs and checks the signatures and hashs of each + """Takes a list of PDUs and checks the signatures and hashes of each one. If a PDU fails its signature check then we check if we have it in the database and if not then request if from the originating server of that PDU. @@ -374,29 +374,26 @@ class FederationClient(FederationBase): """ deferreds = self._check_sigs_and_hashes(room_version, pdus) - @defer.inlineCallbacks - def handle_check_result(pdu: EventBase, deferred: Deferred): + async def handle_check_result(pdu: EventBase, deferred: Deferred): try: - res = yield make_deferred_yieldable(deferred) + res = await make_deferred_yieldable(deferred) except SynapseError: res = None if not res: # Check local db. - res = yield self.store.get_event( + res = await self.store.get_event( pdu.event_id, allow_rejected=True, allow_none=True ) if not res and pdu.origin != origin: try: - res = yield defer.ensureDeferred( - self.get_pdu( - destinations=[pdu.origin], - event_id=pdu.event_id, - room_version=room_version, - outlier=outlier, - timeout=10000, - ) + res = await self.get_pdu( + destinations=[pdu.origin], + event_id=pdu.event_id, + room_version=room_version, + outlier=outlier, + timeout=10000, ) except SynapseError: pass @@ -995,24 +992,25 @@ class FederationClient(FederationBase): raise RuntimeError("Failed to send to any server.") - @defer.inlineCallbacks - def get_room_complexity(self, destination, room_id): + async def get_room_complexity( + self, destination: str, room_id: str + ) -> Optional[dict]: """ Fetch the complexity of a remote room from another server. Args: - destination (str): The remote server - room_id (str): The room ID to ask about. + destination: The remote server + room_id: The room ID to ask about. Returns: - Deferred[dict] or Deferred[None]: Dict contains the complexity - metric versions, while None means we could not fetch the complexity. + Dict contains the complexity metric versions, while None means we + could not fetch the complexity. """ try: - complexity = yield self.transport_layer.get_room_complexity( + complexity = await self.transport_layer.get_room_complexity( destination=destination, room_id=room_id ) - defer.returnValue(complexity) + return complexity except CodeMessageException as e: # We didn't manage to get it -- probably a 404. We are okay if other # servers don't give it to us. @@ -1029,4 +1027,4 @@ class FederationClient(FederationBase): # If we don't manage to find it, return None. It's not an error if a # server doesn't give it to us. - defer.returnValue(None) + return None diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e704cf2f44..11c5d63298 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -15,7 +15,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Match, + Optional, + Tuple, + Union, +) from canonicaljson import json from prometheus_client import Counter, Histogram @@ -56,6 +67,9 @@ from synapse.util import glob_to_regex, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +if TYPE_CHECKING: + from synapse.server import HomeServer + # when processing incoming transactions, we try to handle multiple rooms in # parallel, up to this limit. TRANSACTION_CONCURRENCY_LIMIT = 10 @@ -95,6 +109,9 @@ class FederationServer(FederationBase): # We cache responses to state queries, as they take a while and often # come in waves. self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) + self._state_ids_resp_cache = ResponseCache( + hs, "state_ids_resp", timeout_ms=30000 + ) async def on_backfill_request( self, origin: str, room_id: str, versions: List[str], limit: int @@ -362,10 +379,16 @@ class FederationServer(FederationBase): if not in_room: raise AuthError(403, "Host not in room.") + resp = await self._state_ids_resp_cache.wrap( + (room_id, event_id), self._on_state_ids_request_compute, room_id, event_id, + ) + + return 200, resp + + async def _on_state_ids_request_compute(self, room_id, event_id): state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) - - return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} + return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute( self, room_id: str, event_id: str @@ -526,9 +549,9 @@ class FederationServer(FederationBase): json_result = {} # type: Dict[str, Dict[str, dict]] for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): - for key_id, json_bytes in keys.items(): + for key_id, json_str in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_bytes) + key_id: json.loads(json_str) } logger.info( @@ -717,7 +740,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: # server name is a literal IP allow_ip_literals = acl_event.content.get("allow_ip_literals", True) if not isinstance(allow_ip_literals, bool): - logger.warning("Ignorning non-bool allow_ip_literals flag") + logger.warning("Ignoring non-bool allow_ip_literals flag") allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. @@ -731,7 +754,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: # next, check the deny list deny = acl_event.content.get("deny", []) if not isinstance(deny, (list, tuple)): - logger.warning("Ignorning non-list deny ACL %s", deny) + logger.warning("Ignoring non-list deny ACL %s", deny) deny = [] for e in deny: if _acl_entry_matches(server_name, e): @@ -741,7 +764,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: # then the allow list. allow = acl_event.content.get("allow", []) if not isinstance(allow, (list, tuple)): - logger.warning("Ignorning non-list allow ACL %s", allow) + logger.warning("Ignoring non-list allow ACL %s", allow) allow = [] for e in allow: if _acl_entry_matches(server_name, e): @@ -768,11 +791,30 @@ class FederationHandlerRegistry(object): query type for incoming federation traffic. """ - def __init__(self): - self.edu_handlers = {} - self.query_handlers = {} + def __init__(self, hs: "HomeServer"): + self.config = hs.config + self.http_client = hs.get_simple_http_client() + self.clock = hs.get_clock() + self._instance_name = hs.get_instance_name() - def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]): + # These are safe to load in monolith mode, but will explode if we try + # and use them. However we have guards before we use them to ensure that + # we don't route to ourselves, and in monolith mode that will always be + # the case. + self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) + self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) + + self.edu_handlers = ( + {} + ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] + self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] + + # Map from type to instance name that we should route EDU handling to. + self._edu_type_to_instance = {} # type: Dict[str, str] + + def register_edu_handler( + self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]] + ): """Sets the handler callable that will be used to handle an incoming federation EDU of the given type. @@ -809,66 +851,56 @@ class FederationHandlerRegistry(object): self.query_handlers[query_type] = handler + def register_instance_for_edu(self, edu_type: str, instance_name: str): + """Register that the EDU handler is on a different instance than master. + """ + self._edu_type_to_instance[edu_type] = instance_name + async def on_edu(self, edu_type: str, origin: str, content: dict): + if not self.config.use_presence and edu_type == "m.presence": + return + + # Check if we have a handler on this instance handler = self.edu_handlers.get(edu_type) - if not handler: - logger.warning("No handler registered for EDU type %s", edu_type) + if handler: + with start_active_span_from_edu(content, "handle_edu"): + try: + await handler(origin, content) + except SynapseError as e: + logger.info("Failed to handle edu %r: %r", edu_type, e) + except Exception: + logger.exception("Failed to handle edu %r", edu_type) return - with start_active_span_from_edu(content, "handle_edu"): + # Check if we can route it somewhere else that isn't us + route_to = self._edu_type_to_instance.get(edu_type, "master") + if route_to != self._instance_name: try: - await handler(origin, content) + await self._send_edu( + instance_name=route_to, + edu_type=edu_type, + origin=origin, + content=content, + ) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception: logger.exception("Failed to handle edu %r", edu_type) - - def on_query(self, query_type: str, args: dict) -> defer.Deferred: - handler = self.query_handlers.get(query_type) - if not handler: - logger.warning("No handler registered for query type %s", query_type) - raise NotFoundError("No handler for Query type '%s'" % (query_type,)) - - return handler(args) - - -class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): - """A FederationHandlerRegistry for worker processes. - - When receiving EDU or queries it will check if an appropriate handler has - been registered on the worker, if there isn't one then it calls off to the - master process. - """ - - def __init__(self, hs): - self.config = hs.config - self.http_client = hs.get_simple_http_client() - self.clock = hs.get_clock() - - self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) - self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) - - super(ReplicationFederationHandlerRegistry, self).__init__() - - async def on_edu(self, edu_type: str, origin: str, content: dict): - """Overrides FederationHandlerRegistry - """ - if not self.config.use_presence and edu_type == "m.presence": return - handler = self.edu_handlers.get(edu_type) - if handler: - return await super(ReplicationFederationHandlerRegistry, self).on_edu( - edu_type, origin, content - ) - - return await self._send_edu(edu_type=edu_type, origin=origin, content=content) + # Oh well, let's just log and move on. + logger.warning("No handler registered for EDU type %s", edu_type) async def on_query(self, query_type: str, args: dict): - """Overrides FederationHandlerRegistry - """ handler = self.query_handlers.get(query_type) if handler: return await handler(args) - return await self._get_query_client(query_type=query_type, args=args) + # Check if we can route it somewhere else that isn't us + if self._instance_name == "master": + return await self._get_query_client(query_type=query_type, args=args) + + # Uh oh, no handler! Let's raise an exception so the request returns an + # error. + logger.warning("No handler registered for query type %s", query_type) + raise NotFoundError("No handler for Query type '%s'" % (query_type,)) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 6bbd762681..4fc9ff92e5 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py
@@ -55,6 +55,11 @@ class FederationRemoteSendQueue(object): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id + # We may have multiple federation sender instances, so we need to track + # their positions separately. + self._sender_instances = hs.config.federation.federation_shard_config.instances + self._sender_positions = {} + # Pending presence map user_id -> UserPresenceState self.presence_map = {} # type: Dict[str, UserPresenceState] @@ -261,7 +266,14 @@ class FederationRemoteSendQueue(object): def get_current_token(self): return self.pos - 1 - def federation_ack(self, token): + def federation_ack(self, instance_name, token): + if self._sender_instances: + # If we have configured multiple federation sender instances we need + # to track their positions separately, and only clear the queue up + # to the token all instances have acked. + self._sender_positions[instance_name] = token + token = min(self._sender_positions.values()) + self._clear_queue_before_pos(token) async def get_replication_rows( @@ -359,7 +371,7 @@ class BaseFederationRow(object): Specifies how to identify, serialize and deserialize the different types. """ - TypeId = "" # Unique string that ids the type. Must be overriden in sub classes. + TypeId = "" # Unique string that ids the type. Must be overridden in sub classes. @staticmethod def from_data(data): diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 7afc701109..64282abc60 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py
@@ -69,6 +69,9 @@ class FederationSender(object): self._transaction_manager = TransactionManager(hs) + self._instance_name = hs.get_instance_name() + self._federation_shard_config = hs.config.federation.federation_shard_config + # map from destination to PerDestinationQueue self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue] @@ -206,7 +209,13 @@ class FederationSender(object): ) return - destinations = set(destinations) + destinations = { + d + for d in destinations + if self._federation_shard_config.should_handle( + self._instance_name, d + ) + } if send_on_behalf_of is not None: # If we are sending the event on behalf of another server @@ -340,7 +349,12 @@ class FederationSender(object): # Work out which remote servers should be poked and poke them. domains = yield self.state.get_current_hosts_in_room(room_id) - domains = [d for d in domains if d != self.server_name] + domains = [ + d + for d in domains + if d != self.server_name + and self._federation_shard_config.should_handle(self._instance_name, d) + ] if not domains: return @@ -445,6 +459,10 @@ class FederationSender(object): for destination in destinations: if destination == self.server_name: continue + if not self._federation_shard_config.should_handle( + self._instance_name, destination + ): + continue self._get_per_destination_queue(destination).send_presence(states) @measure_func("txnqueue._process_presence") @@ -453,12 +471,20 @@ class FederationSender(object): """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination """ - hosts_and_states = yield get_interested_remotes(self.store, states, self.state) + hosts_and_states = yield defer.ensureDeferred( + get_interested_remotes(self.store, states, self.state) + ) for destinations, states in hosts_and_states: for destination in destinations: if destination == self.server_name: continue + + if not self._federation_shard_config.should_handle( + self._instance_name, destination + ): + continue + self._get_per_destination_queue(destination).send_presence(states) def build_and_send_edu( @@ -480,6 +506,11 @@ class FederationSender(object): logger.info("Not sending EDU to ourselves") return + if not self._federation_shard_config.should_handle( + self._instance_name, destination + ): + return + edu = Edu( origin=self.server_name, destination=destination, @@ -496,6 +527,11 @@ class FederationSender(object): edu: edu to send key: clobbering key for this edu """ + if not self._federation_shard_config.should_handle( + self._instance_name, edu.destination + ): + return + queue = self._get_per_destination_queue(edu.destination) if key: queue.send_keyed_edu(edu, key) @@ -507,6 +543,11 @@ class FederationSender(object): logger.warning("Not sending device update to ourselves") return + if not self._federation_shard_config.should_handle( + self._instance_name, destination + ): + return + self._get_per_destination_queue(destination).attempt_new_transaction() def wake_destination(self, destination: str): @@ -520,6 +561,11 @@ class FederationSender(object): logger.warning("Not waking up ourselves") return + if not self._federation_shard_config.should_handle( + self._instance_name, destination + ): + return + self._get_per_destination_queue(destination).attempt_new_transaction() @staticmethod diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 1093ae0d91..1c18f9841c 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py
@@ -77,6 +77,20 @@ class PerDestinationQueue(object): self._clock = hs.get_clock() self._store = hs.get_datastore() self._transaction_manager = transaction_manager + self._instance_name = hs.get_instance_name() + self._federation_shard_config = hs.config.federation.federation_shard_config + + self._should_send_on_this_instance = True + if not self._federation_shard_config.should_handle( + self._instance_name, destination + ): + # We don't raise an exception here to avoid taking out any other + # processing. We have a guard in `attempt_new_transaction` that + # ensure we don't start sending stuff. + logger.error( + "Create a per destination queue for %s on wrong worker", destination, + ) + self._should_send_on_this_instance = False self._destination = destination self.transmission_loop_running = False @@ -122,7 +136,7 @@ class PerDestinationQueue(object): ) def send_pdu(self, pdu: EventBase, order: int) -> None: - """Add a PDU to the queue, and start the transmission loop if neccessary + """Add a PDU to the queue, and start the transmission loop if necessary Args: pdu: pdu to send @@ -132,7 +146,7 @@ class PerDestinationQueue(object): self.attempt_new_transaction() def send_presence(self, states: Iterable[UserPresenceState]) -> None: - """Add presence updates to the queue. Start the transmission loop if neccessary. + """Add presence updates to the queue. Start the transmission loop if necessary. Args: states: presence to send @@ -183,6 +197,14 @@ class PerDestinationQueue(object): logger.debug("TX [%s] Transaction already in progress", self._destination) return + if not self._should_send_on_this_instance: + # We don't raise an exception here to avoid taking out any other + # processing. + logger.error( + "Trying to start a transaction to %s on wrong worker", self._destination + ) + return + logger.debug("TX [%s] Starting transaction loop", self._destination) run_as_background_process( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index fa6ad9efdc..9c2c6a232d 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py
@@ -65,8 +65,6 @@ class TransactionManager(object): # all the edus in that transaction. This needs to be done since there is # no active span here, so if the edus were not received by the remote the # span would have no causality and it would be forgotten. - # The span_contexts is a generator so that it won't be evaluated if - # opentracing is disabled. (Yay speed!) span_contexts = [] keep_destination = whitelisted_homeserver(destination) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9f99311419..cfdf23d366 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -746,7 +746,7 @@ class TransportLayerClient(object): def remove_user_from_group( self, destination, group_id, requester_user_id, user_id, content ): - """Remove a user fron a group + """Remove a user from a group """ path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index bfb7831a02..5e111aa902 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -20,8 +20,6 @@ import logging import re from typing import Optional, Tuple, Type -from twisted.internet.defer import maybeDeferred - import synapse from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.room_versions import RoomVersions @@ -109,7 +107,7 @@ class Authenticator(object): self.server_name = hs.hostname self.store = hs.get_datastore() self.federation_domain_whitelist = hs.config.federation_domain_whitelist - self.notifer = hs.get_notifier() + self.notifier = hs.get_notifier() self.replication_client = None if hs.config.worker.worker_app: @@ -175,7 +173,7 @@ class Authenticator(object): await self.store.set_destination_retry_timings(origin, None, 0, 0) # Inform the relevant places that the remote server is back up. - self.notifer.notify_remote_server_up(origin) + self.notifier.notify_remote_server_up(origin) if self.replication_client: # If we're on a worker we try and inform master about this. The # replication client doesn't hook into the notifier to avoid @@ -340,6 +338,12 @@ class BaseFederationServlet(object): if origin: with ratelimiter.ratelimit(origin) as d: await d + if request._disconnected: + logger.warning( + "client disconnected before we started processing " + "request" + ) + return -1, None response = await func( origin, content, request.args, *args, **kwargs ) @@ -795,12 +799,8 @@ class PublicRoomList(BaseFederationServlet): # zero is a special value which corresponds to no limit. limit = None - data = await maybeDeferred( - self.handler.get_local_public_room_list, - limit, - since_token, - network_tuple=network_tuple, - from_federation=True, + data = await self.handler.get_local_public_room_list( + limit, since_token, network_tuple=network_tuple, from_federation=True ) return 200, data diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 61dc4beafe..ba2bf99800 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py
@@ -15,8 +15,8 @@ import logging -from twisted.internet import defer - +import synapse.state +import synapse.storage import synapse.types from synapse.api.constants import EventTypes, Membership from synapse.api.ratelimiting import Ratelimiter @@ -28,10 +28,6 @@ logger = logging.getLogger(__name__) class BaseHandler(object): """ Common base class for the event handlers. - - Attributes: - store (synapse.storage.DataStore): - state_handler (synapse.state.StateHandler): """ def __init__(self, hs): @@ -39,10 +35,10 @@ class BaseHandler(object): Args: hs (synapse.server.HomeServer): """ - self.store = hs.get_datastore() + self.store = hs.get_datastore() # type: synapse.storage.DataStore self.auth = hs.get_auth() self.notifier = hs.get_notifier() - self.state_handler = hs.get_state_handler() + self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler self.distributor = hs.get_distributor() self.clock = hs.get_clock() self.hs = hs @@ -68,8 +64,7 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() - @defer.inlineCallbacks - def ratelimit(self, requester, update=True, is_admin_redaction=False): + async def ratelimit(self, requester, update=True, is_admin_redaction=False): """Ratelimits requests. Args: @@ -101,7 +96,7 @@ class BaseHandler(object): burst_count = self._rc_message.burst_count # Check if there is a per user override in the DB. - override = yield self.store.get_ratelimit_for_user(user_id) + override = await self.store.get_ratelimit_for_user(user_id) if override: # If overridden with a null Hz then ratelimiting has been entirely # disabled for the user diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a162392e4c..c7d921c21a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import time import unicodedata @@ -863,11 +864,15 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: if hasattr(provider, "on_logged_out"): - await provider.on_logged_out( + # This might return an awaitable, if it does block the log out + # until it completes. + result = provider.on_logged_out( user_id=str(user_info["user"]), device_id=user_info["device_id"], access_token=access_token, ) + if inspect.isawaitable(result): + await result # delete pushers associated with this access token if user_info["token_id"] is not None: diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index d79ffefdb5..786e608fa2 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py
@@ -104,7 +104,7 @@ class CasHandler: return user, displayname def _parse_cas_response( - self, cas_response_body: str + self, cas_response_body: bytes ) -> Tuple[str, Dict[str, Optional[str]]]: """ Retrieve the user and other parameters from the CAS response. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 2afb390a92..25169157c1 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py
@@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process @@ -29,6 +30,7 @@ class DeactivateAccountHandler(BaseHandler): def __init__(self, hs): super(DeactivateAccountHandler, self).__init__(hs) + self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() self._room_member_handler = hs.get_room_member_handler() @@ -40,23 +42,25 @@ class DeactivateAccountHandler(BaseHandler): # Start the user parter loop so it can resume parting users from rooms where # it left off (if it has work left to do). - hs.get_reactor().callWhenRunning(self._start_user_parting) + if hs.config.worker_app is None: + hs.get_reactor().callWhenRunning(self._start_user_parting) self._account_validity_enabled = hs.config.account_validity.enabled - async def deactivate_account(self, user_id, erase_data, id_server=None): + async def deactivate_account( + self, user_id: str, erase_data: bool, id_server: Optional[str] = None + ) -> bool: """Deactivate a user's account Args: - user_id (str): ID of user to be deactivated - erase_data (bool): whether to GDPR-erase the user's data - id_server (str|None): Use the given identity server when unbinding + user_id: ID of user to be deactivated + erase_data: whether to GDPR-erase the user's data + id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). Returns: - Deferred[bool]: True if identity server supports removing - threepids, otherwise False. + True if identity server supports removing threepids, otherwise False. """ # FIXME: Theoretically there is a race here wherein user resets # password using threepid. @@ -133,11 +137,11 @@ class DeactivateAccountHandler(BaseHandler): return identity_server_supports_unbinding - async def _reject_pending_invites_for_user(self, user_id): + async def _reject_pending_invites_for_user(self, user_id: str): """Reject pending invites addressed to a given user ID. Args: - user_id (str): The user ID to reject pending invites for. + user_id: The user ID to reject pending invites for. """ user = UserID.from_string(user_id) pending_invites = await self.store.get_invited_rooms_for_local_user(user_id) @@ -165,22 +169,16 @@ class DeactivateAccountHandler(BaseHandler): room.room_id, ) - def _start_user_parting(self): + def _start_user_parting(self) -> None: """ Start the process that goes through the table of users pending deactivation, if it isn't already running. - - Returns: - None """ if not self._user_parter_running: run_as_background_process("user_parter_loop", self._user_parter_loop) - async def _user_parter_loop(self): + async def _user_parter_loop(self) -> None: """Loop that parts deactivated users from rooms - - Returns: - None """ self._user_parter_running = True logger.info("Starting user parter") @@ -197,11 +195,8 @@ class DeactivateAccountHandler(BaseHandler): finally: self._user_parter_running = False - async def _part_user(self, user_id): + async def _part_user(self, user_id: str) -> None: """Causes the given user_id to leave all the rooms they're joined to - - Returns: - None """ user = UserID.from_string(user_id) @@ -223,3 +218,31 @@ class DeactivateAccountHandler(BaseHandler): user_id, room_id, ) + + async def activate_account(self, user_id: str) -> None: + """ + Activate an account that was previously deactivated. + + This marks the user as active and not erased in the database, but does + not attempt to rejoin rooms, re-add threepids, etc. + + If enabled, the user will be re-added to the user directory. + + The user will also need a password hash set to actually login. + + Args: + user_id: ID of user to be re-activated + """ + # Add the user to the directory, if necessary. + user = UserID.from_string(user_id) + if self.hs.config.user_directory_search_all_users: + profile = await self.store.get_profileinfo(user.localpart) + await self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) + + # Ensure the user is not marked as erased. + await self.store.mark_user_not_erased(user_id) + + # Mark the user as active. + await self.store.set_user_deactivated_status(user_id, False) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 31346b56c3..db417d60de 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, Optional - -from twisted.internet import defer +from typing import Any, Dict, List, Optional from synapse.api import errors from synapse.api.constants import EventTypes @@ -57,21 +55,20 @@ class DeviceWorkerHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() @trace - @defer.inlineCallbacks - def get_devices_by_user(self, user_id): + async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]: """ Retrieve the given user's devices Args: - user_id (str): + user_id: The user ID to query for devices. Returns: - defer.Deferred: list[dict[str, X]]: info on each device + info on each device """ set_tag("user_id", user_id) - device_map = yield self.store.get_devices_by_user(user_id) + device_map = await self.store.get_devices_by_user(user_id) - ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None) + ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None) devices = list(device_map.values()) for device in devices: @@ -81,24 +78,23 @@ class DeviceWorkerHandler(BaseHandler): return devices @trace - @defer.inlineCallbacks - def get_device(self, user_id, device_id): + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """ Retrieve the given device Args: - user_id (str): - device_id (str): + user_id: The user to get the device from + device_id: The device to fetch. Returns: - defer.Deferred: dict[str, X]: info on the device + info on the device Raises: errors.NotFoundError: if the device was not found """ try: - device = yield self.store.get_device(user_id, device_id) + device = await self.store.get_device(user_id, device_id) except errors.StoreError: raise errors.NotFoundError - ips = yield self.store.get_last_client_ip_by_device(user_id, device_id) + ips = await self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) set_tag("device", device) @@ -106,10 +102,9 @@ class DeviceWorkerHandler(BaseHandler): return device - @measure_func("device.get_user_ids_changed") @trace - @defer.inlineCallbacks - def get_user_ids_changed(self, user_id, from_token): + @measure_func("device.get_user_ids_changed") + async def get_user_ids_changed(self, user_id, from_token): """Get list of users that have had the devices updated, or have newly joined a room, that `user_id` may be interested in. @@ -120,13 +115,13 @@ class DeviceWorkerHandler(BaseHandler): set_tag("user_id", user_id) set_tag("from_token", from_token) - now_room_key = yield self.store.get_room_events_max_id() + now_room_key = await self.store.get_room_events_max_id() - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) # First we check if any devices have changed for users that we share # rooms with. - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) @@ -135,14 +130,14 @@ class DeviceWorkerHandler(BaseHandler): # Always tell the user about their own devices tracked_users.add(user_id) - changed = yield self.store.get_users_whose_devices_changed( + changed = await self.store.get_users_whose_devices_changed( from_token.device_list_key, tracked_users ) # Then work out if any users have since joined rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) - member_events = yield self.store.get_membership_changes_for_user( + member_events = await self.store.get_membership_changes_for_user( user_id, from_token.room_key, now_room_key ) rooms_changed.update(event.room_id for event in member_events) @@ -152,7 +147,7 @@ class DeviceWorkerHandler(BaseHandler): possibly_changed = set(changed) possibly_left = set() for room_id in rooms_changed: - current_state_ids = yield self.store.get_current_state_ids(room_id) + current_state_ids = await self.store.get_current_state_ids(room_id) # The user may have left the room # TODO: Check if they actually did or if we were just invited. @@ -166,7 +161,7 @@ class DeviceWorkerHandler(BaseHandler): # Fetch the current state at the time. try: - event_ids = yield self.store.get_forward_extremeties_for_room( + event_ids = await self.store.get_forward_extremeties_for_room( room_id, stream_ordering=stream_ordering ) except errors.StoreError: @@ -192,7 +187,7 @@ class DeviceWorkerHandler(BaseHandler): continue # mapping from event_id -> state_dict - prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids) + prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. @@ -238,11 +233,10 @@ class DeviceWorkerHandler(BaseHandler): return result - @defer.inlineCallbacks - def on_federation_query_user_devices(self, user_id): - stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") - self_signing_key = yield self.store.get_e2e_cross_signing_key( + async def on_federation_query_user_devices(self, user_id): + stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id) + master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") + self_signing_key = await self.store.get_e2e_cross_signing_key( user_id, "self_signing" ) @@ -271,8 +265,7 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) - @defer.inlineCallbacks - def check_device_registered( + async def check_device_registered( self, user_id, device_id, initial_device_display_name=None ): """ @@ -290,13 +283,13 @@ class DeviceHandler(DeviceWorkerHandler): str: device id (generated if none was supplied) """ if device_id is not None: - new_device = yield self.store.store_device( + new_device = await self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) return device_id # if the device id is not specified, we'll autogen one, but loop a few @@ -304,33 +297,29 @@ class DeviceHandler(DeviceWorkerHandler): attempts = 0 while attempts < 5: device_id = stringutils.random_string(10).upper() - new_device = yield self.store.store_device( + new_device = await self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) return device_id attempts += 1 raise errors.StoreError(500, "Couldn't generate a device ID.") @trace - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): + async def delete_device(self, user_id: str, device_id: str) -> None: """ Delete the given device Args: - user_id (str): - device_id (str): - - Returns: - defer.Deferred: + user_id: The user to delete the device from. + device_id: The device to delete. """ try: - yield self.store.delete_device(user_id, device_id) + await self.store.delete_device(user_id, device_id) except errors.StoreError as e: if e.code == 404: # no match @@ -342,49 +331,40 @@ class DeviceHandler(DeviceWorkerHandler): else: raise - yield defer.ensureDeferred( - self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id - ) + await self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) + await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) @trace - @defer.inlineCallbacks - def delete_all_devices_for_user(self, user_id, except_device_id=None): + async def delete_all_devices_for_user( + self, user_id: str, except_device_id: Optional[str] = None + ) -> None: """Delete all of the user's devices Args: - user_id (str): - except_device_id (str|None): optional device id which should not - be deleted - - Returns: - defer.Deferred: + user_id: The user to remove all devices from + except_device_id: optional device id which should not be deleted """ - device_map = yield self.store.get_devices_by_user(user_id) + device_map = await self.store.get_devices_by_user(user_id) device_ids = list(device_map) if except_device_id is not None: device_ids = [d for d in device_ids if d != except_device_id] - yield self.delete_devices(user_id, device_ids) + await self.delete_devices(user_id, device_ids) - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """ Delete several devices Args: - user_id (str): - device_ids (List[str]): The list of device IDs to delete - - Returns: - defer.Deferred: + user_id: The user to delete devices from. + device_ids: The list of device IDs to delete """ try: - yield self.store.delete_devices(user_id, device_ids) + await self.store.delete_devices(user_id, device_ids) except errors.StoreError as e: if e.code == 404: # no match @@ -397,28 +377,22 @@ class DeviceHandler(DeviceWorkerHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield defer.ensureDeferred( - self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id - ) + await self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device( + await self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id ) - yield self.notify_device_update(user_id, device_ids) + await self.notify_device_update(user_id, device_ids) - @defer.inlineCallbacks - def update_device(self, user_id, device_id, content): + async def update_device(self, user_id: str, device_id: str, content: dict) -> None: """ Update the given device Args: - user_id (str): - device_id (str): - content (dict): body of update request - - Returns: - defer.Deferred: + user_id: The user to update devices of. + device_id: The device to update. + content: body of update request """ # Reject a new displayname which is too long. @@ -431,10 +405,10 @@ class DeviceHandler(DeviceWorkerHandler): ) try: - yield self.store.update_device( + await self.store.update_device( user_id, device_id, new_display_name=new_display_name ) - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) except errors.StoreError as e: if e.code == 404: raise errors.NotFoundError() @@ -443,12 +417,15 @@ class DeviceHandler(DeviceWorkerHandler): @trace @measure_func("notify_device_update") - @defer.inlineCallbacks - def notify_device_update(self, user_id, device_ids): + async def notify_device_update(self, user_id, device_ids): """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. """ - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + if not device_ids: + # No changes to notify about, so this is a no-op. + return + + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) @@ -459,20 +436,24 @@ class DeviceHandler(DeviceWorkerHandler): set_tag("target_hosts", hosts) - position = yield self.store.add_device_change_to_streams( + position = await self.store.add_device_change_to_streams( user_id, device_ids, list(hosts) ) + if not position: + # This should only happen if there are no updates, so we bail. + return + for device_id in device_ids: logger.debug( "Notifying about update %r/%r, ID: %r", user_id, device_id, position ) - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - yield self.notifier.on_new_event( + self.notifier.on_new_event( "device_list_key", position, users=[user_id], rooms=room_ids ) @@ -484,29 +465,29 @@ class DeviceHandler(DeviceWorkerHandler): self.federation_sender.send_device_messages(host) log_kv({"message": "sent device update to host", "host": host}) - @defer.inlineCallbacks - def notify_user_signature_update(self, from_user_id, user_ids): + async def notify_user_signature_update( + self, from_user_id: str, user_ids: List[str] + ) -> None: """Notify a user that they have made new signatures of other users. Args: - from_user_id (str): the user who made the signature - user_ids (list[str]): the users IDs that have new signatures + from_user_id: the user who made the signature + user_ids: the users IDs that have new signatures """ - position = yield self.store.add_user_signature_change_to_streams( + position = await self.store.add_user_signature_change_to_streams( from_user_id, user_ids ) self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) - @defer.inlineCallbacks - def user_left_room(self, user, room_id): + async def user_left_room(self, user, room_id): user_id = user.to_string() - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We no longer share rooms with this user, so we'll no longer # receive device updates. Mark this in DB. - yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) def _update_device_from_client_ips(device, client_ips): @@ -549,8 +530,7 @@ class DeviceListUpdater(object): ) @trace - @defer.inlineCallbacks - def incoming_device_list_update(self, origin, edu_content): + async def incoming_device_list_update(self, origin, edu_content): """Called on incoming device list update from federation. Responsible for parsing the EDU and adding to pending updates list. """ @@ -583,7 +563,7 @@ class DeviceListUpdater(object): ) return - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. @@ -608,14 +588,13 @@ class DeviceListUpdater(object): (device_id, stream_id, prev_ids, edu_content) ) - yield self._handle_device_updates(user_id) + await self._handle_device_updates(user_id) @measure_func("_incoming_device_list_update") - @defer.inlineCallbacks - def _handle_device_updates(self, user_id): + async def _handle_device_updates(self, user_id): "Actually handle pending updates." - with (yield self._remote_edu_linearizer.queue(user_id)): + with (await self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates @@ -632,7 +611,7 @@ class DeviceListUpdater(object): # Given a list of updates we check if we need to resync. This # happens if we've missed updates. - resync = yield self._need_to_do_resync(user_id, pending_updates) + resync = await self._need_to_do_resync(user_id, pending_updates) if logger.isEnabledFor(logging.INFO): logger.info( @@ -643,16 +622,16 @@ class DeviceListUpdater(object): ) if resync: - yield self.user_device_resync(user_id) + await self.user_device_resync(user_id) else: # Simply update the single device, since we know that is the only # change (because of the single prev_id matching the current cache) for device_id, stream_id, prev_ids, content in pending_updates: - yield self.store.update_remote_device_list_cache_entry( + await self.store.update_remote_device_list_cache_entry( user_id, device_id, content, stream_id ) - yield self.device_handler.notify_device_update( + await self.device_handler.notify_device_update( user_id, [device_id for device_id, _, _, _ in pending_updates] ) @@ -660,14 +639,13 @@ class DeviceListUpdater(object): stream_id for _, stream_id, _, _ in pending_updates ) - @defer.inlineCallbacks - def _need_to_do_resync(self, user_id, updates): + async def _need_to_do_resync(self, user_id, updates): """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ seen_updates = self._seen_updates.get(user_id, set()) - extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id) + extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) logger.debug("Current extremity for %r: %r", user_id, extremity) @@ -692,8 +670,7 @@ class DeviceListUpdater(object): return False @trace - @defer.inlineCallbacks - def _maybe_retry_device_resync(self): + async def _maybe_retry_device_resync(self): """Retry to resync device lists that are out of sync, except if another retry is in progress. """ @@ -705,12 +682,12 @@ class DeviceListUpdater(object): # we don't send too many requests. self._resync_retry_in_progress = True # Get all of the users that need resyncing. - need_resync = yield self.store.get_user_ids_requiring_device_list_resync() + need_resync = await self.store.get_user_ids_requiring_device_list_resync() # Iterate over the set of user IDs. for user_id in need_resync: try: # Try to resync the current user's devices list. - result = yield self.user_device_resync( + result = await self.user_device_resync( user_id=user_id, mark_failed_as_stale=False, ) @@ -734,16 +711,17 @@ class DeviceListUpdater(object): # Allow future calls to retry resyncinc out of sync device lists. self._resync_retry_in_progress = False - @defer.inlineCallbacks - def user_device_resync(self, user_id, mark_failed_as_stale=True): + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[dict]: """Fetches all devices for a user and updates the device cache with them. Args: - user_id (str): The user's id whose device_list will be updated. - mark_failed_as_stale (bool): Whether to mark the user's device list as stale + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale if the attempt to resync failed. Returns: - Deferred[dict]: a dict with device info as under the "devices" in the result of this + A dict with device info as under the "devices" in the result of this request: https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid """ @@ -752,12 +730,12 @@ class DeviceListUpdater(object): # Fetch all devices for the user. origin = get_domain_from_id(user_id) try: - result = yield self.federation.query_user_devices(origin, user_id) + result = await self.federation.query_user_devices(origin, user_id) except NotRetryingDestination: if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) return except (RequestSendFailed, HttpResponseException) as e: @@ -768,7 +746,7 @@ class DeviceListUpdater(object): if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list @@ -792,7 +770,7 @@ class DeviceListUpdater(object): if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) return log_kv({"result": result}) @@ -833,25 +811,24 @@ class DeviceListUpdater(object): stream_id, ) - yield self.store.update_remote_device_list_cache(user_id, devices, stream_id) + await self.store.update_remote_device_list_cache(user_id, devices, stream_id) device_ids = [device["device_id"] for device in devices] # Handle cross-signing keys. - cross_signing_device_ids = yield self.process_cross_signing_key_update( + cross_signing_device_ids = await self.process_cross_signing_key_update( user_id, master_key, self_signing_key, ) device_ids = device_ids + cross_signing_device_ids - yield self.device_handler.notify_device_update(user_id, device_ids) + await self.device_handler.notify_device_update(user_id, device_ids) # We clobber the seen updates since we've re-synced from a given # point. self._seen_updates[user_id] = {stream_id} - defer.returnValue(result) + return result - @defer.inlineCallbacks - def process_cross_signing_key_update( + async def process_cross_signing_key_update( self, user_id: str, master_key: Optional[Dict[str, Any]], @@ -872,14 +849,14 @@ class DeviceListUpdater(object): device_ids = [] if master_key: - yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) _, verify_key = get_verify_key_from_cross_signing_key(master_key) # verify_key is a VerifyKey from signedjson, which uses # .version to denote the portion of the key ID after the # algorithm and colon, which is the device ID device_ids.append(verify_key.version) if self_signing_key: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index a7e60cbc26..361dd64cd2 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py
@@ -77,8 +77,7 @@ class E2eKeysHandler(object): ) @trace - @defer.inlineCallbacks - def query_devices(self, query_body, timeout, from_user_id): + async def query_devices(self, query_body, timeout, from_user_id): """ Handle a device key query from a client { @@ -124,7 +123,7 @@ class E2eKeysHandler(object): failures = {} results = {} if local_query: - local_result = yield self.query_local_devices(local_query) + local_result = await self.query_local_devices(local_query) for user_id, keys in local_result.items(): if user_id in local_query: results[user_id] = keys @@ -142,7 +141,7 @@ class E2eKeysHandler(object): ( user_ids_not_in_cache, remote_results, - ) = yield self.store.get_user_devices_from_cache(query_list) + ) = await self.store.get_user_devices_from_cache(query_list) for user_id, devices in remote_results.items(): user_devices = results.setdefault(user_id, {}) for device_id, device in devices.items(): @@ -161,14 +160,13 @@ class E2eKeysHandler(object): r[user_id] = remote_queries[user_id] # Get cached cross-signing keys - cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, from_user_id ) # Now fetch any devices that we don't have in our cache @trace - @defer.inlineCallbacks - def do_remote_query(destination): + async def do_remote_query(destination): """This is called when we are querying the device list of a user on a remote homeserver and their device list is not in the device list cache. If we share a room with this user and we're not querying for @@ -192,7 +190,7 @@ class E2eKeysHandler(object): if device_list: continue - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: continue @@ -201,11 +199,11 @@ class E2eKeysHandler(object): # done an initial sync on the device list so we do it now. try: if self._is_master: - user_devices = yield self.device_handler.device_list_updater.user_device_resync( + user_devices = await self.device_handler.device_list_updater.user_device_resync( user_id ) else: - user_devices = yield self._user_device_resync_client( + user_devices = await self._user_device_resync_client( user_id=user_id ) @@ -227,7 +225,7 @@ class E2eKeysHandler(object): destination_query.pop(user_id) try: - remote_result = yield self.federation.query_client_keys( + remote_result = await self.federation.query_client_keys( destination, {"device_keys": destination_query}, timeout=timeout ) @@ -251,7 +249,7 @@ class E2eKeysHandler(object): set_tag("error", True) set_tag("reason", failure) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(do_remote_query, destination) @@ -267,8 +265,7 @@ class E2eKeysHandler(object): return ret - @defer.inlineCallbacks - def get_cross_signing_keys_from_cache(self, query, from_user_id): + async def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database Args: @@ -289,7 +286,7 @@ class E2eKeysHandler(object): user_ids = list(query) - keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) + keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) for user_id, user_info in keys.items(): if user_info is None: @@ -315,8 +312,7 @@ class E2eKeysHandler(object): } @trace - @defer.inlineCallbacks - def query_local_devices(self, query): + async def query_local_devices(self, query): """Get E2E device keys for local users Args: @@ -354,7 +350,7 @@ class E2eKeysHandler(object): # make sure that each queried user appears in the result dict result_dict[user_id] = {} - results = yield self.store.get_e2e_device_keys(local_query) + results = await self.store.get_e2e_device_keys(local_query) # Build the result structure for user_id, device_keys in results.items(): @@ -364,16 +360,15 @@ class E2eKeysHandler(object): log_kv(results) return result_dict - @defer.inlineCallbacks - def on_federation_query_client_keys(self, query_body): + async def on_federation_query_client_keys(self, query_body): """ Handle a device key query from a federated server """ device_keys_query = query_body.get("device_keys", {}) - res = yield self.query_local_devices(device_keys_query) + res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} # add in the cross-signing keys - cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, None ) @@ -382,8 +377,7 @@ class E2eKeysHandler(object): return ret @trace - @defer.inlineCallbacks - def claim_one_time_keys(self, query, timeout): + async def claim_one_time_keys(self, query, timeout): local_query = [] remote_queries = {} @@ -399,7 +393,7 @@ class E2eKeysHandler(object): set_tag("local_key_query", local_query) set_tag("remote_key_query", remote_queries) - results = yield self.store.claim_e2e_one_time_keys(local_query) + results = await self.store.claim_e2e_one_time_keys(local_query) json_result = {} failures = {} @@ -411,12 +405,11 @@ class E2eKeysHandler(object): } @trace - @defer.inlineCallbacks - def claim_client_keys(destination): + async def claim_client_keys(destination): set_tag("destination", destination) device_keys = remote_queries[destination] try: - remote_result = yield self.federation.claim_client_keys( + remote_result = await self.federation.claim_client_keys( destination, {"one_time_keys": device_keys}, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): @@ -429,7 +422,7 @@ class E2eKeysHandler(object): set_tag("error", True) set_tag("reason", failure) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(claim_client_keys, destination) @@ -454,9 +447,8 @@ class E2eKeysHandler(object): log_kv({"one_time_keys": json_result, "failures": failures}) return {"one_time_keys": json_result, "failures": failures} - @defer.inlineCallbacks @tag_args - def upload_keys_for_user(self, user_id, device_id, keys): + async def upload_keys_for_user(self, user_id, device_id, keys): time_now = self.clock.time_msec() @@ -477,12 +469,12 @@ class E2eKeysHandler(object): } ) # TODO: Sign the JSON with the server key - changed = yield self.store.set_e2e_device_keys( + changed = await self.store.set_e2e_device_keys( user_id, device_id, time_now, device_keys ) if changed: # Only notify about device updates *if* the keys actually changed - yield self.device_handler.notify_device_update(user_id, [device_id]) + await self.device_handler.notify_device_update(user_id, [device_id]) else: log_kv({"message": "Not updating device_keys for user", "user_id": user_id}) one_time_keys = keys.get("one_time_keys", None) @@ -494,7 +486,7 @@ class E2eKeysHandler(object): "device_id": device_id, } ) - yield self._upload_one_time_keys_for_user( + await self._upload_one_time_keys_for_user( user_id, device_id, time_now, one_time_keys ) else: @@ -507,15 +499,14 @@ class E2eKeysHandler(object): # old access_token without an associated device_id. Either way, we # need to double-check the device is registered to avoid ending up with # keys without a corresponding device. - yield self.device_handler.check_device_registered(user_id, device_id) + await self.device_handler.check_device_registered(user_id, device_id) - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + result = await self.store.count_e2e_one_time_keys(user_id, device_id) set_tag("one_time_key_counts", result) return {"one_time_key_counts": result} - @defer.inlineCallbacks - def _upload_one_time_keys_for_user( + async def _upload_one_time_keys_for_user( self, user_id, device_id, time_now, one_time_keys ): logger.info( @@ -533,7 +524,7 @@ class E2eKeysHandler(object): key_list.append((algorithm, key_id, key_obj)) # First we check if we have already persisted any of the keys. - existing_key_map = yield self.store.get_e2e_one_time_keys( + existing_key_map = await self.store.get_e2e_one_time_keys( user_id, device_id, [k_id for _, k_id, _ in key_list] ) @@ -556,10 +547,9 @@ class E2eKeysHandler(object): ) log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) - yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) + await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) - @defer.inlineCallbacks - def upload_signing_keys_for_user(self, user_id, keys): + async def upload_signing_keys_for_user(self, user_id, keys): """Upload signing keys for cross-signing Args: @@ -574,7 +564,7 @@ class E2eKeysHandler(object): _check_cross_signing_key(master_key, user_id, "master") else: - master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") # if there is no master key, then we can't do anything, because all the # other cross-signing keys need to be signed by the master key @@ -613,10 +603,10 @@ class E2eKeysHandler(object): # if everything checks out, then store the keys and send notifications deviceids = [] if "master_key" in keys: - yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) deviceids.append(master_verify_key.version) if "self_signing_key" in keys: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) try: @@ -626,23 +616,22 @@ class E2eKeysHandler(object): except ValueError: raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM) if "user_signing_key" in keys: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "user_signing", user_signing_key ) # the signature stream matches the semantics that we want for # user-signing key updates: only the user themselves is notified of # their own user-signing key updates - yield self.device_handler.notify_user_signature_update(user_id, [user_id]) + await self.device_handler.notify_user_signature_update(user_id, [user_id]) # master key and self-signing key updates match the semantics of device # list updates: all users who share an encrypted room are notified if len(deviceids): - yield self.device_handler.notify_device_update(user_id, deviceids) + await self.device_handler.notify_device_update(user_id, deviceids) return {} - @defer.inlineCallbacks - def upload_signatures_for_device_keys(self, user_id, signatures): + async def upload_signatures_for_device_keys(self, user_id, signatures): """Upload device signatures for cross-signing Args: @@ -667,13 +656,13 @@ class E2eKeysHandler(object): self_signatures = signatures.get(user_id, {}) other_signatures = {k: v for k, v in signatures.items() if k != user_id} - self_signature_list, self_failures = yield self._process_self_signatures( + self_signature_list, self_failures = await self._process_self_signatures( user_id, self_signatures ) signature_list.extend(self_signature_list) failures.update(self_failures) - other_signature_list, other_failures = yield self._process_other_signatures( + other_signature_list, other_failures = await self._process_other_signatures( user_id, other_signatures ) signature_list.extend(other_signature_list) @@ -681,21 +670,20 @@ class E2eKeysHandler(object): # store the signature, and send the appropriate notifications for sync logger.debug("upload signature failures: %r", failures) - yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) + await self.store.store_e2e_cross_signing_signatures(user_id, signature_list) self_device_ids = [item.target_device_id for item in self_signature_list] if self_device_ids: - yield self.device_handler.notify_device_update(user_id, self_device_ids) + await self.device_handler.notify_device_update(user_id, self_device_ids) signed_users = [item.target_user_id for item in other_signature_list] if signed_users: - yield self.device_handler.notify_user_signature_update( + await self.device_handler.notify_user_signature_update( user_id, signed_users ) return {"failures": failures} - @defer.inlineCallbacks - def _process_self_signatures(self, user_id, signatures): + async def _process_self_signatures(self, user_id, signatures): """Process uploaded signatures of the user's own keys. Signatures of the user's own keys from this API come in two forms: @@ -728,7 +716,7 @@ class E2eKeysHandler(object): _, self_signing_key_id, self_signing_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing") # get our master key, since we may have received a signature of it. # We need to fetch it here so that we know what its key ID is, so @@ -738,12 +726,12 @@ class E2eKeysHandler(object): master_key, _, master_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "master") # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what # was sent if the device was signed - devices = yield self.store.get_e2e_device_keys([(user_id, None)]) + devices = await self.store.get_e2e_device_keys([(user_id, None)]) if user_id not in devices: raise NotFoundError("No device keys found") @@ -853,8 +841,7 @@ class E2eKeysHandler(object): return master_key_signature_list - @defer.inlineCallbacks - def _process_other_signatures(self, user_id, signatures): + async def _process_other_signatures(self, user_id, signatures): """Process uploaded signatures of other users' keys. These will be the target user's master keys, signed by the uploading user's user-signing key. @@ -882,7 +869,7 @@ class E2eKeysHandler(object): user_signing_key, user_signing_key_id, user_signing_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing") except SynapseError as e: failure = _exception_to_failure(e) for user, devicemap in signatures.items(): @@ -905,7 +892,7 @@ class E2eKeysHandler(object): master_key, master_key_id, _, - ) = yield self._get_e2e_cross_signing_verify_key( + ) = await self._get_e2e_cross_signing_verify_key( target_user, "master", user_id ) @@ -958,8 +945,7 @@ class E2eKeysHandler(object): return signature_list, failures - @defer.inlineCallbacks - def _get_e2e_cross_signing_verify_key( + async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: str = None ): """Fetch locally or remotely query for a cross-signing public key. @@ -983,7 +969,7 @@ class E2eKeysHandler(object): SynapseError: if `user_id` is invalid """ user = UserID.from_string(user_id) - key = yield self.store.get_e2e_cross_signing_key( + key = await self.store.get_e2e_cross_signing_key( user_id, key_type, from_user_id ) @@ -1009,15 +995,14 @@ class E2eKeysHandler(object): key, key_id, verify_key, - ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type) + ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type) if key is None: raise NotFoundError("No %s key found for %s" % (key_type, user_id)) return key, key_id, verify_key - @defer.inlineCallbacks - def _retrieve_cross_signing_keys_for_remote_user( + async def _retrieve_cross_signing_keys_for_remote_user( self, user: UserID, desired_key_type: str, ): """Queries cross-signing keys for a remote user and saves them to the database @@ -1035,7 +1020,7 @@ class E2eKeysHandler(object): If the key cannot be retrieved, all values in the tuple will instead be None. """ try: - remote_result = yield self.federation.query_user_devices( + remote_result = await self.federation.query_user_devices( user.domain, user.to_string() ) except Exception as e: @@ -1101,14 +1086,14 @@ class E2eKeysHandler(object): desired_key_id = key_id # At the same time, store this key in the db for subsequent queries - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user.to_string(), key_type, key_content ) # Notify clients that new devices for this user have been discovered if retrieved_device_ids: # XXX is this necessary? - yield self.device_handler.notify_device_update( + await self.device_handler.notify_device_update( user.to_string(), retrieved_device_ids ) @@ -1250,8 +1235,7 @@ class SigningKeyEduUpdater(object): iterable=True, ) - @defer.inlineCallbacks - def incoming_signing_key_update(self, origin, edu_content): + async def incoming_signing_key_update(self, origin, edu_content): """Called on incoming signing key update from federation. Responsible for parsing the EDU and adding to pending updates list. @@ -1268,7 +1252,7 @@ class SigningKeyEduUpdater(object): logger.warning("Got signing key update edu for %r from %r", user_id, origin) return - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. @@ -1278,10 +1262,9 @@ class SigningKeyEduUpdater(object): (master_key, self_signing_key) ) - yield self._handle_signing_key_updates(user_id) + await self._handle_signing_key_updates(user_id) - @defer.inlineCallbacks - def _handle_signing_key_updates(self, user_id): + async def _handle_signing_key_updates(self, user_id): """Actually handle pending updates. Args: @@ -1291,7 +1274,7 @@ class SigningKeyEduUpdater(object): device_handler = self.e2e_keys_handler.device_handler device_list_updater = device_handler.device_list_updater - with (yield self._remote_edu_linearizer.queue(user_id)): + with (await self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates @@ -1302,9 +1285,9 @@ class SigningKeyEduUpdater(object): logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = yield device_list_updater.process_cross_signing_key_update( + new_device_ids = await device_list_updater.process_cross_signing_key_update( user_id, master_key, self_signing_key, ) device_ids = device_ids + new_device_ids - yield device_handler.notify_device_update(user_id, device_ids) + await device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f55470a707..0bb983dc28 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py
@@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import ( Codes, NotFoundError, @@ -50,8 +48,7 @@ class E2eRoomKeysHandler(object): self._upload_linearizer = Linearizer("upload_room_keys_lock") @trace - @defer.inlineCallbacks - def get_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. @@ -71,17 +68,17 @@ class E2eRoomKeysHandler(object): # we deliberately take the lock to get keys so that changing the version # works atomically - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # make sure the backup version exists try: - yield self.store.get_e2e_room_keys_version_info(user_id, version) + await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") else: raise - results = yield self.store.get_e2e_room_keys( + results = await self.store.get_e2e_room_keys( user_id, version, room_id, session_id ) @@ -89,8 +86,7 @@ class E2eRoomKeysHandler(object): return results @trace - @defer.inlineCallbacks - def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. @@ -109,10 +105,10 @@ class E2eRoomKeysHandler(object): """ # lock for consistency with uploading - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # make sure the backup version exists try: - version_info = yield self.store.get_e2e_room_keys_version_info( + version_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) except StoreError as e: @@ -121,19 +117,18 @@ class E2eRoomKeysHandler(object): else: raise - yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + await self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) version_etag = version_info["etag"] + 1 - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, None, version_etag ) - count = yield self.store.count_e2e_room_keys(user_id, version) + count = await self.store.count_e2e_room_keys(user_id, version) return {"etag": str(version_etag), "count": count} @trace - @defer.inlineCallbacks - def upload_room_keys(self, user_id, version, room_keys): + async def upload_room_keys(self, user_id, version, room_keys): """Bulk upload a list of room keys into a given backup version, asserting that the given version is the current backup version. room_keys are merged into the current backup as described in RoomKeysServlet.on_PUT(). @@ -169,11 +164,11 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # XXX: perhaps we should use a finer grained lock here? - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # Check that the version we're trying to upload is the current version try: - version_info = yield self.store.get_e2e_room_keys_version_info(user_id) + version_info = await self.store.get_e2e_room_keys_version_info(user_id) except StoreError as e: if e.code == 404: raise NotFoundError("Version '%s' not found" % (version,)) @@ -183,7 +178,7 @@ class E2eRoomKeysHandler(object): if version_info["version"] != version: # Check that the version we're trying to upload actually exists try: - version_info = yield self.store.get_e2e_room_keys_version_info( + version_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) # if we get this far, the version must exist @@ -198,7 +193,7 @@ class E2eRoomKeysHandler(object): # submitted. Then compare them with the submitted keys. If the # key is new, insert it; if the key should be updated, then update # it; otherwise, drop it. - existing_keys = yield self.store.get_e2e_room_keys_multi( + existing_keys = await self.store.get_e2e_room_keys_multi( user_id, version, room_keys["rooms"] ) to_insert = [] # batch the inserts together @@ -227,7 +222,7 @@ class E2eRoomKeysHandler(object): # updates are done one at a time in the DB, so send # updates right away rather than batching them up, # like we do with the inserts - yield self.store.update_e2e_room_key( + await self.store.update_e2e_room_key( user_id, version, room_id, session_id, room_key ) changed = True @@ -246,16 +241,16 @@ class E2eRoomKeysHandler(object): changed = True if len(to_insert): - yield self.store.add_e2e_room_keys(user_id, version, to_insert) + await self.store.add_e2e_room_keys(user_id, version, to_insert) version_etag = version_info["etag"] if changed: version_etag = version_etag + 1 - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, None, version_etag ) - count = yield self.store.count_e2e_room_keys(user_id, version) + count = await self.store.count_e2e_room_keys(user_id, version) return {"etag": str(version_etag), "count": count} @staticmethod @@ -291,8 +286,7 @@ class E2eRoomKeysHandler(object): return True @trace - @defer.inlineCallbacks - def create_version(self, user_id, version_info): + async def create_version(self, user_id, version_info): """Create a new backup version. This automatically becomes the new backup version for the user's keys; previous backups will no longer be writeable to. @@ -313,14 +307,13 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # lock everyone out until we've switched version - with (yield self._upload_linearizer.queue(user_id)): - new_version = yield self.store.create_e2e_room_keys_version( + with (await self._upload_linearizer.queue(user_id)): + new_version = await self.store.create_e2e_room_keys_version( user_id, version_info ) return new_version - @defer.inlineCallbacks - def get_version_info(self, user_id, version=None): + async def get_version_info(self, user_id, version=None): """Get the info about a given version of the user's backup Args: @@ -339,22 +332,21 @@ class E2eRoomKeysHandler(object): } """ - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - res = yield self.store.get_e2e_room_keys_version_info(user_id, version) + res = await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") else: raise - res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) + res["count"] = await self.store.count_e2e_room_keys(user_id, res["version"]) res["etag"] = str(res["etag"]) return res @trace - @defer.inlineCallbacks - def delete_version(self, user_id, version=None): + async def delete_version(self, user_id, version=None): """Deletes a given version of the user's e2e_room_keys backup Args: @@ -364,9 +356,9 @@ class E2eRoomKeysHandler(object): NotFoundError: if this backup version doesn't exist """ - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - yield self.store.delete_e2e_room_keys_version(user_id, version) + await self.store.delete_e2e_room_keys_version(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") @@ -374,8 +366,7 @@ class E2eRoomKeysHandler(object): raise @trace - @defer.inlineCallbacks - def update_version(self, user_id, version, version_info): + async def update_version(self, user_id, version, version_info): """Update the info about a given version of the user's backup Args: @@ -393,9 +384,9 @@ class E2eRoomKeysHandler(object): raise SynapseError( 400, "Version in body does not match", Codes.INVALID_PARAM ) - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - old_info = yield self.store.get_e2e_room_keys_version_info( + old_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) except StoreError as e: @@ -406,7 +397,7 @@ class E2eRoomKeysHandler(object): if old_info["algorithm"] != version_info["algorithm"]: raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM) - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, version_info ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ca7da42a3f..71ac5dca99 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@ import itertools import logging -from collections import Container +from collections.abc import Container from http import HTTPStatus from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union @@ -44,6 +44,7 @@ from synapse.api.errors import ( FederationDeniedError, FederationError, HttpResponseException, + NotFoundError, RequestSendFailed, SynapseError, ) @@ -61,6 +62,7 @@ from synapse.logging.context import ( run_in_background, ) from synapse.logging.utils import log_function +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, @@ -618,6 +620,11 @@ class FederationHandler(BaseHandler): will be omitted from the result. Likewise, any events which turn out not to be in the given room. + This function *does not* automatically get missing auth events of the + newly fetched events. Callers must include the full auth chain of + of the missing events in the `event_ids` argument, to ensure that any + missing auth events are correctly fetched. + Returns: map from event_id to event """ @@ -784,15 +791,25 @@ class FederationHandler(BaseHandler): resync = True if resync: - await self.store.mark_remote_user_device_cache_as_stale(event.sender) + run_as_background_process( + "resync_device_due_to_pdu", self._resync_device, event.sender + ) - # Immediately attempt a resync in the background - if self.config.worker_app: - return run_in_background(self._user_device_resync, event.sender) - else: - return run_in_background( - self._device_list_updater.user_device_resync, event.sender - ) + async def _resync_device(self, sender: str) -> None: + """We have detected that the device list for the given user may be out + of sync, so we try and resync them. + """ + + try: + await self.store.mark_remote_user_device_cache_as_stale(sender) + + # Immediately attempt a resync in the background + if self.config.worker_app: + await self._user_device_resync(user_id=sender) + else: + await self._device_list_updater.user_device_resync(sender) + except Exception: + logger.exception("Failed to resync device for %s", sender) @log_function async def backfill(self, dest, room_id, limit, extremities): @@ -1131,12 +1148,16 @@ class FederationHandler(BaseHandler): ): """Fetch the given events from a server, and persist them as outliers. + This function *does not* recursively get missing auth events of the + newly fetched events. Callers must include in the `events` argument + any missing events from the auth chain. + Logs a warning if we can't find the given event. """ room_version = await self.store.get_room_version(room_id) - event_infos = [] + event_map = {} # type: Dict[str, EventBase] async def get_event(event_id: str): with nested_logging_context(event_id): @@ -1150,17 +1171,7 @@ class FederationHandler(BaseHandler): ) return - # recursively fetch the auth events for this event - auth_events = await self._get_events_from_store_or_dest( - destination, room_id, event.auth_event_ids() - ) - auth = {} - for auth_event_id in event.auth_event_ids(): - ae = auth_events.get(auth_event_id) - if ae: - auth[(ae.type, ae.state_key)] = ae - - event_infos.append(_NewEventInfo(event, None, auth)) + event_map[event.event_id] = event except Exception as e: logger.warning( @@ -1172,6 +1183,32 @@ class FederationHandler(BaseHandler): await concurrently_execute(get_event, events, 5) + # Make a map of auth events for each event. We do this after fetching + # all the events as some of the events' auth events will be in the list + # of requested events. + + auth_events = [ + aid + for event in event_map.values() + for aid in event.auth_event_ids() + if aid not in event_map + ] + persisted_events = await self.store.get_events( + auth_events, allow_rejected=True, + ) + + event_infos = [] + for event in event_map.values(): + auth = {} + for auth_event_id in event.auth_event_ids(): + ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id) + if ae: + auth[(ae.type, ae.state_key)] = ae + else: + logger.info("Missing auth event %s", auth_event_id) + + event_infos.append(_NewEventInfo(event, None, auth)) + await self._handle_new_events( destination, event_infos, ) @@ -1403,10 +1440,20 @@ class FederationHandler(BaseHandler): ) raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - event_content = {"membership": Membership.JOIN} - + # checking the room version will check that we've actually heard of the room + # (and return a 404 otherwise) room_version = await self.store.get_room_version_id(room_id) + # now check that we are *still* in the room + is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) + if not is_in_room: + logger.info( + "Got /make_join request for room %s we are no longer in", room_id, + ) + raise NotFoundError("Not an active room on this server") + + event_content = {"membership": Membership.JOIN} + builder = self.event_builder_factory.new( room_version, { diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index b137f806d5..4b9b80a36d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -15,12 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from canonicaljson import encode_canonical_json, json -from twisted.internet import defer -from twisted.internet.defer import succeed from twisted.internet.interfaces import IDelayedCall from synapse import event_auth @@ -41,13 +39,22 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.events import EventBase +from synapse.events.builder import EventBuilder +from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import Collection, RoomAlias, UserID, create_requester +from synapse.types import ( + Collection, + Requester, + RoomAlias, + StreamToken, + UserID, + create_requester, +) from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.metrics import measure_func @@ -84,14 +91,22 @@ class MessageHandler(object): "_schedule_next_expiry", self._schedule_next_expiry ) - @defer.inlineCallbacks - def get_room_data( - self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False - ): + async def get_room_data( + self, + user_id: str = None, + room_id: str = None, + event_type: Optional[str] = None, + state_key: str = "", + is_guest: bool = False, + ) -> dict: """ Get data from a room. Args: - event : The room path event + user_id + room_id + event_type + state_key + is_guest Returns: The path data content. Raises: @@ -100,30 +115,29 @@ class MessageHandler(object): ( membership, membership_event_id, - ) = yield self.auth.check_user_in_room_or_world_readable( + ) = await self.auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) if membership == Membership.JOIN: - data = yield self.state.get_current_state(room_id, event_type, state_key) + data = await self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) - room_state = yield self.state_store.get_state_for_events( + room_state = await self.state_store.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) return data - @defer.inlineCallbacks - def get_state_events( + async def get_state_events( self, - user_id, - room_id, - state_filter=StateFilter.all(), - at_token=None, - is_guest=False, - ): + user_id: str, + room_id: str, + state_filter: StateFilter = StateFilter.all(), + at_token: Optional[StreamToken] = None, + is_guest: bool = False, + ) -> List[dict]: """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has left the room return the state events from when they left. If an explicit @@ -131,15 +145,14 @@ class MessageHandler(object): visible. Args: - user_id(str): The user requesting state events. - room_id(str): The room ID to get all state events from. - state_filter (StateFilter): The state filter used to fetch state - from the database. - at_token(StreamToken|None): the stream token of the at which we are requesting + user_id: The user requesting state events. + room_id: The room ID to get all state events from. + state_filter: The state filter used to fetch state from the database. + at_token: the stream token of the at which we are requesting the stats. If the user is not allowed to view the state as of that stream token, we raise a 403 SynapseError. If None, returns the current state based on the current_state_events table. - is_guest(bool): whether this user is a guest + is_guest: whether this user is a guest Returns: A list of dicts representing state events. [{}, {}, {}] Raises: @@ -153,20 +166,20 @@ class MessageHandler(object): # get_recent_events_for_room operates by topo ordering. This therefore # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) - last_events, _ = yield self.store.get_recent_events_for_room( + last_events, _ = await self.store.get_recent_events_for_room( room_id, end_token=at_token.room_key, limit=1 ) if not last_events: raise NotFoundError("Can't find event for token %s" % (at_token,)) - visible_events = yield filter_events_for_client( + visible_events = await filter_events_for_client( self.storage, user_id, last_events, filter_send_to_client=False ) event = last_events[0] if visible_events: - room_state = yield self.state_store.get_state_for_events( + room_state = await self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] @@ -180,23 +193,23 @@ class MessageHandler(object): ( membership, membership_event_id, - ) = yield self.auth.check_user_in_room_or_world_readable( + ) = await self.auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) if membership == Membership.JOIN: - state_ids = yield self.store.get_filtered_current_state_ids( + state_ids = await self.store.get_filtered_current_state_ids( room_id, state_filter=state_filter ) - room_state = yield self.store.get_events(state_ids.values()) + room_state = await self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: - room_state = yield self.state_store.get_state_for_events( + room_state = await self.state_store.get_state_for_events( [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] now = self.clock.time_msec() - events = yield self._event_serializer.serialize_events( + events = await self._event_serializer.serialize_events( room_state.values(), now, # We don't bother bundling aggregations in when asked for state @@ -205,15 +218,14 @@ class MessageHandler(object): ) return events - @defer.inlineCallbacks - def get_joined_members(self, requester, room_id): + async def get_joined_members(self, requester: Requester, room_id: str) -> dict: """Get all the joined members in the room and their profile information. If the user has left the room return the state events from when they left. Args: - requester(Requester): The user requesting state events. - room_id(str): The room ID to get all state events from. + requester: The user requesting state events. + room_id: The room ID to get all state events from. Returns: A dict of user_id to profile info """ @@ -221,7 +233,7 @@ class MessageHandler(object): if not requester.app_service: # We check AS auth after fetching the room membership, as it # requires us to pull out all joined members anyway. - membership, _ = yield self.auth.check_user_in_room_or_world_readable( + membership, _ = await self.auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) if membership != Membership.JOIN: @@ -229,7 +241,7 @@ class MessageHandler(object): "Getting joined members after leaving is not implemented" ) - users_with_profile = yield self.state.get_current_users_in_room(room_id) + users_with_profile = await self.state.get_current_users_in_room(room_id) # If this is an AS, double check that they are allowed to see the members. # This can either be because the AS user is in the room or because there @@ -250,7 +262,7 @@ class MessageHandler(object): for user_id, profile in users_with_profile.items() } - def maybe_schedule_expiry(self, event): + def maybe_schedule_expiry(self, event: EventBase): """Schedule the expiry of an event if there's not already one scheduled, or if the one running is for an event that will expire after the provided timestamp. @@ -259,7 +271,7 @@ class MessageHandler(object): the master process, and therefore needs to be run on there. Args: - event (EventBase): The event to schedule the expiry of. + event: The event to schedule the expiry of. """ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) @@ -270,8 +282,7 @@ class MessageHandler(object): # a task scheduled for a timestamp that's sooner than the provided one. self._schedule_expiry_for_event(event.event_id, expiry_ts) - @defer.inlineCallbacks - def _schedule_next_expiry(self): + async def _schedule_next_expiry(self): """Retrieve the ID and the expiry timestamp of the next event to be expired, and schedule an expiry task for it. @@ -279,18 +290,18 @@ class MessageHandler(object): future call to save_expiry_ts can schedule a new expiry task. """ # Try to get the expiry timestamp of the next event to expire. - res = yield self.store.get_next_event_to_expire() + res = await self.store.get_next_event_to_expire() if res: event_id, expiry_ts = res self._schedule_expiry_for_event(event_id, expiry_ts) - def _schedule_expiry_for_event(self, event_id, expiry_ts): + def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int): """Schedule an expiry task for the provided event if there's not already one scheduled at a timestamp that's sooner than the provided one. Args: - event_id (str): The ID of the event to expire. - expiry_ts (int): The timestamp at which to expire the event. + event_id: The ID of the event to expire. + expiry_ts: The timestamp at which to expire the event. """ if self._scheduled_expiry: # If the provided timestamp refers to a time before the scheduled time of the @@ -320,8 +331,7 @@ class MessageHandler(object): event_id, ) - @defer.inlineCallbacks - def _expire_event(self, event_id): + async def _expire_event(self, event_id: str): """Retrieve and expire an event that needs to be expired from the database. If the event doesn't exist in the database, log it and delete the expiry date @@ -336,12 +346,12 @@ class MessageHandler(object): try: # Expire the event if we know about it. This function also deletes the expiry # date from the database in the same database transaction. - yield self.store.expire_event(event_id) + await self.store.expire_event(event_id) except Exception as e: logger.error("Could not expire event %s: %r", event_id, e) # Schedule the expiry of the next event to expire. - yield self._schedule_next_expiry() + await self._schedule_next_expiry() # The duration (in ms) after which rooms should be removed @@ -425,16 +435,15 @@ class EventCreationHandler(object): self._dummy_events_threshold = hs.config.dummy_events_threshold - @defer.inlineCallbacks - def create_event( + async def create_event( self, - requester, - event_dict, - token_id=None, - txn_id=None, + requester: Requester, + event_dict: dict, + token_id: Optional[str] = None, + txn_id: Optional[str] = None, prev_event_ids: Optional[Collection[str]] = None, - require_consent=True, - ): + require_consent: bool = True, + ) -> Tuple[EventBase, EventContext]: """ Given a dict from a client, create a new event. @@ -445,31 +454,29 @@ class EventCreationHandler(object): Args: requester - event_dict (dict): An entire event - token_id (str) - txn_id (str) - + event_dict: An entire event + token_id + txn_id prev_event_ids: the forward extremities to use as the prev_events for the new event. If None, they will be requested from the database. - - require_consent (bool): Whether to check if the requester has - consented to privacy policy. + require_consent: Whether to check if the requester has + consented to the privacy policy. Raises: ResourceLimitError if server is blocked to some resource being exceeded Returns: - Tuple of created event (FrozenEvent), Context + Tuple of created event, Context """ - yield self.auth.check_auth_blocking(requester.user.to_string()) + await self.auth.check_auth_blocking(requester.user.to_string()) if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": room_version = event_dict["content"]["room_version"] else: try: - room_version = yield self.store.get_room_version_id( + room_version = await self.store.get_room_version_id( event_dict["room_id"] ) except NotFoundError: @@ -490,11 +497,11 @@ class EventCreationHandler(object): try: if "displayname" not in content: - displayname = yield profile.get_displayname(target) + displayname = await profile.get_displayname(target) if displayname is not None: content["displayname"] = displayname if "avatar_url" not in content: - avatar_url = yield profile.get_avatar_url(target) + avatar_url = await profile.get_avatar_url(target) if avatar_url is not None: content["avatar_url"] = avatar_url except Exception as e: @@ -502,9 +509,9 @@ class EventCreationHandler(object): "Failed to get profile information for %r: %s", target, e ) - is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester) + is_exempt = await self._is_exempt_from_privacy_policy(builder, requester) if require_consent and not is_exempt: - yield self.assert_accepted_privacy_policy(requester) + await self.assert_accepted_privacy_policy(requester) if token_id is not None: builder.internal_metadata.token_id = token_id @@ -512,7 +519,7 @@ class EventCreationHandler(object): if txn_id is not None: builder.internal_metadata.txn_id = txn_id - event, context = yield self.create_new_client_event( + event, context = await self.create_new_client_event( builder=builder, requester=requester, prev_event_ids=prev_event_ids, ) @@ -528,10 +535,10 @@ class EventCreationHandler(object): # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( - yield self.store.get_event(prev_event_id, allow_none=True) + await self.store.get_event(prev_event_id, allow_none=True) if prev_event_id else None ) @@ -554,37 +561,36 @@ class EventCreationHandler(object): return (event, context) - def _is_exempt_from_privacy_policy(self, builder, requester): + async def _is_exempt_from_privacy_policy( + self, builder: EventBuilder, requester: Requester + ) -> bool: """"Determine if an event to be sent is exempt from having to consent to the privacy policy Args: - builder (synapse.events.builder.EventBuilder): event being created - requester (Requster): user requesting this event + builder: event being created + requester: user requesting this event Returns: - Deferred[bool]: true if the event can be sent without the user - consenting + true if the event can be sent without the user consenting """ # the only thing the user can do is join the server notices room. if builder.type == EventTypes.Member: membership = builder.content.get("membership", None) if membership == Membership.JOIN: - return self._is_server_notices_room(builder.room_id) + return await self._is_server_notices_room(builder.room_id) elif membership == Membership.LEAVE: # the user is always allowed to leave (but not kick people) return builder.state_key == requester.user.to_string() - return succeed(False) + return False - @defer.inlineCallbacks - def _is_server_notices_room(self, room_id): + async def _is_server_notices_room(self, room_id: str) -> bool: if self.config.server_notices_mxid is None: return False - user_ids = yield self.store.get_users_in_room(room_id) + user_ids = await self.store.get_users_in_room(room_id) return self.config.server_notices_mxid in user_ids - @defer.inlineCallbacks - def assert_accepted_privacy_policy(self, requester): + async def assert_accepted_privacy_policy(self, requester: Requester) -> None: """Check if a user has accepted the privacy policy Called when the given user is about to do something that requires @@ -593,12 +599,10 @@ class EventCreationHandler(object): raised. Args: - requester (synapse.types.Requester): - The user making the request + requester: The user making the request Returns: - Deferred[None]: returns normally if the user has consented or is - exempt + Returns normally if the user has consented or is exempt Raises: ConsentNotGivenError: if the user has not given consent yet @@ -619,7 +623,7 @@ class EventCreationHandler(object): ): return - u = yield self.store.get_user_by_id(user_id) + u = await self.store.get_user_by_id(user_id) assert u is not None if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT): # support and bot users are not required to consent @@ -637,16 +641,20 @@ class EventCreationHandler(object): raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) async def send_nonmember_event( - self, requester, event, context, ratelimit=True + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, ) -> int: """ Persists and notifies local clients and federation of an event. Args: - event (FrozenEvent) the event to send. - context (Context) the context of the event. - ratelimit (bool): Whether to rate limit this send. - is_guest (bool): Whether the sender is a guest. + requester + event the event to send. + context: the context of the event. + ratelimit: Whether to rate limit this send. Return: The stream_id of the persisted event. @@ -674,19 +682,20 @@ class EventCreationHandler(object): requester=requester, event=event, context=context, ratelimit=ratelimit ) - @defer.inlineCallbacks - def deduplicate_state_event(self, event, context): + async def deduplicate_state_event( + self, event: EventBase, context: EventContext + ) -> None: """ Checks whether event is in the latest resolved state in context. If so, returns the version of the event in context. Otherwise, returns None. """ - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: return - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if not prev_event: return @@ -698,7 +707,11 @@ class EventCreationHandler(object): return async def create_and_send_nonmember_event( - self, requester, event_dict, ratelimit=True, txn_id=None + self, + requester: Requester, + event_dict: EventBase, + ratelimit: bool = True, + txn_id: Optional[str] = None, ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. @@ -728,17 +741,17 @@ class EventCreationHandler(object): return event, stream_id @measure_func("create_new_client_event") - @defer.inlineCallbacks - def create_new_client_event( - self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None - ): + async def create_new_client_event( + self, + builder: EventBuilder, + requester: Optional[Requester] = None, + prev_event_ids: Optional[Collection[str]] = None, + ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client Args: - builder (EventBuilder): - - requester (synapse.types.Requester|None): - + builder: + requester: prev_event_ids: the forward extremities to use as the prev_events for the new event. @@ -746,7 +759,7 @@ class EventCreationHandler(object): If None, they will be requested from the database. Returns: - Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] + Tuple of created event, context """ if prev_event_ids is not None: @@ -755,10 +768,10 @@ class EventCreationHandler(object): % (len(prev_event_ids),) ) else: - prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id) + prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) - event = yield builder.build(prev_event_ids=prev_event_ids) - context = yield self.state.compute_event_context(event) + event = await builder.build(prev_event_ids=prev_event_ids) + context = await self.state.compute_event_context(event) if requester: context.app_service = requester.app_service @@ -772,7 +785,7 @@ class EventCreationHandler(object): relates_to = relation["event_id"] aggregation_key = relation["key"] - already_exists = yield self.store.has_user_annotated_event( + already_exists = await self.store.has_user_annotated_event( relates_to, event.type, aggregation_key, event.sender ) if already_exists: @@ -784,7 +797,12 @@ class EventCreationHandler(object): @measure_func("handle_new_client_event") async def handle_new_client_event( - self, requester, event, context, ratelimit=True, extra_users=[] + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + extra_users: List[UserID] = [], ) -> int: """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -793,11 +811,11 @@ class EventCreationHandler(object): processing. Args: - requester (Requester) - event (FrozenEvent) - context (EventContext) - ratelimit (bool) - extra_users (list(UserID)): Any extra users to notify about event + requester + event + context + ratelimit + extra_users: Any extra users to notify about event Return: The stream_id of the persisted event. @@ -876,10 +894,9 @@ class EventCreationHandler(object): self.store.remove_push_actions_from_staging, event.event_id ) - @defer.inlineCallbacks - def _validate_canonical_alias( - self, directory_handler, room_alias_str, expected_room_id - ): + async def _validate_canonical_alias( + self, directory_handler, room_alias_str: str, expected_room_id: str + ) -> None: """ Ensure that the given room alias points to the expected room ID. @@ -890,9 +907,7 @@ class EventCreationHandler(object): """ room_alias = RoomAlias.from_string(room_alias_str) try: - mapping = yield defer.ensureDeferred( - directory_handler.get_association(room_alias) - ) + mapping = await directory_handler.get_association(room_alias) except SynapseError as e: # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors. if e.errcode == Codes.NOT_FOUND: @@ -911,7 +926,12 @@ class EventCreationHandler(object): ) async def persist_and_notify_client_event( - self, requester, event, context, ratelimit=True, extra_users=[] + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + extra_users: List[UserID] = [], ) -> int: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -1104,7 +1124,7 @@ class EventCreationHandler(object): return event_stream_id - async def _bump_active_time(self, user): + async def _bump_active_time(self, user: UserID) -> None: try: presence = self.hs.get_presence_handler() await presence.bump_presence_active_time(user) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index d2f25ae12a..8e99c83d9d 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -30,8 +30,6 @@ from typing import Dict, Iterable, List, Set, Tuple from prometheus_client import Counter from typing_extensions import ContextManager -from twisted.internet import defer - import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError @@ -39,6 +37,8 @@ from synapse.logging.context import run_in_background from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.state import StateHandler +from synapse.storage.data_stores.main import DataStore from synapse.storage.presence import UserPresenceState from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer @@ -895,16 +895,9 @@ class PresenceHandler(BasePresenceHandler): await self._on_user_joined_room(room_id, state_key) - async def _on_user_joined_room(self, room_id, user_id): + async def _on_user_joined_room(self, room_id: str, user_id: str) -> None: """Called when we detect a user joining the room via the current state delta stream. - - Args: - room_id (str) - user_id (str) - - Returns: - Deferred """ if self.is_mine_id(user_id): @@ -1296,22 +1289,24 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): return new_state, persist_and_notify, federation_ping -@defer.inlineCallbacks -def get_interested_parties(store, states): +async def get_interested_parties( + store: DataStore, states: List[UserPresenceState] +) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]: """Given a list of states return which entities (rooms, users) are interested in the given states. Args: - states (list(UserPresenceState)) + store + states Returns: - 2-tuple: `(room_ids_to_states, users_to_states)`, + A 2-tuple of `(room_ids_to_states, users_to_states)`, with each item being a dict of `entity_name` -> `[UserPresenceState]` """ room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] users_to_states = {} # type: Dict[str, List[UserPresenceState]] for state in states: - room_ids = yield store.get_rooms_for_user(state.user_id) + room_ids = await store.get_rooms_for_user(state.user_id) for room_id in room_ids: room_ids_to_states.setdefault(room_id, []).append(state) @@ -1321,20 +1316,22 @@ def get_interested_parties(store, states): return room_ids_to_states, users_to_states -@defer.inlineCallbacks -def get_interested_remotes(store, states, state_handler): +async def get_interested_remotes( + store: DataStore, states: List[UserPresenceState], state_handler: StateHandler +) -> List[Tuple[List[str], List[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. All the presence states should be for local users only. Args: - store (DataStore) - states (list(UserPresenceState)) + store + states + state_handler Returns: - Deferred list of ([destinations], [UserPresenceState]), where for - each row the list of UserPresenceState should be sent to each + A list of 2-tuples of destinations and states, where for + each tuple the list of UserPresenceState should be sent to each destination """ hosts_and_states = [] @@ -1342,10 +1339,10 @@ def get_interested_remotes(store, states, state_handler): # First we look up the rooms each user is in (as well as any explicit # subscriptions), then for each distinct room we look up the remote # hosts in those rooms. - room_ids_to_states, users_to_states = yield get_interested_parties(store, states) + room_ids_to_states, users_to_states = await get_interested_parties(store, states) for room_id, states in room_ids_to_states.items(): - hosts = yield state_handler.get_current_hosts_in_room(room_id) + hosts = await state_handler.get_current_hosts_in_room(room_id) hosts_and_states.append((hosts, states)) for user_id, states in users_to_states.items(): diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 4b1e3073a8..31a2e5ea18 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py
@@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import ( AuthError, Codes, @@ -54,16 +52,15 @@ class BaseProfileHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() - @defer.inlineCallbacks - def get_profile(self, user_id): + async def get_profile(self, user_id): target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): try: - displayname = yield self.store.get_profile_displayname( + displayname = await self.store.get_profile_displayname( target_user.localpart ) - avatar_url = yield self.store.get_profile_avatar_url( + avatar_url = await self.store.get_profile_avatar_url( target_user.localpart ) except StoreError as e: @@ -74,7 +71,7 @@ class BaseProfileHandler(BaseHandler): return {"displayname": displayname, "avatar_url": avatar_url} else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=target_user.domain, query_type="profile", args={"user_id": user_id}, @@ -86,8 +83,7 @@ class BaseProfileHandler(BaseHandler): except HttpResponseException as e: raise e.to_synapse_error() - @defer.inlineCallbacks - def get_profile_from_cache(self, user_id): + async def get_profile_from_cache(self, user_id): """Get the profile information from our local cache. If the user is ours then the profile information will always be corect. Otherwise, it may be out of date/missing. @@ -95,10 +91,10 @@ class BaseProfileHandler(BaseHandler): target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): try: - displayname = yield self.store.get_profile_displayname( + displayname = await self.store.get_profile_displayname( target_user.localpart ) - avatar_url = yield self.store.get_profile_avatar_url( + avatar_url = await self.store.get_profile_avatar_url( target_user.localpart ) except StoreError as e: @@ -108,14 +104,13 @@ class BaseProfileHandler(BaseHandler): return {"displayname": displayname, "avatar_url": avatar_url} else: - profile = yield self.store.get_from_remote_profile_cache(user_id) + profile = await self.store.get_from_remote_profile_cache(user_id) return profile or {} - @defer.inlineCallbacks - def get_displayname(self, target_user): + async def get_displayname(self, target_user): if self.hs.is_mine(target_user): try: - displayname = yield self.store.get_profile_displayname( + displayname = await self.store.get_profile_displayname( target_user.localpart ) except StoreError as e: @@ -126,7 +121,7 @@ class BaseProfileHandler(BaseHandler): return displayname else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=target_user.domain, query_type="profile", args={"user_id": target_user.to_string(), "field": "displayname"}, @@ -189,11 +184,10 @@ class BaseProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - @defer.inlineCallbacks - def get_avatar_url(self, target_user): + async def get_avatar_url(self, target_user): if self.hs.is_mine(target_user): try: - avatar_url = yield self.store.get_profile_avatar_url( + avatar_url = await self.store.get_profile_avatar_url( target_user.localpart ) except StoreError as e: @@ -203,7 +197,7 @@ class BaseProfileHandler(BaseHandler): return avatar_url else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=target_user.domain, query_type="profile", args={"user_id": target_user.to_string(), "field": "avatar_url"}, @@ -253,8 +247,7 @@ class BaseProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - @defer.inlineCallbacks - def on_profile_query(self, args): + async def on_profile_query(self, args): user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -264,12 +257,12 @@ class BaseProfileHandler(BaseHandler): response = {} try: if just_field is None or just_field == "displayname": - response["displayname"] = yield self.store.get_profile_displayname( + response["displayname"] = await self.store.get_profile_displayname( user.localpart ) if just_field is None or just_field == "avatar_url": - response["avatar_url"] = yield self.store.get_profile_avatar_url( + response["avatar_url"] = await self.store.get_profile_avatar_url( user.localpart ) except StoreError as e: @@ -304,8 +297,7 @@ class BaseProfileHandler(BaseHandler): "Failed to update join event for room %s - %s", room_id, str(e) ) - @defer.inlineCallbacks - def check_profile_query_allowed(self, target_user, requester=None): + async def check_profile_query_allowed(self, target_user, requester=None): """Checks whether a profile query is allowed. If the 'require_auth_for_profile_requests' config flag is set to True and a 'requester' is provided, the query is only allowed if the two users @@ -337,8 +329,8 @@ class BaseProfileHandler(BaseHandler): return try: - requester_rooms = yield self.store.get_rooms_for_user(requester.to_string()) - target_user_rooms = yield self.store.get_rooms_for_user( + requester_rooms = await self.store.get_rooms_for_user(requester.to_string()) + target_user_rooms = await self.store.get_rooms_for_user( target_user.to_string() ) @@ -371,25 +363,24 @@ class MasterProfileHandler(BaseProfileHandler): "Update remote profile", self._update_remote_profile_cache ) - @defer.inlineCallbacks - def _update_remote_profile_cache(self): + async def _update_remote_profile_cache(self): """Called periodically to check profiles of remote users we haven't checked in a while. """ - entries = yield self.store.get_remote_profile_cache_entries_that_expire( + entries = await self.store.get_remote_profile_cache_entries_that_expire( last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS ) for user_id, displayname, avatar_url in entries: - is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( + is_subscribed = await self.store.is_subscribed_remote_profile_for_user( user_id ) if not is_subscribed: - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) continue try: - profile = yield self.federation.make_query( + profile = await self.federation.make_query( destination=get_domain_from_id(user_id), query_type="profile", args={"user_id": user_id}, @@ -398,7 +389,7 @@ class MasterProfileHandler(BaseProfileHandler): except Exception: logger.exception("Failed to get avatar_url") - yield self.store.update_remote_profile_cache( + await self.store.update_remote_profile_cache( user_id, displayname, avatar_url ) continue @@ -407,4 +398,4 @@ class MasterProfileHandler(BaseProfileHandler): new_avatar = profile.get("avatar_url") # We always hit update to update the last_check timestamp - yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar) + await self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 8bc100db42..f922d8a545 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py
@@ -14,8 +14,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.handlers._base import BaseHandler from synapse.types import ReadReceipt, get_domain_from_id from synapse.util.async_helpers import maybe_awaitable @@ -129,15 +127,14 @@ class ReceiptEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events(self, from_key, room_ids, **kwargs): from_key = int(from_key) - to_key = yield self.get_current_key() + to_key = self.get_current_key() if from_key == to_key: return [], to_key - events = yield self.store.get_linearized_receipts_for_rooms( + events = await self.store.get_linearized_receipts_for_rooms( room_ids, from_key=from_key, to_key=to_key ) @@ -146,8 +143,7 @@ class ReceiptEventSource(object): def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): + async def get_pagination_rows(self, user, config, key): to_key = int(config.from_key) if config.to_key: @@ -155,8 +151,8 @@ class ReceiptEventSource(object): else: from_key = None - room_ids = yield self.store.get_rooms_for_user(user.to_string()) - events = yield self.store.get_linearized_receipts_for_rooms( + room_ids = await self.store.get_rooms_for_user(user.to_string()) + events = await self.store.get_linearized_receipts_for_rooms( room_ids, from_key=from_key, to_key=to_key ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 78c3772ac1..501f0fe795 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -28,7 +28,6 @@ from synapse.replication.http.register import ( ) from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester -from synapse.util.async_helpers import Linearizer from ._base import BaseHandler @@ -50,14 +49,7 @@ class RegistrationHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() self.identity_handler = self.hs.get_handlers().identity_handler self.ratelimiter = hs.get_registration_ratelimiter() - - self._next_generated_user_id = None - self.macaroon_gen = hs.get_macaroon_generator() - - self._generate_user_id_linearizer = Linearizer( - name="_generate_user_id_linearizer" - ) self._server_notices_mxid = hs.config.server_notices_mxid if hs.config.worker_app: @@ -219,7 +211,7 @@ class RegistrationHandler(BaseHandler): if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") - localpart = await self._generate_user_id() + localpart = await self.store.generate_user_id() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) @@ -510,18 +502,6 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - async def _generate_user_id(self): - if self._next_generated_user_id is None: - with await self._generate_user_id_linearizer.queue(()): - if self._next_generated_user_id is None: - self._next_generated_user_id = ( - await self.store.find_next_generated_user_id_localpart() - ) - - id = self._next_generated_user_id - self._next_generated_user_id += 1 - return str(id) - def check_registration_ratelimit(self, address): """A simple helper method to check whether the registration rate limit has been hit for a given IP address diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 950a84acd0..fb37d371ad 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -22,11 +22,12 @@ import logging import math import string from collections import OrderedDict -from typing import Tuple +from typing import Optional, Tuple from synapse.api.constants import ( EventTypes, JoinRules, + Membership, RoomCreationPreset, RoomEncryptionAlgorithms, ) @@ -43,9 +44,10 @@ from synapse.types import ( StateMap, StreamToken, UserID, + create_requester, ) from synapse.util import stringutils -from synapse.util.async_helpers import Linearizer +from synapse.util.async_helpers import Linearizer, maybe_awaitable from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client @@ -1089,3 +1091,205 @@ class RoomEventSource(object): def get_current_key_for_room(self, room_id): return self.store.get_room_events_max_id(room_id) + + +class RoomShutdownHandler(object): + + DEFAULT_MESSAGE = ( + "Sharing illegal content on this server is not permitted and rooms in" + " violation will be blocked." + ) + DEFAULT_ROOM_NAME = "Content Violation Notification" + + def __init__(self, hs): + self.hs = hs + self.room_member_handler = hs.get_room_member_handler() + self._room_creation_handler = hs.get_room_creation_handler() + self._replication = hs.get_replication_data_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self.state = hs.get_state_handler() + self.store = hs.get_datastore() + + async def shutdown_room( + self, + room_id: str, + requester_user_id: str, + new_room_user_id: Optional[str] = None, + new_room_name: Optional[str] = None, + message: Optional[str] = None, + block: bool = False, + ) -> dict: + """ + Shuts down a room. Moves all local users and room aliases automatically + to a new room if `new_room_user_id` is set. Otherwise local users only + leave the room without any information. + + The new room will be created with the user specified by the + `new_room_user_id` parameter as room administrator and will contain a + message explaining what happened. Users invited to the new room will + have power level `-10` by default, and thus be unable to speak. + + The local server will only have the power to move local user and room + aliases to the new room. Users on other servers will be unaffected. + + Args: + room_id: The ID of the room to shut down. + requester_user_id: + User who requested the action and put the room on the + blocking list. + new_room_user_id: + If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be + moved into that room. If not set, no new room will be created + and the users will just be removed from the old room. + new_room_name: + A string representing the name of the room that new users will + be invited to. Defaults to `Content Violation Notification` + message: + A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly + convey why the original room was shut down. + Defaults to `Sharing illegal content on this server is not + permitted and rooms in violation will be blocked.` + block: + If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Defaults to `false`. + + Returns: a dict containing the following keys: + kicked_users: An array of users (`user_id`) that were kicked. + failed_to_kick_users: + An array of users (`user_id`) that that were not kicked. + local_aliases: + An array of strings representing the local aliases that were + migrated from the old room to the new. + new_room_id: A string representing the room ID of the new room. + """ + + if not new_room_name: + new_room_name = self.DEFAULT_ROOM_NAME + if not message: + message = self.DEFAULT_MESSAGE + + if not RoomID.is_valid(room_id): + raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + + if not await self.store.get_room(room_id): + raise NotFoundError("Unknown room id %s" % (room_id,)) + + # This will work even if the room is already blocked, but that is + # desirable in case the first attempt at blocking the room failed below. + if block: + await self.store.block_room(room_id, requester_user_id) + + if new_room_user_id is not None: + if not self.hs.is_mine_id(new_room_user_id): + raise SynapseError( + 400, "User must be our own: %s" % (new_room_user_id,) + ) + + room_creator_requester = create_requester(new_room_user_id) + + info, stream_id = await self._room_creation_handler.create_room( + room_creator_requester, + config={ + "preset": RoomCreationPreset.PUBLIC_CHAT, + "name": new_room_name, + "power_level_content_override": {"users_default": -10}, + }, + ratelimit=False, + ) + new_room_id = info["room_id"] + + logger.info( + "Shutting down room %r, joining to new room: %r", room_id, new_room_id + ) + + # We now wait for the create room to come back in via replication so + # that we can assume that all the joins/invites have propogated before + # we try and auto join below. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) + else: + new_room_id = None + logger.info("Shutting down room %r", room_id) + + users = await self.state.get_current_users_in_room(room_id) + kicked_users = [] + failed_to_kick_users = [] + for user_id in users: + if not self.hs.is_mine_id(user_id): + continue + + logger.info("Kicking %r from %r...", user_id, room_id) + + try: + # Kick users from room + target_requester = create_requester(user_id) + _, stream_id = await self.room_member_handler.update_membership( + requester=target_requester, + target=target_requester.user, + room_id=room_id, + action=Membership.LEAVE, + content={}, + ratelimit=False, + require_consent=False, + ) + + # Wait for leave to come in over replication before trying to forget. + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) + + await self.room_member_handler.forget(target_requester.user, room_id) + + # Join users to new room + if new_room_user_id: + await self.room_member_handler.update_membership( + requester=target_requester, + target=target_requester.user, + room_id=new_room_id, + action=Membership.JOIN, + content={}, + ratelimit=False, + require_consent=False, + ) + + kicked_users.append(user_id) + except Exception: + logger.exception( + "Failed to leave old room and join new room for %r", user_id + ) + failed_to_kick_users.append(user_id) + + # Send message in new room and move aliases + if new_room_user_id: + await self.event_creation_handler.create_and_send_nonmember_event( + room_creator_requester, + { + "type": "m.room.message", + "content": {"body": message, "msgtype": "m.text"}, + "room_id": new_room_id, + "sender": new_room_user_id, + }, + ratelimit=False, + ) + + aliases_for_room = await maybe_awaitable( + self.store.get_aliases_for_room(room_id) + ) + + await self.store.update_aliases_for_room( + room_id, new_room_id, requester_user_id + ) + else: + aliases_for_room = [] + + return { + "kicked_users": kicked_users, + "failed_to_kick_users": failed_to_kick_users, + "local_aliases": aliases_for_room, + "new_room_id": new_room_id, + } diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index e3dbbcc052..0d678eee17 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py
@@ -20,12 +20,10 @@ from typing import Any, Dict, Optional import msgpack from unpaddedbase64 import decode_base64, encode_base64 -from twisted.internet import defer - from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, HttpResponseException from synapse.types import ThirdPartyInstanceID -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached from synapse.util.caches.response_cache import ResponseCache from ._base import BaseHandler @@ -48,7 +46,7 @@ class RoomListHandler(BaseHandler): hs, "remote_room_list", timeout_ms=30 * 1000 ) - def get_local_public_room_list( + async def get_local_public_room_list( self, limit=None, since_token=None, @@ -73,7 +71,7 @@ class RoomListHandler(BaseHandler): API """ if not self.enable_room_list_search: - return defer.succeed({"chunk": [], "total_room_count_estimate": 0}) + return {"chunk": [], "total_room_count_estimate": 0} logger.info( "Getting public room list: limit=%r, since=%r, search=%r, network=%r", @@ -88,7 +86,7 @@ class RoomListHandler(BaseHandler): # appservice specific lists. logger.info("Bypassing cache as search request.") - return self._get_public_room_list( + return await self._get_public_room_list( limit, since_token, search_filter, @@ -97,7 +95,7 @@ class RoomListHandler(BaseHandler): ) key = (limit, since_token, network_tuple) - return self.response_cache.wrap( + return await self.response_cache.wrap( key, self._get_public_room_list, limit, @@ -106,8 +104,7 @@ class RoomListHandler(BaseHandler): from_federation=from_federation, ) - @defer.inlineCallbacks - def _get_public_room_list( + async def _get_public_room_list( self, limit: Optional[int] = None, since_token: Optional[str] = None, @@ -146,7 +143,7 @@ class RoomListHandler(BaseHandler): # we request one more than wanted to see if there are more pages to come probing_limit = limit + 1 if limit is not None else None - results = yield self.store.get_largest_public_rooms( + results = await self.store.get_largest_public_rooms( network_tuple, search_filter, probing_limit, @@ -222,44 +219,44 @@ class RoomListHandler(BaseHandler): response["chunk"] = results - response["total_room_count_estimate"] = yield self.store.count_public_rooms( + response["total_room_count_estimate"] = await self.store.count_public_rooms( network_tuple, ignore_non_federatable=from_federation ) return response - @cachedInlineCallbacks(num_args=1, cache_context=True) - def generate_room_entry( + @cached(num_args=1, cache_context=True) + async def generate_room_entry( self, - room_id, - num_joined_users, + room_id: str, + num_joined_users: int, cache_context, - with_alias=True, - allow_private=False, - ): + with_alias: bool = True, + allow_private: bool = False, + ) -> Optional[dict]: """Returns the entry for a room Args: - room_id (str): The room's ID. - num_joined_users (int): Number of users in the room. + room_id: The room's ID. + num_joined_users: Number of users in the room. cache_context: Information for cached responses. - with_alias (bool): Whether to return the room's aliases in the result. - allow_private (bool): Whether invite-only rooms should be shown. + with_alias: Whether to return the room's aliases in the result. + allow_private: Whether invite-only rooms should be shown. Returns: - Deferred[dict|None]: Returns a room entry as a dictionary, or None if this + Returns a room entry as a dictionary, or None if this room was determined not to be shown publicly. """ result = {"room_id": room_id, "num_joined_members": num_joined_users} if with_alias: - aliases = yield self.store.get_aliases_for_room( + aliases = await self.store.get_aliases_for_room( room_id, on_invalidate=cache_context.invalidate ) if aliases: result["aliases"] = aliases - current_state_ids = yield self.store.get_current_state_ids( + current_state_ids = await self.store.get_current_state_ids( room_id, on_invalidate=cache_context.invalidate ) @@ -267,7 +264,7 @@ class RoomListHandler(BaseHandler): # We're not in the room, so may as well bail out here. return result - event_map = yield self.store.get_events( + event_map = await self.store.get_events( [ event_id for key, event_id in current_state_ids.items() @@ -337,8 +334,7 @@ class RoomListHandler(BaseHandler): return result - @defer.inlineCallbacks - def get_remote_public_room_list( + async def get_remote_public_room_list( self, server_name, limit=None, @@ -357,7 +353,7 @@ class RoomListHandler(BaseHandler): # to a locally-filtered search if we must. try: - res = yield self._get_remote_list_cached( + res = await self._get_remote_list_cached( server_name, limit=limit, since_token=since_token, @@ -382,7 +378,7 @@ class RoomListHandler(BaseHandler): limit = None since_token = None - res = yield self._get_remote_list_cached( + res = await self._get_remote_list_cached( server_name, limit=limit, since_token=since_token, @@ -401,7 +397,7 @@ class RoomListHandler(BaseHandler): return res - def _get_remote_list_cached( + async def _get_remote_list_cached( self, server_name, limit=None, @@ -413,7 +409,7 @@ class RoomListHandler(BaseHandler): repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search - return repl_layer.get_public_rooms( + return await repl_layer.get_public_rooms( server_name, limit=limit, since_token=since_token, @@ -429,7 +425,7 @@ class RoomListHandler(BaseHandler): include_all_networks, third_party_instance_id, ) - return self.remote_response_cache.wrap( + return await self.remote_response_cache.wrap( key, repl_layer.get_public_rooms, server_name, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 843c4ae010..0010f48577 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -110,7 +110,7 @@ class RoomMemberHandler(object): txn_id: Optional[str], requester: Requester, content: JsonDict, - ) -> Tuple[Optional[str], int]: + ) -> Tuple[str, int]: """ Rejects an out-of-band invite we have received from a remote server @@ -269,7 +269,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Tuple[Optional[str], int]: + ) -> Tuple[str, int]: key = (room_id,) as_id = object() @@ -319,7 +319,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Tuple[Optional[str], int]: + ) -> Tuple[str, int]: content_specified = bool(content) if content is None: content = {} @@ -1026,7 +1026,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): txn_id: Optional[str], requester: Requester, content: JsonDict, - ) -> Tuple[Optional[str], int]: + ) -> Tuple[str, int]: """ Rejects an out-of-band invite received from a remote user diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index ac03f15166..897338fd54 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py
@@ -67,7 +67,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): txn_id: Optional[str], requester: Requester, content: dict, - ) -> Tuple[Optional[str], int]: + ) -> Tuple[str, int]: """ Rejects an out-of-band invite received from a remote user diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f631c34f92..99dd4ee948 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -286,6 +286,7 @@ class SyncHandler(object): timeout, full_state, ) + logger.debug("Returning sync response for %s", user_id) return res async def _wait_for_sync_for_user( @@ -993,10 +994,14 @@ class SyncHandler(object): joined_room_ids=joined_room_ids, ) + logger.debug("Fetching account data") + account_data_by_room = await self._generate_sync_entry_for_account_data( sync_result_builder ) + logger.debug("Fetching room data") + res = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) @@ -1007,10 +1012,12 @@ class SyncHandler(object): since_token is None and sync_config.filter_collection.blocks_all_presence() ) if self.hs_config.use_presence and not block_all_presence_data: + logger.debug("Fetching presence data") await self._generate_sync_entry_for_presence( sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users ) + logger.debug("Fetching to-device data") await self._generate_sync_entry_for_to_device(sync_result_builder) device_lists = await self._generate_sync_entry_for_device_list( @@ -1021,6 +1028,7 @@ class SyncHandler(object): newly_left_users=newly_left_users, ) + logger.debug("Fetching OTK data") device_id = sync_config.device_id one_time_key_counts = {} # type: JsonDict if device_id: @@ -1028,6 +1036,7 @@ class SyncHandler(object): user_id, device_id ) + logger.debug("Fetching group data") await self._generate_sync_entry_for_groups(sync_result_builder) # debug for https://github.com/matrix-org/synapse/issues/4422 @@ -1038,6 +1047,7 @@ class SyncHandler(object): "Sync result for newly joined room %s: %r", room_id, joined_room ) + logger.debug("Sync response calculation complete") return SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, @@ -1410,8 +1420,9 @@ class SyncHandler(object): newly_joined_rooms = room_changes.newly_joined_rooms newly_left_rooms = room_changes.newly_left_rooms - def handle_room_entries(room_entry): - return self._generate_room_entry( + async def handle_room_entries(room_entry): + logger.debug("Generating room entry for %s", room_entry.room_id) + res = await self._generate_room_entry( sync_result_builder, ignored_users, room_entry, @@ -1420,6 +1431,8 @@ class SyncHandler(object): account_data=account_data_by_room.get(room_entry.room_id, {}), always_include=sync_result_builder.full_state, ) + logger.debug("Generated room entry for %s", room_entry.room_id) + return res await concurrently_execute(handle_room_entries, room_entries, 10) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 879c4c07c6..a86ac0150e 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -15,15 +15,19 @@ import logging from collections import namedtuple -from typing import List, Tuple +from typing import TYPE_CHECKING, List, Set, Tuple from synapse.api.errors import AuthError, SynapseError -from synapse.logging.context import run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.streams import TypingStream from synapse.types import UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -39,48 +43,48 @@ FEDERATION_TIMEOUT = 60 * 1000 FEDERATION_PING_INTERVAL = 40 * 1000 -class TypingHandler(object): - def __init__(self, hs): +class FollowerTypingHandler: + """A typing handler on a different process than the writer that is updated + via replication. + """ + + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.server_name = hs.config.server_name - self.auth = hs.get_auth() - self.is_mine_id = hs.is_mine_id - self.notifier = hs.get_notifier() - self.state = hs.get_state_handler() - - self.hs = hs - self.clock = hs.get_clock() - self.wheel_timer = WheelTimer(bucket_size=5000) + self.is_mine_id = hs.is_mine_id - self.federation = hs.get_federation_sender() + self.federation = None + if hs.should_send_federation(): + self.federation = hs.get_federation_sender() - hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) + if hs.config.worker.writers.typing != hs.get_instance_name(): + hs.get_federation_registry().register_instance_for_edu( + "m.typing", hs.config.worker.writers.typing, + ) - hs.get_distributor().observe("user_left_room", self.user_left_room) + # map room IDs to serial numbers + self._room_serials = {} + # map room IDs to sets of users currently typing + self._room_typing = {} - self._member_typing_until = {} # clock time we expect to stop self._member_last_federation_poke = {} - + self.wheel_timer = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 - self._reset() - - # caches which room_ids changed at which serials - self._typing_stream_change_cache = StreamChangeCache( - "TypingStreamChangeCache", self._latest_room_serial - ) self.clock.looping_call(self._handle_timeouts, 5000) def _reset(self): - """ - Reset the typing handler's data caches. + """Reset the typing handler's data caches. """ # map room IDs to serial numbers self._room_serials = {} # map room IDs to sets of users currently typing self._room_typing = {} + self._member_last_federation_poke = {} + self.wheel_timer = WheelTimer(bucket_size=5000) + def _handle_timeouts(self): logger.debug("Checking for typing timeouts") @@ -89,30 +93,140 @@ class TypingHandler(object): members = set(self.wheel_timer.fetch(now)) for member in members: - if not self.is_typing(member): - # Nothing to do if they're no longer typing - continue - - until = self._member_typing_until.get(member, None) - if not until or until <= now: - logger.info("Timing out typing for: %s", member.user_id) - self._stopped_typing(member) - continue - - # Check if we need to resend a keep alive over federation for this - # user. - if self.hs.is_mine_id(member.user_id): - last_fed_poke = self._member_last_federation_poke.get(member, None) - if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - run_in_background(self._push_remote, member=member, typing=True) - - # Add a paranoia timer to ensure that we always have a timer for - # each person typing. - self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) + self._handle_timeout_for_member(now, member) + + def _handle_timeout_for_member(self, now: int, member: RoomMember): + if not self.is_typing(member): + # Nothing to do if they're no longer typing + return + + # Check if we need to resend a keep alive over federation for this + # user. + if self.federation and self.is_mine_id(member.user_id): + last_fed_poke = self._member_last_federation_poke.get(member, None) + if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: + run_as_background_process( + "typing._push_remote", self._push_remote, member=member, typing=True + ) + + # Add a paranoia timer to ensure that we always have a timer for + # each person typing. + self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member): return member.user_id in self._room_typing.get(member.room_id, []) + async def _push_remote(self, member, typing): + if not self.federation: + return + + try: + users = await self.store.get_users_in_room(member.room_id) + self._member_last_federation_poke[member] = self.clock.time_msec() + + now = self.clock.time_msec() + self.wheel_timer.insert( + now=now, obj=member, then=now + FEDERATION_PING_INTERVAL + ) + + for domain in {get_domain_from_id(u) for u in users}: + if domain != self.server_name: + logger.debug("sending typing update to %s", domain) + self.federation.build_and_send_edu( + destination=domain, + edu_type="m.typing", + content={ + "room_id": member.room_id, + "user_id": member.user_id, + "typing": typing, + }, + key=member, + ) + except Exception: + logger.exception("Error pushing typing notif to remotes") + + def process_replication_rows( + self, token: int, rows: List[TypingStream.TypingStreamRow] + ): + """Should be called whenever we receive updates for typing stream. + """ + + if self._latest_room_serial > token: + # The master has gone backwards. To prevent inconsistent data, just + # clear everything. + self._reset() + + # Set the latest serial token to whatever the server gave us. + self._latest_room_serial = token + + for row in rows: + self._room_serials[row.room_id] = token + + prev_typing = set(self._room_typing.get(row.room_id, [])) + now_typing = set(row.user_ids) + self._room_typing[row.room_id] = row.user_ids + + run_as_background_process( + "_handle_change_in_typing", + self._handle_change_in_typing, + row.room_id, + prev_typing, + now_typing, + ) + + async def _handle_change_in_typing( + self, room_id: str, prev_typing: Set[str], now_typing: Set[str] + ): + """Process a change in typing of a room from replication, sending EDUs + for any local users. + """ + for user_id in now_typing - prev_typing: + if self.is_mine_id(user_id): + await self._push_remote(RoomMember(room_id, user_id), True) + + for user_id in prev_typing - now_typing: + if self.is_mine_id(user_id): + await self._push_remote(RoomMember(room_id, user_id), False) + + def get_current_token(self): + return self._latest_room_serial + + +class TypingWriterHandler(FollowerTypingHandler): + def __init__(self, hs): + super().__init__(hs) + + assert hs.config.worker.writers.typing == hs.get_instance_name() + + self.auth = hs.get_auth() + self.notifier = hs.get_notifier() + + self.hs = hs + + hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) + + hs.get_distributor().observe("user_left_room", self.user_left_room) + + self._member_typing_until = {} # clock time we expect to stop + + # caches which room_ids changed at which serials + self._typing_stream_change_cache = StreamChangeCache( + "TypingStreamChangeCache", self._latest_room_serial + ) + + def _handle_timeout_for_member(self, now: int, member: RoomMember): + super()._handle_timeout_for_member(now, member) + + if not self.is_typing(member): + # Nothing to do if they're no longer typing + return + + until = self._member_typing_until.get(member, None) + if not until or until <= now: + logger.info("Timing out typing for: %s", member.user_id) + self._stopped_typing(member) + return + async def started_typing(self, target_user, auth_user, room_id, timeout): target_user_id = target_user.to_string() auth_user_id = auth_user.to_string() @@ -179,35 +293,11 @@ class TypingHandler(object): def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. - run_in_background(self._push_remote, member, typing) - - self._push_update_local(member=member, typing=typing) - - async def _push_remote(self, member, typing): - try: - users = await self.state.get_current_users_in_room(member.room_id) - self._member_last_federation_poke[member] = self.clock.time_msec() - - now = self.clock.time_msec() - self.wheel_timer.insert( - now=now, obj=member, then=now + FEDERATION_PING_INTERVAL + run_as_background_process( + "typing._push_remote", self._push_remote, member, typing ) - for domain in {get_domain_from_id(u) for u in users}: - if domain != self.server_name: - logger.debug("sending typing update to %s", domain) - self.federation.build_and_send_edu( - destination=domain, - edu_type="m.typing", - content={ - "room_id": member.room_id, - "user_id": member.user_id, - "typing": typing, - }, - key=member, - ) - except Exception: - logger.exception("Error pushing typing notif to remotes") + self._push_update_local(member=member, typing=typing) async def _recv_edu(self, origin, content): room_id = content["room_id"] @@ -224,7 +314,7 @@ class TypingHandler(object): ) return - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) domains = {get_domain_from_id(u) for u in users} if self.server_name in domains: @@ -304,8 +394,11 @@ class TypingHandler(object): return rows, current_id, limited - def get_current_token(self): - return self._latest_room_serial + def process_replication_rows( + self, token: int, rows: List[TypingStream.TypingStreamRow] + ): + # The writing process should never get updates from replication. + raise Exception("Typing writer instance got typing info over replication") class TypingNotificationEventSource(object): diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 8b24a73319..a011e9fe29 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py
@@ -12,11 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging +from typing import Any from canonicaljson import json -from twisted.internet import defer from twisted.web.client import PartialDownloadError from synapse.api.constants import LoginType @@ -32,25 +33,25 @@ class UserInteractiveAuthChecker: def __init__(self, hs): pass - def is_enabled(self): + def is_enabled(self) -> bool: """Check if the configuration of the homeserver allows this checker to work Returns: - bool: True if this login type is enabled. + True if this login type is enabled. """ - def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: """Given the authentication dict from the client, attempt to check this step Args: - authdict (dict): authentication dictionary from the client - clientip (str): The IP address of the client. + authdict: authentication dictionary from the client + clientip: The IP address of the client. Raises: SynapseError if authentication failed Returns: - Deferred: the result of authentication (to pass back to the client?) + The result of authentication (to pass back to the client?) """ raise NotImplementedError() @@ -61,8 +62,8 @@ class DummyAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return True - def check_auth(self, authdict, clientip): - return defer.succeed(True) + async def check_auth(self, authdict, clientip): + return True class TermsAuthChecker(UserInteractiveAuthChecker): @@ -71,8 +72,8 @@ class TermsAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return True - def check_auth(self, authdict, clientip): - return defer.succeed(True) + async def check_auth(self, authdict, clientip): + return True class RecaptchaAuthChecker(UserInteractiveAuthChecker): @@ -88,8 +89,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return self._enabled - @defer.inlineCallbacks - def check_auth(self, authdict, clientip): + async def check_auth(self, authdict, clientip): try: user_response = authdict["response"] except KeyError: @@ -106,7 +106,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): # TODO: get this from the homeserver rather than creating a new one for # each request try: - resp_body = yield self._http_client.post_urlencoded_get_json( + resp_body = await self._http_client.post_urlencoded_get_json( self._url, args={ "secret": self._secret, @@ -117,7 +117,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): except PartialDownloadError as pde: # Twisted is silly data = pde.response - resp_body = json.loads(data) + resp_body = json.loads(data.decode("utf-8")) if "success" in resp_body: # Note that we do NOT check the hostname here: we explicitly @@ -218,8 +218,8 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec ThreepidBehaviour.LOCAL, ) - def check_auth(self, authdict, clientip): - return defer.ensureDeferred(self._check_threepid("email", authdict)) + async def check_auth(self, authdict, clientip): + return await self._check_threepid("email", authdict) class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): @@ -232,8 +232,8 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): def is_enabled(self): return bool(self.hs.config.account_threepid_delegate_msisdn) - def check_auth(self, authdict, clientip): - return defer.ensureDeferred(self._check_threepid("msisdn", authdict)) + async def check_auth(self, authdict, clientip): + return await self._check_threepid("msisdn", authdict) INTERACTIVE_AUTH_CHECKERS = [ diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8743e9839d..6bc51202cd 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -31,6 +31,7 @@ from twisted.internet.interfaces import ( IReactorPluggableNameResolver, IResolutionReceiver, ) +from twisted.internet.task import Cooperator from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone from twisted.web.client import Agent, HTTPConnectionPool, readBody @@ -69,6 +70,21 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): return False +_EPSILON = 0.00000001 + + +def _make_scheduler(reactor): + """Makes a schedular suitable for a Cooperator using the given reactor. + + (This is effectively just a copy from `twisted.internet.task`) + """ + + def _scheduler(x): + return reactor.callLater(_EPSILON, x) + + return _scheduler + + class IPBlacklistingResolver(object): """ A proxy for reactor.nameResolver which only produces non-blacklisted IP @@ -212,6 +228,10 @@ class SimpleHttpClient(object): if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) + # We use this for our body producers to ensure that they use the correct + # reactor. + self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor())) + self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: @@ -292,7 +312,9 @@ class SimpleHttpClient(object): try: body_producer = None if data is not None: - body_producer = QuieterFileBodyProducer(BytesIO(data)) + body_producer = QuieterFileBodyProducer( + BytesIO(data), cooperator=self._cooperator, + ) request_deferred = treq.request( method, @@ -371,7 +393,7 @@ class SimpleHttpClient(object): body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body) + return json.loads(body.decode("utf-8")) else: raise HttpResponseException(response.code, response.phrase, body) @@ -412,7 +434,7 @@ class SimpleHttpClient(object): body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body) + return json.loads(body.decode("utf-8")) else: raise HttpResponseException(response.code, response.phrase, body) @@ -441,7 +463,7 @@ class SimpleHttpClient(object): actual_headers.update(headers) body = yield self.get_raw(uri, args, headers=headers) - return json.loads(body) + return json.loads(body.decode("utf-8")) @defer.inlineCallbacks def put_json(self, uri, json_body, args={}, headers=None): @@ -485,7 +507,7 @@ class SimpleHttpClient(object): body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body) + return json.loads(body.decode("utf-8")) else: raise HttpResponseException(response.code, response.phrase, body) @@ -503,7 +525,7 @@ class SimpleHttpClient(object): header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the - HTTP body at text. + HTTP body as bytes. Raises: HttpResponseException on a non-2xx HTTP response. """ diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index c5fc746f2f..0c02648015 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py
@@ -15,6 +15,7 @@ import logging import urllib +from typing import List from netaddr import AddrFormatError, IPAddress from zope.interface import implementer @@ -236,11 +237,10 @@ class MatrixHostnameEndpoint(object): return run_in_background(self._do_connect, protocol_factory) - @defer.inlineCallbacks - def _do_connect(self, protocol_factory): + async def _do_connect(self, protocol_factory): first_exception = None - server_list = yield self._resolve_server() + server_list = await self._resolve_server() for server in server_list: host = server.host @@ -251,7 +251,7 @@ class MatrixHostnameEndpoint(object): endpoint = HostnameEndpoint(self._reactor, host, port) if self._tls_options: endpoint = wrapClientTLS(self._tls_options, endpoint) - result = yield make_deferred_yieldable( + result = await make_deferred_yieldable( endpoint.connect(protocol_factory) ) @@ -271,13 +271,9 @@ class MatrixHostnameEndpoint(object): # to try and if that doesn't work then we'll have an exception. raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,)) - @defer.inlineCallbacks - def _resolve_server(self): + async def _resolve_server(self) -> List[Server]: """Resolves the server name to a list of hosts and ports to attempt to connect to. - - Returns: - Deferred[list[Server]] """ if self._parsed_uri.scheme != b"matrix": @@ -298,7 +294,7 @@ class MatrixHostnameEndpoint(object): if port or _is_ip_literal(host): return [Server(host, port or 8448)] - server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host) + server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host) if server_list: return server_list diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 021b233a7d..2ede90a9b1 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py
@@ -17,10 +17,10 @@ import logging import random import time +from typing import List import attr -from twisted.internet import defer from twisted.internet.error import ConnectError from twisted.names import client, dns from twisted.names.error import DNSNameError, DomainError @@ -113,16 +113,14 @@ class SrvResolver(object): self._cache = cache self._get_time = get_time - @defer.inlineCallbacks - def resolve_service(self, service_name): + async def resolve_service(self, service_name: bytes) -> List[Server]: """Look up a SRV record Args: service_name (bytes): record to look up Returns: - Deferred[list[Server]]: - a list of the SRV records, or an empty list if none found + a list of the SRV records, or an empty list if none found """ now = int(self._get_time()) @@ -136,7 +134,7 @@ class SrvResolver(object): return _sort_server_list(servers) try: - answers, _, _ = yield make_deferred_yieldable( + answers, _, _ = await make_deferred_yieldable( self._dns_client.lookupService(service_name) ) except DNSNameError: diff --git a/synapse/http/server.py b/synapse/http/server.py
index 2b35f86066..8e003689c4 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -217,7 +217,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): return NOT_DONE_YET @wrap_async_request_handler - async def _async_render_wrapper(self, request): + async def _async_render_wrapper(self, request: SynapseRequest): """This is a wrapper that delegates to `_async_render` and handles exceptions, return values, metrics, etc. """ @@ -237,7 +237,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): f = failure.Failure() self._send_error_response(f, request) - async def _async_render(self, request): + async def _async_render(self, request: Request): """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if no appropriate method exists. Can be overriden in sub classes for different routing. @@ -278,7 +278,7 @@ class DirectServeJsonResource(_AsyncResource): """ def _send_response( - self, request, code, response_object, + self, request: Request, code: int, response_object: Any, ): """Implements _AsyncResource._send_response """ @@ -507,14 +507,29 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect): def respond_with_json( - request, - code, - json_object, - send_cors=False, - response_code_message=None, - pretty_print=False, - canonical_json=True, + request: Request, + code: int, + json_object: Any, + send_cors: bool = False, + pretty_print: bool = False, + canonical_json: bool = True, ): + """Sends encoded JSON in response to the given request. + + Args: + request: The http request to respond to. + code: The HTTP response code. + json_object: The object to serialize to JSON. + send_cors: Whether to send Cross-Origin Resource Sharing headers + https://fetch.spec.whatwg.org/#http-cors-protocol + pretty_print: Whether to include indentation and line-breaks in the + resulting JSON bytes. + canonical_json: Whether to use the canonicaljson algorithm when encoding + the JSON bytes. + + Returns: + twisted.web.server.NOT_DONE_YET if the request is still active. + """ # could alternatively use request.notifyFinish() and flip a flag when # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. @@ -522,7 +537,7 @@ def respond_with_json( logger.warning( "Not sending response to request %s, already disconnected.", request ) - return + return None if pretty_print: json_bytes = encode_pretty_printed_json(json_object) + b"\n" @@ -533,30 +548,26 @@ def respond_with_json( else: json_bytes = json.dumps(json_object).encode("utf-8") - return respond_with_json_bytes( - request, - code, - json_bytes, - send_cors=send_cors, - response_code_message=response_code_message, - ) + return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors) def respond_with_json_bytes( - request, code, json_bytes, send_cors=False, response_code_message=None + request: Request, code: int, json_bytes: bytes, send_cors: bool = False, ): """Sends encoded JSON in response to the given request. Args: - request (twisted.web.http.Request): The http request to respond to. - code (int): The HTTP response code. - json_bytes (bytes): The json bytes to use as the response body. - send_cors (bool): Whether to send Cross-Origin Resource Sharing headers + request: The http request to respond to. + code: The HTTP response code. + json_bytes: The json bytes to use as the response body. + send_cors: Whether to send Cross-Origin Resource Sharing headers https://fetch.spec.whatwg.org/#http-cors-protocol + Returns: - twisted.web.server.NOT_DONE_YET""" + twisted.web.server.NOT_DONE_YET if the request is still active. + """ - request.setResponseCode(code, message=response_code_message) + request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") @@ -564,8 +575,8 @@ def respond_with_json_bytes( if send_cors: set_cors_headers(request) - # todo: we can almost certainly avoid this copy and encode the json straight into - # the bytesIO, but it would involve faffing around with string->bytes wrappers. + # note that this is zero-copy (the bytesio shares a copy-on-write buffer with + # the original `bytes`). bytes_io = BytesIO(json_bytes) producer = NoRangeStaticProducer(request, bytes_io) @@ -573,12 +584,12 @@ def respond_with_json_bytes( return NOT_DONE_YET -def set_cors_headers(request): - """Set the CORs headers so that javascript running in a web browsers can +def set_cors_headers(request: Request): + """Set the CORS headers so that javascript running in a web browsers can use this API Args: - request (twisted.web.http.Request): The http request to add CORs to. + request: The http request to add CORS to. """ request.setHeader(b"Access-Control-Allow-Origin", b"*") request.setHeader( @@ -643,7 +654,7 @@ def set_clickjacking_protection_headers(request: Request): request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") -def finish_request(request): +def finish_request(request: Request): """ Finish writing the response to the request. Twisted throws a RuntimeException if the connection closed before the @@ -662,7 +673,7 @@ def finish_request(request): logger.info("Connection disconnected before response was written: %r", e) -def _request_user_agent_is_curl(request): +def _request_user_agent_is_curl(request: Request) -> bool: user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[]) for user_agent in user_agents: if b"curl" in user_agent: diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 13fcb408a6..a34e5ead88 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py
@@ -214,16 +214,8 @@ def parse_json_value_from_request(request, allow_empty_body=False): if not content_bytes and allow_empty_body: return None - # Decode to Unicode so that simplejson will return Unicode strings on - # Python 2 try: - content_unicode = content_bytes.decode("utf8") - except UnicodeDecodeError: - logger.warning("Unable to decode UTF-8") - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) - - try: - content = json.loads(content_unicode) + content = json.loads(content_bytes.decode("utf-8")) except Exception as e: logger.warning("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) diff --git a/synapse/http/site.py b/synapse/http/site.py
index cbc37eac6e..6f3b2258cc 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py
@@ -215,9 +215,7 @@ class SynapseRequest(Request): # It's useful to log it here so that we can get an idea of when # the client disconnects. with PreserveLoggingContext(self.logcontext): - logger.warning( - "Error processing request %r: %s %s", self, reason.type, reason.value - ) + logger.info("Connection from client lost before response was sent") if not self._is_processing: self._finished_processing() diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 8b9c4e38bd..cbeeb870cb 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py
@@ -566,36 +566,33 @@ class LoggingContextFilter(logging.Filter): return True -class PreserveLoggingContext(object): - """Captures the current logging context and restores it when the scope is - exited. Used to restore the context after a function using - @defer.inlineCallbacks is resumed by a callback from the reactor.""" +class PreserveLoggingContext: + """Context manager which replaces the logging context - __slots__ = ["current_context", "new_context", "has_parent"] + The previous logging context is restored on exit.""" + + __slots__ = ["_old_context", "_new_context"] def __init__( self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT ) -> None: - self.new_context = new_context + self._new_context = new_context def __enter__(self) -> None: - """Captures the current logging context""" - self.current_context = set_current_context(self.new_context) - - if self.current_context: - self.has_parent = self.current_context.previous_context is not None + self._old_context = set_current_context(self._new_context) def __exit__(self, type, value, traceback) -> None: - """Restores the current logging context""" - context = set_current_context(self.current_context) + context = set_current_context(self._old_context) - if context != self.new_context: + if context != self._new_context: if not context: - logger.warning("Expected logging context %s was lost", self.new_context) + logger.warning( + "Expected logging context %s was lost", self._new_context + ) else: logger.warning( "Expected logging context %s but found %s", - self.new_context, + self._new_context, context, ) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index c6c0e623c1..2101517575 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py
@@ -733,37 +733,54 @@ def trace(func=None, opname=None): _opname = opname if opname else func.__name__ - @wraps(func) - def _trace_inner(*args, **kwargs): - if opentracing is None: - return func(*args, **kwargs) + if inspect.iscoroutinefunction(func): - scope = start_active_span(_opname) - scope.__enter__() + @wraps(func) + async def _trace_inner(*args, **kwargs): + if opentracing is None: + return await func(*args, **kwargs) - try: - result = func(*args, **kwargs) - if isinstance(result, defer.Deferred): + with start_active_span(_opname) as scope: + try: + return await func(*args, **kwargs) + except Exception: + scope.span.set_tag(tags.ERROR, True) + raise - def call_back(result): - scope.__exit__(None, None, None) - return result + else: + # The other case here handles both sync functions and those + # decorated with inlineDeferred. + @wraps(func) + def _trace_inner(*args, **kwargs): + if opentracing is None: + return func(*args, **kwargs) - def err_back(result): - scope.span.set_tag(tags.ERROR, True) - scope.__exit__(None, None, None) - return result + scope = start_active_span(_opname) + scope.__enter__() + + try: + result = func(*args, **kwargs) + if isinstance(result, defer.Deferred): + + def call_back(result): + scope.__exit__(None, None, None) + return result - result.addCallbacks(call_back, err_back) + def err_back(result): + scope.span.set_tag(tags.ERROR, True) + scope.__exit__(None, None, None) + return result - else: - scope.__exit__(None, None, None) + result.addCallbacks(call_back, err_back) + + else: + scope.__exit__(None, None, None) - return result + return result - except Exception as e: - scope.__exit__(type(e), None, e.__traceback__) - raise + except Exception as e: + scope.__exit__(type(e), None, e.__traceback__) + raise return _trace_inner diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index 99049bb5d8..fea774e2e5 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py
@@ -14,9 +14,7 @@ # limitations under the License. -import inspect import logging -import time from functools import wraps from inspect import getcallargs @@ -74,127 +72,3 @@ def log_function(f): wrapped.__name__ = func_name return wrapped - - -def time_function(f): - func_name = f.__name__ - - @wraps(f) - def wrapped(*args, **kwargs): - global _TIME_FUNC_ID - id = _TIME_FUNC_ID - _TIME_FUNC_ID += 1 - - start = time.clock() - - try: - _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id)) - - r = f(*args, **kwargs) - finally: - end = time.clock() - _log_debug_as_f( - f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start) - ) - - return r - - return wrapped - - -def trace_function(f): - func_name = f.__name__ - linenum = f.func_code.co_firstlineno - pathname = f.func_code.co_filename - - @wraps(f) - def wrapped(*args, **kwargs): - name = f.__module__ - logger = logging.getLogger(name) - level = logging.DEBUG - - frame = inspect.currentframe() - if frame is None: - raise Exception("Can't get current frame!") - - s = frame.f_back - - to_print = [ - "\t%s:%s %s. Args: args=%s, kwargs=%s" - % (pathname, linenum, func_name, args, kwargs) - ] - while s: - if True or s.f_globals["__name__"].startswith("synapse"): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - to_print.append( - "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string) - ) - - s = s.f_back - - msg = "\nTraceback for %s:\n" % (func_name,) + "\n".join(to_print) - - record = logging.LogRecord( - name=name, - level=level, - pathname=pathname, - lineno=lineno, - msg=msg, - args=(), - exc_info=None, - ) - - logger.handle(record) - - return f(*args, **kwargs) - - wrapped.__name__ = func_name - return wrapped - - -def get_previous_frames(): - - frame = inspect.currentframe() - if frame is None: - raise Exception("Can't get current frame!") - - s = frame.f_back.f_back - to_return = [] - while s: - if s.f_globals["__name__"].startswith("synapse"): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - to_return.append( - "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string) - ) - - s = s.f_back - - return ", ".join(to_return) - - -def get_previous_frame(ignore=[]): - frame = inspect.currentframe() - if frame is None: - raise Exception("Can't get current frame!") - s = frame.f_back.f_back - - while s: - if s.f_globals["__name__"].startswith("synapse"): - if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - return "{{ %s:%d %s - Args: %s }}" % ( - filename, - lineno, - function, - args_string, - ) - - s = s.f_back - - return None diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 13785038ad..a9269196b3 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Set from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer +from twisted.python.failure import Failure from synapse.logging.context import LoggingContext, PreserveLoggingContext @@ -212,7 +213,14 @@ def run_as_background_process(desc, func, *args, **kwargs): return (yield result) except Exception: - logger.exception("Background process '%s' threw an exception", desc) + # failure.Failure() fishes the original Failure out of our stack, and + # thus gives us a sensible stack trace. + f = Failure() + logger.error( + "Background process '%s' threw an exception", + desc, + exc_info=(f.type, f.value, f.getTracebackObject()), + ) finally: _background_process_in_flight_count.labels(desc).dec() diff --git a/synapse/notifier.py b/synapse/notifier.py
index 87c120a59c..bd41f77852 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py
@@ -83,7 +83,7 @@ class _NotifierUserStream(object): self.current_token = current_token # The last token for which we should wake up any streams that have a - # token that comes before it. This gets updated everytime we get poked. + # token that comes before it. This gets updated every time we get poked. # We start it at the current token since if we get any streams # that have a token from before we have no idea whether they should be # woken up or not, so lets just wake them up. diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index dda560b2c2..af117fddf9 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py
@@ -27,6 +27,7 @@ import jinja2 from synapse.api.constants import EventTypes from synapse.api.errors import StoreError +from synapse.config.emailconfig import EmailSubjectConfig from synapse.logging.context import make_deferred_yieldable from synapse.push.presentable_names import ( calculate_room_name, @@ -42,23 +43,6 @@ logger = logging.getLogger(__name__) T = TypeVar("T") -MESSAGE_FROM_PERSON_IN_ROOM = ( - "You have a message on %(app)s from %(person)s in the %(room)s room..." -) -MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..." -MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..." -MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..." -MESSAGES_IN_ROOM_AND_OTHERS = ( - "You have messages on %(app)s in the %(room)s room and others..." -) -MESSAGES_FROM_PERSON_AND_OTHERS = ( - "You have messages on %(app)s from %(person)s and others..." -) -INVITE_FROM_PERSON_TO_ROOM = ( - "%(person)s has invited you to join the %(room)s room on %(app)s..." -) -INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..." - CONTEXT_BEFORE = 1 CONTEXT_AFTER = 1 @@ -121,6 +105,7 @@ class Mailer(object): self.state_handler = self.hs.get_state_handler() self.storage = hs.get_storage() self.app_name = app_name + self.email_subjects = hs.config.email_subjects # type: EmailSubjectConfig logger.info("Created Mailer for app_name %s" % app_name) @@ -147,7 +132,8 @@ class Mailer(object): await self.send_email( email_address, - "[%s] Password Reset" % self.hs.config.server_name, + self.email_subjects.password_reset + % {"server_name": self.hs.config.server_name}, template_vars, ) @@ -174,7 +160,8 @@ class Mailer(object): await self.send_email( email_address, - "[%s] Register your Email Address" % self.hs.config.server_name, + self.email_subjects.email_validation + % {"server_name": self.hs.config.server_name}, template_vars, ) @@ -202,7 +189,8 @@ class Mailer(object): await self.send_email( email_address, - "[%s] Validate Your Email" % self.hs.config.server_name, + self.email_subjects.email_validation + % {"server_name": self.hs.config.server_name}, template_vars, ) @@ -269,16 +257,13 @@ class Mailer(object): user_id, app_id, email_address ), "summary_text": summary_text, - "app_name": self.app_name, "rooms": rooms, "reason": reason, } - await self.send_email( - email_address, "[%s] %s" % (self.app_name, summary_text), template_vars - ) + await self.send_email(email_address, summary_text, template_vars) - async def send_email(self, email_address, subject, template_vars): + async def send_email(self, email_address, subject, extra_template_vars): """Send an email with the given information and template text""" try: from_string = self.hs.config.email_notif_from % {"app": self.app_name} @@ -291,6 +276,13 @@ class Mailer(object): if raw_to == "": raise RuntimeError("Invalid 'to' address") + template_vars = { + "app_name": self.app_name, + "server_name": self.hs.config.server.server_name, + } + + template_vars.update(extra_template_vars) + html_text = self.template_html.render(**template_vars) html_part = MIMEText(html_text, "html", "utf8") @@ -476,12 +468,12 @@ class Mailer(object): inviter_name = name_from_member_event(inviter_member_event) if room_name is None: - return INVITE_FROM_PERSON % { + return self.email_subjects.invite_from_person % { "person": inviter_name, "app": self.app_name, } else: - return INVITE_FROM_PERSON_TO_ROOM % { + return self.email_subjects.invite_from_person_to_room % { "person": inviter_name, "room": room_name, "app": self.app_name, @@ -499,13 +491,13 @@ class Mailer(object): sender_name = name_from_member_event(state_event) if sender_name is not None and room_name is not None: - return MESSAGE_FROM_PERSON_IN_ROOM % { + return self.email_subjects.message_from_person_in_room % { "person": sender_name, "room": room_name, "app": self.app_name, } elif sender_name is not None: - return MESSAGE_FROM_PERSON % { + return self.email_subjects.message_from_person % { "person": sender_name, "app": self.app_name, } @@ -513,7 +505,10 @@ class Mailer(object): # There's more than one notification for this room, so just # say there are several if room_name is not None: - return MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name} + return self.email_subjects.messages_in_room % { + "room": room_name, + "app": self.app_name, + } else: # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" @@ -531,7 +526,7 @@ class Mailer(object): ] ) - return MESSAGES_FROM_PERSON % { + return self.email_subjects.messages_from_person % { "person": descriptor_from_member_events(member_events.values()), "app": self.app_name, } @@ -540,7 +535,7 @@ class Mailer(object): # ...but we still refer to the 'reason' room which triggered the mail if reason["room_name"] is not None: - return MESSAGES_IN_ROOM_AND_OTHERS % { + return self.email_subjects.messages_in_room_and_others % { "room": reason["room_name"], "app": self.app_name, } @@ -560,7 +555,7 @@ class Mailer(object): [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids] ) - return MESSAGES_FROM_PERSON_AND_OTHERS % { + return self.email_subjects.messages_from_person_and_others % { "person": descriptor_from_member_events(member_events.values()), "app": self.app_name, } diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index f6a5458681..2456f12f46 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py
@@ -15,13 +15,12 @@ # limitations under the License. import logging -from collections import defaultdict -from threading import Lock -from typing import Dict, Tuple, Union +from typing import TYPE_CHECKING, Dict, Union + +from prometheus_client import Gauge from twisted.internet import defer -from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher @@ -29,9 +28,18 @@ from synapse.push.httppusher import HttpPusher from synapse.push.pusher import PusherFactory from synapse.util.async_helpers import concurrently_execute +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) +synapse_pushers = Gauge( + "synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"] +) + + class PusherPool: """ The pusher pool. This is responsible for dispatching notifications of new events to @@ -47,36 +55,20 @@ class PusherPool: Pusher.on_new_receipts are not expected to return deferreds. """ - def __init__(self, _hs): - self.hs = _hs - self.pusher_factory = PusherFactory(_hs) - self._should_start_pushers = _hs.config.start_pushers + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.pusher_factory = PusherFactory(hs) + self._should_start_pushers = hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() + # We shard the handling of push notifications by user ID. + self._pusher_shard_config = hs.config.push.pusher_shard_config + self._instance_name = hs.get_instance_name() + # map from user id to app_id:pushkey to pusher self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] - # a lock for the pushers dict, since `count_pushers` is called from an different - # and we otherwise get concurrent modification errors - self._pushers_lock = Lock() - - def count_pushers(): - results = defaultdict(int) # type: Dict[Tuple[str, str], int] - with self._pushers_lock: - for pushers in self.pushers.values(): - for pusher in pushers.values(): - k = (type(pusher).__name__, pusher.app_id) - results[k] += 1 - return results - - LaterGauge( - name="synapse_pushers", - desc="the number of active pushers", - labels=["kind", "app_id"], - caller=count_pushers, - ) - def start(self): """Starts the pushers off in a background process. """ @@ -104,6 +96,7 @@ class PusherPool: Returns: Deferred[EmailPusher|HttpPusher] """ + time_now_msec = self.clock.time_msec() # we try to create the pusher just to validate the config: it @@ -176,6 +169,9 @@ class PusherPool: access_tokens (Iterable[int]): access token *ids* to remove pushers for """ + if not self._pusher_shard_config.should_handle(self._instance_name, user_id): + return + tokens = set(access_tokens) for p in (yield self.store.get_pushers_by_user_id(user_id)): if p["access_token"] in tokens: @@ -237,6 +233,9 @@ class PusherPool: if not self._should_start_pushers: return + if not self._pusher_shard_config.should_handle(self._instance_name, user_id): + return + resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None @@ -275,6 +274,11 @@ class PusherPool: Returns: Deferred[EmailPusher|HttpPusher] """ + if not self._pusher_shard_config.should_handle( + self._instance_name, pusherdict["user_name"] + ): + return + try: p = self.pusher_factory.create_pusher(pusherdict) except PusherConfigException as e: @@ -298,11 +302,12 @@ class PusherPool: appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) - with self._pushers_lock: - byuser = self.pushers.setdefault(pusherdict["user_name"], {}) - if appid_pushkey in byuser: - byuser[appid_pushkey].on_stop() - byuser[appid_pushkey] = p + byuser = self.pushers.setdefault(pusherdict["user_name"], {}) + if appid_pushkey in byuser: + byuser[appid_pushkey].on_stop() + byuser[appid_pushkey] = p + + synapse_pushers.labels(type(p).__name__, p.app_id).inc() # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to @@ -330,9 +335,10 @@ class PusherPool: if appid_pushkey in byuser: logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) - byuser[appid_pushkey].on_stop() - with self._pushers_lock: - del byuser[appid_pushkey] + pusher = byuser.pop(appid_pushkey) + pusher.on_stop() + + synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() yield self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 0843d28d4b..fb0dd04f88 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py
@@ -92,11 +92,11 @@ class ReplicationEndpoint(object): # assert here that sub classes don't try and use the name. assert ( "instance_name" not in self.PATH_ARGS - ), "`instance_name` is a reserved paramater name" + ), "`instance_name` is a reserved parameter name" assert ( "instance_name" not in signature(self.__class__._serialize_payload).parameters - ), "`instance_name` is a reserved paramater name" + ), "`instance_name` is a reserved parameter name" assert self.METHOD in ("PUT", "POST", "GET") diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index bd394f6b00..a8a16dbc71 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -26,7 +26,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, database: Database, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_max_stream_id", "stream_id" + db_conn, "device_inbox", "stream_id" ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
index 523a1358d4..1b8718b11d 100644 --- a/synapse/replication/tcp/__init__.py +++ b/synapse/replication/tcp/__init__.py
@@ -25,7 +25,7 @@ Structure of the module: * command.py - the definitions of all the valid commands * protocol.py - the TCP protocol classes * resource.py - handles streaming stream updates to replications - * streams/ - the definitons of all the valid streams + * streams/ - the definitions of all the valid streams The general interaction of the classes are: diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 0f453ff0a8..f33801f883 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py
@@ -47,7 +47,7 @@ class Command(metaclass=abc.ABCMeta): @abc.abstractmethod def to_line(self) -> str: - """Serialises the comamnd for the wire. Does not include the command + """Serialises the command for the wire. Does not include the command prefix. """ @@ -293,20 +293,22 @@ class FederationAckCommand(Command): Format:: - FEDERATION_ACK <token> + FEDERATION_ACK <instance_name> <token> """ NAME = "FEDERATION_ACK" - def __init__(self, token): + def __init__(self, instance_name, token): + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - return cls(int(line)) + instance_name, token = line.split(" ") + return cls(instance_name, int(token)) def to_line(self): - return str(self.token) + return "%s %s" % (self.instance_name, self.token) class RemovePusherCommand(Command): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 55b3b79008..1de590bba2 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -14,9 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from prometheus_client import Counter +from typing_extensions import Deque from twisted.internet.protocol import ReconnectingClientFactory @@ -42,8 +54,8 @@ from synapse.replication.tcp.streams import ( EventsStream, FederationStream, Stream, + TypingStream, ) -from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) @@ -55,12 +67,16 @@ inbound_rdata_count = Counter( user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter( - "synapse_replication_tcp_resource_invalidate_cache", "" -) + user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") +# the type of the entries in _command_queues_by_stream +_StreamCommandQueue = Deque[ + Tuple[Union[RdataCommand, PositionCommand], AbstractConnection] +] + + class ReplicationCommandHandler: """Handles incoming commands from replication as well as sending commands back out to connections. @@ -96,6 +112,14 @@ class ReplicationCommandHandler: continue + if isinstance(stream, TypingStream): + # Only add TypingStream as a source on the instance in charge of + # typing. + if hs.config.worker.writers.typing == hs.get_instance_name(): + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue @@ -107,10 +131,6 @@ class ReplicationCommandHandler: self._streams_to_replicate.append(stream) - self._position_linearizer = Linearizer( - "replication_position", clock=self._clock - ) - # Map of stream name to batched updates. See RdataCommand for info on # how batching works. self._pending_batches = {} # type: Dict[str, List[Any]] @@ -122,10 +142,6 @@ class ReplicationCommandHandler: # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] - # For each connection, the incoming stream names that are coming from - # that connection. - self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] - LaterGauge( "synapse_replication_tcp_resource_total_connections", "", @@ -133,6 +149,32 @@ class ReplicationCommandHandler: lambda: len(self._connections), ) + # When POSITION or RDATA commands arrive, we stick them in a queue and process + # them in order in a separate background process. + + # the streams which are currently being processed by _unsafe_process_stream + self._processing_streams = set() # type: Set[str] + + # for each stream, a queue of commands that are awaiting processing, and the + # connection that they arrived on. + self._command_queues_by_stream = { + stream_name: _StreamCommandQueue() for stream_name in self._streams + } + + # For each connection, the incoming stream names that have received a POSITION + # from that connection. + self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + + LaterGauge( + "synapse_replication_tcp_command_queue", + "Number of inbound RDATA/POSITION commands queued for processing", + ["stream_name"], + lambda: { + (stream_name,): len(queue) + for stream_name, queue in self._command_queues_by_stream.items() + }, + ) + self._is_master = hs.config.worker_app is None self._federation_sender = None @@ -143,6 +185,64 @@ class ReplicationCommandHandler: if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() + async def _add_command_to_stream_queue( + self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] + ) -> None: + """Queue the given received command for processing + + Adds the given command to the per-stream queue, and processes the queue if + necessary + """ + stream_name = cmd.stream_name + queue = self._command_queues_by_stream.get(stream_name) + if queue is None: + logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name) + return + + # if we're already processing this stream, stick the new command in the + # queue, and we're done. + if stream_name in self._processing_streams: + queue.append((cmd, conn)) + return + + # otherwise, process the new command. + + # arguably we should start off a new background process here, but nothing + # will be too upset if we don't return for ages, so let's save the overhead + # and use the existing logcontext. + + self._processing_streams.add(stream_name) + try: + # might as well skip the queue for this one, since it must be empty + assert not queue + await self._process_command(cmd, conn, stream_name) + + # now process any other commands that have built up while we were + # dealing with that one. + while queue: + cmd, conn = queue.popleft() + try: + await self._process_command(cmd, conn, stream_name) + except Exception: + logger.exception("Failed to handle command %s", cmd) + + finally: + self._processing_streams.discard(stream_name) + + async def _process_command( + self, + cmd: Union[PositionCommand, RdataCommand], + conn: AbstractConnection, + stream_name: str, + ) -> None: + if isinstance(cmd, PositionCommand): + await self._process_position(stream_name, conn, cmd) + elif isinstance(cmd, RdataCommand): + await self._process_rdata(stream_name, conn, cmd) + else: + # This shouldn't be possible + raise Exception("Unrecognised command %s in stream queue", cmd.NAME) + def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. @@ -238,7 +338,7 @@ class ReplicationCommandHandler: federation_ack_counter.inc() if self._federation_sender: - self._federation_sender.federation_ack(cmd.token) + self._federation_sender.federation_ack(cmd.instance_name, cmd.token) async def on_REMOVE_PUSHER( self, conn: AbstractConnection, cmd: RemovePusherCommand @@ -276,63 +376,71 @@ class ReplicationCommandHandler: stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) - raise - - # We linearize here for two reasons: + # We put the received command into a queue here for two reasons: # 1. so we don't try and concurrently handle multiple rows for the # same stream, and # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. - with await self._position_linearizer.queue(cmd.stream_name): - # make sure that we've processed a POSITION for this stream *on this - # connection*. (A POSITION on another connection is no good, as there - # is no guarantee that we have seen all the intermediate updates.) - sbc = self._streams_by_connection.get(conn) - if not sbc or stream_name not in sbc: - # Let's drop the row for now, on the assumption we'll receive a - # `POSITION` soon and we'll catch up correctly then. - logger.debug( - "Discarding RDATA for unconnected stream %s -> %s", - stream_name, - cmd.token, - ) - return - - if cmd.token is None: - # I.e. this is part of a batch of updates for this stream (in - # which case batch until we get an update for the stream with a non - # None token). - self._pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self._pending_batches.pop(stream_name, []) - rows.append(row) - - stream = self._streams.get(stream_name) - if not stream: - logger.error("Got RDATA for unknown stream: %s", stream_name) - return - - # Find where we previously streamed up to. - current_token = stream.current_token(cmd.instance_name) - - # Discard this data if this token is earlier than the current - # position. Note that streams can be reset (in which case you - # expect an earlier token), but that must be preceded by a - # POSITION command. - if cmd.token <= current_token: - logger.debug( - "Discarding RDATA from stream %s at position %s before previous position %s", - stream_name, - cmd.token, - current_token, - ) - else: - await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) + + await self._add_command_to_stream_queue(conn, cmd) + + async def _process_rdata( + self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand + ) -> None: + """Process an RDATA command + + Called after the command has been popped off the queue of inbound commands + """ + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception as e: + raise Exception( + "Failed to parse RDATA: %r %r" % (stream_name, cmd.row) + ) from e + + # make sure that we've processed a POSITION for this stream *on this + # connection*. (A POSITION on another connection is no good, as there + # is no guarantee that we have seen all the intermediate updates.) + sbc = self._streams_by_connection.get(conn) + if not sbc or stream_name not in sbc: + # Let's drop the row for now, on the assumption we'll receive a + # `POSITION` soon and we'll catch up correctly then. + logger.debug( + "Discarding RDATA for unconnected stream %s -> %s", + stream_name, + cmd.token, + ) + return + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token). + self._pending_batches.setdefault(stream_name, []).append(row) + return + + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + + stream = self._streams[stream_name] + + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) + + # Discard this data if this token is earlier than the current + # position. Note that streams can be reset (in which case you + # expect an earlier token), but that must be preceded by a + # POSITION command. + if cmd.token <= current_token: + logger.debug( + "Discarding RDATA from stream %s at position %s before previous position %s", + stream_name, + cmd.token, + current_token, + ) + else: + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -358,67 +466,65 @@ class ReplicationCommandHandler: logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) - stream_name = cmd.stream_name - stream = self._streams.get(stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", stream_name) - return + await self._add_command_to_stream_queue(conn, cmd) - # We protect catching up with a linearizer in case the replication - # connection reconnects under us. - with await self._position_linearizer.queue(stream_name): - # We're about to go and catch up with the stream, so remove from set - # of connected streams. - for streams in self._streams_by_connection.values(): - streams.discard(stream_name) - - # We clear the pending batches for the stream as the fetching of the - # missing updates below will fetch all rows in the batch. - self._pending_batches.pop(stream_name, []) - - # Find where we previously streamed up to. - current_token = stream.current_token(cmd.instance_name) - - # If the position token matches our current token then we're up to - # date and there's nothing to do. Otherwise, fetch all updates - # between then and now. - missing_updates = cmd.token != current_token - while missing_updates: - logger.info( - "Fetching replication rows for '%s' between %i and %i", - stream_name, - current_token, - cmd.token, - ) - ( - updates, - current_token, - missing_updates, - ) = await stream.get_updates_since( - cmd.instance_name, current_token, cmd.token - ) + async def _process_position( + self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand + ) -> None: + """Process a POSITION command - # TODO: add some tests for this + Called after the command has been popped off the queue of inbound commands + """ + stream = self._streams[stream_name] - # Some streams return multiple rows with the same stream IDs, - # which need to be processed in batches. + # We're about to go and catch up with the stream, so remove from set + # of connected streams. + for streams in self._streams_by_connection.values(): + streams.discard(stream_name) - for token, rows in _batch_updates(updates): - await self.on_rdata( - stream_name, - cmd.instance_name, - token, - [stream.parse_row(row) for row in rows], - ) + # We clear the pending batches for the stream as the fetching of the + # missing updates below will fetch all rows in the batch. + self._pending_batches.pop(stream_name, []) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) - # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position( - cmd.stream_name, cmd.instance_name, cmd.token + # If the position token matches our current token then we're up to + # date and there's nothing to do. Otherwise, fetch all updates + # between then and now. + missing_updates = cmd.token != current_token + while missing_updates: + logger.info( + "Fetching replication rows for '%s' between %i and %i", + stream_name, + current_token, + cmd.token, ) + (updates, current_token, missing_updates) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token + ) + + # TODO: add some tests for this + + # Some streams return multiple rows with the same stream IDs, + # which need to be processed in batches. + + for token, rows in _batch_updates(updates): + await self.on_rdata( + stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], + ) + + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) - self._streams_by_connection.setdefault(conn, set()).add(stream_name) + self._streams_by_connection.setdefault(conn, set()).add(stream_name) async def on_REMOTE_SERVER_UP( self, conn: AbstractConnection, cmd: RemoteServerUpCommand @@ -527,7 +633,7 @@ class ReplicationCommandHandler: """Ack data for the federation stream. This allows the master to drop data stored purely in memory. """ - self.send_command(FederationAckCommand(token)) + self.send_command(FederationAckCommand(self._instance_name, token)) def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 4198eece71..23191e3218 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -57,8 +57,12 @@ from prometheus_client import Counter from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure +from synapse.logging.context import PreserveLoggingContext from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + BackgroundProcessLoggingContext, + run_as_background_process, +) from synapse.replication.tcp.commands import ( VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, @@ -160,6 +164,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The LoopingCall for sending pings. self._send_ping_loop = None + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + self._logging_context = BackgroundProcessLoggingContext( + "replication_command_handler-%s" % self.conn_id + ) + def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -210,6 +220,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def lineReceived(self, line: bytes): """Called when we've received a line """ + with PreserveLoggingContext(self._logging_context): + self._parse_and_dispatch_line(line) + + def _parse_and_dispatch_line(self, line: bytes): if line.strip() == "": # Ignore blank lines return @@ -317,7 +331,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def _queue_command(self, cmd): """Queue the command until the connection is ready to write to again. """ - logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) + logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) if len(self.pending_commands) > self.max_line_buffer: @@ -397,6 +411,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if self.transport: self.transport.unregisterProducer() + # mark the logging context as finished + self._logging_context.__exit__(None, None, None) + def __str__(self): addr = None if self.transport: diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index e776b63183..b5c533a607 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py
@@ -18,8 +18,11 @@ from typing import TYPE_CHECKING import txredisapi -from synapse.logging.context import make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import ( + BackgroundProcessLoggingContext, + run_as_background_process, +) from synapse.replication.tcp.commands import ( Command, ReplicateCommand, @@ -66,6 +69,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): stream_name = None # type: str outbound_redis_connection = None # type: txredisapi.RedisProtocol + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + self._logging_context = BackgroundProcessLoggingContext( + "replication_command_handler" + ) + def connectionMade(self): logger.info("Connected to redis") super().connectionMade() @@ -92,7 +104,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. """ + with PreserveLoggingContext(self._logging_context): + self._parse_and_dispatch_message(message) + def _parse_and_dispatch_message(self, message: str): if message.strip() == "": # Ignore blank lines return @@ -145,6 +160,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): super().connectionLost(reason) self.handler.lost_connection(self) + # mark the logging context as finished + self._logging_context.__exit__(None, None, None) + def send_command(self, cmd: Command): """Send a command if connection has been established. @@ -177,7 +195,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): Args: hs outbound_redis_connection: A connection to redis that will be used to - send outbound commands (this is seperate to the redis connection + send outbound commands (this is separate to the redis connection used to subscribe). """ diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 9076bbe9f1..7a42de3f7d 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -294,11 +294,12 @@ class TypingStream(Stream): def __init__(self, hs): typing_handler = hs.get_typing_handler() - if hs.config.worker_app is None: - # on the master, query the typing handler + writer_instance = hs.config.worker.writers.typing + if writer_instance == hs.get_instance_name(): + # On the writer, query the typing handler update_function = typing_handler.get_all_typing_updates else: - # Query master process + # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) super().__init__( diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index bdddb62ad6..16c63ff4ec 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py
@@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import heapq -from collections import Iterable +from collections.abc import Iterable from typing import List, Tuple, Type import attr @@ -62,7 +62,7 @@ class BaseEventsStreamRow(object): Specifies how to identify, serialize and deserialize the different types. """ - # Unique string that ids the type. Must be overriden in sub classes. + # Unique string that ids the type. Must be overridden in sub classes. TypeId = None # type: str @classmethod diff --git a/synapse/res/templates/mail-Element.css b/synapse/res/templates/mail-Element.css new file mode 100644
index 0000000000..6a3e36eda1 --- /dev/null +++ b/synapse/res/templates/mail-Element.css
@@ -0,0 +1,7 @@ +.header { + border-bottom: 4px solid #e4f7ed ! important; +} + +.notif_link a, .footer a { + color: #76CFA6 ! important; +} diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html
index 6b94d8c367..d87311f659 100644 --- a/synapse/res/templates/notice_expiry.html +++ b/synapse/res/templates/notice_expiry.html
@@ -22,6 +22,8 @@ <img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/> {% elif app_name == "Vector" %} <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> + {% elif app_name == "Element" %} + <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/> {% else %} <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> {% endif %} diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html
index 019506e5fb..a2dfeb9e9f 100644 --- a/synapse/res/templates/notif_mail.html +++ b/synapse/res/templates/notif_mail.html
@@ -22,6 +22,8 @@ <img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/> {% elif app_name == "Vector" %} <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> + {% elif app_name == "Element" %} + <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/> {% else %} <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> {% endif %} diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 9eda592de9..1c88c93f38 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py
@@ -35,8 +35,10 @@ from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.rooms import ( + DeleteRoomRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, + RoomMembersRestServlet, RoomRestServlet, ShutdownRoomRestServlet, ) @@ -200,6 +202,8 @@ def register_servlets(hs, http_server): register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) + RoomMembersRestServlet(hs).register(http_server) + DeleteRoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index e07c32118d..b8c95d045a 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import List, Optional -from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset +from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, @@ -32,7 +33,6 @@ from synapse.rest.admin._base import ( ) from synapse.storage.data_stores.main.room import RoomSortOrder from synapse.types import RoomAlias, RoomID, UserID, create_requester -from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -46,20 +46,10 @@ class ShutdownRoomRestServlet(RestServlet): PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)") - DEFAULT_MESSAGE = ( - "Sharing illegal content on this server is not permitted and rooms in" - " violation will be blocked." - ) - def __init__(self, hs): self.hs = hs - self.store = hs.get_datastore() - self.state = hs.get_state_handler() - self._room_creation_handler = hs.get_room_creation_handler() - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - self._replication = hs.get_replication_data_handler() + self.room_shutdown_handler = hs.get_room_shutdown_handler() async def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request) @@ -67,116 +57,65 @@ class ShutdownRoomRestServlet(RestServlet): content = parse_json_object_from_request(request) assert_params_in_dict(content, ["new_room_user_id"]) - new_room_user_id = content["new_room_user_id"] - - room_creator_requester = create_requester(new_room_user_id) - - message = content.get("message", self.DEFAULT_MESSAGE) - room_name = content.get("room_name", "Content Violation Notification") - - info, stream_id = await self._room_creation_handler.create_room( - room_creator_requester, - config={ - "preset": RoomCreationPreset.PUBLIC_CHAT, - "name": room_name, - "power_level_content_override": {"users_default": -10}, - }, - ratelimit=False, - ) - new_room_id = info["room_id"] - - requester_user_id = requester.user.to_string() - logger.info( - "Shutting down room %r, joining to new room: %r", room_id, new_room_id + ret = await self.room_shutdown_handler.shutdown_room( + room_id=room_id, + new_room_user_id=content["new_room_user_id"], + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=True, ) - # This will work even if the room is already blocked, but that is - # desirable in case the first attempt at blocking the room failed below. - await self.store.block_room(room_id, requester_user_id) - - # We now wait for the create room to come back in via replication so - # that we can assume that all the joins/invites have propogated before - # we try and auto join below. - # - # TODO: Currently the events stream is written to from master - await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id - ) + return (200, ret) - users = await self.state.get_current_users_in_room(room_id) - kicked_users = [] - failed_to_kick_users = [] - for user_id in users: - if not self.hs.is_mine_id(user_id): - continue - logger.info("Kicking %r from %r...", user_id, room_id) +class DeleteRoomRestServlet(RestServlet): + """Delete a room from server. It is a combination and improvement of + shut down and purge room. + Shuts down a room by removing all local users from the room. + Blocking all future invites and joins to the room is optional. + If desired any local aliases will be repointed to a new room + created by `new_room_user_id` and kicked users will be auto + joined to the new room. + It will remove all trace of a room from the database. + """ - try: - target_requester = create_requester(user_id) - _, stream_id = await self.room_member_handler.update_membership( - requester=target_requester, - target=target_requester.user, - room_id=room_id, - action=Membership.LEAVE, - content={}, - ratelimit=False, - require_consent=False, - ) + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$") - # Wait for leave to come in over replication before trying to forget. - await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id - ) + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.room_shutdown_handler = hs.get_room_shutdown_handler() + self.pagination_handler = hs.get_pagination_handler() - await self.room_member_handler.forget(target_requester.user, room_id) + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) - await self.room_member_handler.update_membership( - requester=target_requester, - target=target_requester.user, - room_id=new_room_id, - action=Membership.JOIN, - content={}, - ratelimit=False, - require_consent=False, - ) + content = parse_json_object_from_request(request) - kicked_users.append(user_id) - except Exception: - logger.exception( - "Failed to leave old room and join new room for %r", user_id - ) - failed_to_kick_users.append(user_id) - - await self.event_creation_handler.create_and_send_nonmember_event( - room_creator_requester, - { - "type": "m.room.message", - "content": {"body": message, "msgtype": "m.text"}, - "room_id": new_room_id, - "sender": new_room_user_id, - }, - ratelimit=False, - ) + block = content.get("block", False) + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean, if given", + Codes.BAD_JSON, + ) - aliases_for_room = await maybe_awaitable( - self.store.get_aliases_for_room(room_id) + ret = await self.room_shutdown_handler.shutdown_room( + room_id=room_id, + new_room_user_id=content.get("new_room_user_id"), + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=block, ) - await self.store.update_aliases_for_room( - room_id, new_room_id, requester_user_id - ) + # Purge room + await self.pagination_handler.purge_room(room_id) - return ( - 200, - { - "kicked_users": kicked_users, - "failed_to_kick_users": failed_to_kick_users, - "local_aliases": aliases_for_room, - "new_room_id": new_room_id, - }, - ) + return (200, ret) class ListRoomRestServlet(RestServlet): @@ -292,6 +231,31 @@ class RoomRestServlet(RestServlet): return 200, ret +class RoomMembersRestServlet(RestServlet): + """ + Get members list of a room. + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + await assert_requester_is_admin(self.auth, request) + + ret = await self.store.get_room(room_id) + if not ret: + raise NotFoundError("Room not found") + + members = await self.store.get_users_in_room(room_id) + ret = {"members": members, "total": len(members)} + + return 200, ret + + class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index e4330c39d6..cc0bdfa5c9 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -239,6 +239,15 @@ class UserRestServletV2(RestServlet): await self.deactivate_account_handler.deactivate_account( target_user.to_string(), False ) + elif not deactivate and user["deactivated"]: + if "password" not in body: + raise SynapseError( + 400, "Must provide a password to re-activate an account." + ) + + await self.deactivate_account_handler.activate_account( + target_user.to_string() + ) user = await self.admin_handler.get_user(target_user) return 200, user @@ -254,7 +263,6 @@ class UserRestServletV2(RestServlet): admin = body.get("admin", None) user_type = body.get("user_type", None) displayname = body.get("displayname", None) - threepids = body.get("threepids", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: raise SynapseError(400, "Invalid user type") diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 64d5c58b65..379f668d6f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -89,12 +89,19 @@ class LoginRestServlet(RestServlet): def __init__(self, hs): super(LoginRestServlet, self).__init__() self.hs = hs + + # JWT configuration variables. self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm + self.jwt_issuer = hs.config.jwt_issuer + self.jwt_audiences = hs.config.jwt_audiences + + # SSO configuration. self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() @@ -364,24 +371,28 @@ class LoginRestServlet(RestServlet): token = login_submission.get("token", None) if token is None: raise LoginError( - 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED + 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN ) import jwt - from jwt.exceptions import InvalidTokenError try: payload = jwt.decode( - token, self.jwt_secret, algorithms=[self.jwt_algorithm] + token, + self.jwt_secret, + algorithms=[self.jwt_algorithm], + issuer=self.jwt_issuer, + audience=self.jwt_audiences, + ) + except jwt.PyJWTError as e: + # A JWT error occurred, return some info back to the client. + raise LoginError( + 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN, ) - except jwt.ExpiredSignatureError: - raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED) - except InvalidTokenError: - raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user = payload.get("sub", None) if user is None: - raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) + raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) user_id = UserID(user, self.hs.hostname).to_string() result = await self._complete_login( diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 46811abbfa..26d5a51cb2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py
@@ -15,6 +15,7 @@ # limitations under the License. """ This module contains REST servlets to do with rooms: /rooms/<paths> """ + import logging import re from typing import List, Optional @@ -217,10 +218,8 @@ class RoomStateEventRestServlet(TransactionRestServlet): ) event_id = event.event_id - ret = {} # type: dict - if event_id: - set_tag("event_id", event_id) - ret = {"event_id": event_id} + set_tag("event_id", event_id) + ret = {"event_id": event_id} return 200, ret @@ -517,9 +516,9 @@ class RoomMessageListRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args - filter_bytes = parse_string(request, b"filter", encoding=None) - if filter_bytes: - filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) + filter_str = parse_string(request, b"filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] if ( event_filter @@ -629,9 +628,9 @@ class RoomEventContextServlet(RestServlet): limit = parse_integer(request, "limit", default=10) # picking the API shape for symmetry with /messages - filter_bytes = parse_string(request, "filter") - if filter_bytes: - filter_json = urlparse.unquote(filter_bytes) + filter_str = parse_string(request, b"filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] else: event_filter = None @@ -818,9 +817,18 @@ class RoomTypingRestServlet(RestServlet): self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() + # If we're not on the typing writer instance we should scream if we get + # requests. + self._is_typing_writer = ( + hs.config.worker.writers.typing == hs.get_instance_name() + ) + async def on_PUT(self, request, room_id, user_id): requester = await self.auth.get_user_by_req(request) + if not self._is_typing_writer: + raise Exception("Got /typing request on instance that is not typing writer") + room_id = urlparse.unquote(room_id) target_user = UserID.from_string(urlparse.unquote(user_id)) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index bc11b4dda4..b21538766d 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py
@@ -22,6 +22,7 @@ from twisted.internet import defer from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.urls import CLIENT_API_PREFIX +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -51,7 +52,15 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): return patterns -def set_timeline_upper_limit(filter_json, filter_timeline_limit): +def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None: + """ + Enforces a maximum limit of a timeline query. + + Params: + filter_json: The timeline query to modify. + filter_timeline_limit: The maximum limit to allow, passing -1 will + disable enforcing a maximum limit. + """ if filter_timeline_limit < 0: return # no upper limits timeline = filter_json.get("room", {}).get("timeline", {}) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 8fa68dd37f..a5c24fbd63 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py
@@ -178,14 +178,22 @@ class SyncRestServlet(RestServlet): full_state=full_state, ) + # the client may have disconnected by now; don't bother to serialize the + # response if so. + if request._disconnected: + logger.info("Client has disconnected; not serializing response.") + return 200, {} + time_now = self.clock.time_msec() response_content = await self.encode_response( time_now, sync_result, requester.access_token_id, filter_collection ) + logger.debug("Event formatting complete") return 200, response_content async def encode_response(self, time_now, sync_result, access_token_id, filter): + logger.debug("Formatting events in sync response") if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id elif filter.event_format == "federation": @@ -213,6 +221,7 @@ class SyncRestServlet(RestServlet): event_formatter, ) + logger.debug("building sync response dict") return { "account_data": {"events": sync_result.account_data}, "to_device": {"events": sync_result.to_device}, diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index e149ac1733..9b3f85b306 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -202,9 +202,11 @@ class RemoteKey(DirectServeJsonResource): if miss: cache_misses.setdefault(server_name, set()).add(key_id) + # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(most_recent_result["key_json"])) else: for ts_added, result in results: + # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(result["key_json"])) if cache_misses and query_remote_on_cache_miss: @@ -213,7 +215,7 @@ class RemoteKey(DirectServeJsonResource): else: signed_keys = [] for key_json in json_results: - key_json = json.loads(key_json) + key_json = json.loads(key_json.decode("utf-8")) for signing_key in self.config.key_server_signing_keys: key_json = sign_json(key_json, self.config.server_name, signing_key) diff --git a/synapse/server.py b/synapse/server.py
index 6acce2e23f..8e41112530 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -44,7 +44,6 @@ from synapse.federation.federation_client import FederationClient from synapse.federation.federation_server import ( FederationHandlerRegistry, FederationServer, - ReplicationFederationHandlerRegistry, ) from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.sender import FederationSender @@ -73,14 +72,18 @@ from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.register import RegistrationHandler -from synapse.handlers.room import RoomContextHandler, RoomCreationHandler +from synapse.handlers.room import ( + RoomContextHandler, + RoomCreationHandler, + RoomShutdownHandler, +) from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler -from synapse.handlers.typing import TypingHandler +from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.user_directory import UserDirectoryHandler from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient @@ -102,7 +105,7 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import DataStores, Storage +from synapse.storage import DataStore, DataStores, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -144,6 +147,7 @@ class HomeServer(object): "handlers", "auth", "room_creation_handler", + "room_shutdown_handler", "state_handler", "state_resolution_handler", "presence_handler", @@ -307,7 +311,7 @@ class HomeServer(object): def get_clock(self): return self.clock - def get_datastore(self): + def get_datastore(self) -> DataStore: return self.datastores.main def get_datastores(self): @@ -357,6 +361,9 @@ class HomeServer(object): def build_room_creation_handler(self): return RoomCreationHandler(self) + def build_room_shutdown_handler(self): + return RoomShutdownHandler(self) + def build_sendmail(self): return sendmail @@ -370,7 +377,10 @@ class HomeServer(object): return PresenceHandler(self) def build_typing_handler(self): - return TypingHandler(self) + if self.config.worker.writers.typing == self.get_instance_name(): + return TypingWriterHandler(self) + else: + return FollowerTypingHandler(self) def build_sync_handler(self): return SyncHandler(self) @@ -526,10 +536,7 @@ class HomeServer(object): return RoomMemberMasterHandler(self) def build_federation_registry(self): - if self.config.worker_app: - return ReplicationFederationHandlerRegistry(self) - else: - return FederationHandlerRegistry() + return FederationHandlerRegistry(self) def build_server_notices_manager(self): if self.config.worker_app: diff --git a/synapse/server.pyi b/synapse/server.pyi
index fe8024d2d4..90a673778f 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi
@@ -20,6 +20,7 @@ import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password import synapse.http.client +import synapse.http.matrixfederationclient import synapse.notifier import synapse.push.pusherpool import synapse.replication.tcp.client @@ -71,6 +72,8 @@ class HomeServer(object): pass def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: pass + def get_room_shutdown_handler(self) -> synapse.handlers.room.RoomShutdownHandler: + pass def get_event_creation_handler( self, ) -> synapse.handlers.message.EventCreationHandler: @@ -141,3 +144,9 @@ class HomeServer(object): pass def get_replication_streams(self) -> Dict[str, Stream]: pass + def get_http_client( + self, + ) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient: + pass + def should_send_federation(self) -> bool: + pass diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index bfce541ca7..985a042869 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -100,8 +100,8 @@ def db_to_json(db_content): if isinstance(db_content, memoryview): db_content = db_content.tobytes() - # Decode it to a Unicode string before feeding it to json.loads, so we - # consistenty get a Unicode-containing object out. + # Decode it to a Unicode string before feeding it to json.loads, since + # Python 3.5 does not support deserializing bytes. if isinstance(db_content, (bytes, bytearray)): db_content = db_content.decode("utf8") diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 59f3394b0a..018826ef69 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py
@@ -249,7 +249,10 @@ class BackgroundUpdater(object): retcol="progress_json", ) - progress = json.loads(progress_json) + # Avoid a circular import. + from synapse.storage._base import db_to_json + + progress = db_to_json(progress_json) time_start = self._clock.time_msec() items_updated = await update_handler(progress, batch_size) diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 4b4763c701..932458f651 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py
@@ -128,7 +128,7 @@ class DataStore( db_conn, "presence_stream", "stream_id" ) self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_max_stream_id", "stream_id" + db_conn, "device_inbox", "stream_id" ) self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index b58f04d00d..33cc372dfd 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py
@@ -22,7 +22,7 @@ from canonicaljson import json from twisted.internet import defer -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -77,7 +77,7 @@ class AccountDataWorkerStore(SQLBaseStore): ) global_account_data = { - row["account_data_type"]: json.loads(row["content"]) for row in rows + row["account_data_type"]: db_to_json(row["content"]) for row in rows } rows = self.db.simple_select_list_txn( @@ -90,7 +90,7 @@ class AccountDataWorkerStore(SQLBaseStore): by_room = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) - room_data[row["account_data_type"]] = json.loads(row["content"]) + room_data[row["account_data_type"]] = db_to_json(row["content"]) return global_account_data, by_room @@ -113,7 +113,7 @@ class AccountDataWorkerStore(SQLBaseStore): ) if result: - return json.loads(result) + return db_to_json(result) else: return None @@ -137,7 +137,7 @@ class AccountDataWorkerStore(SQLBaseStore): ) return { - row["account_data_type"]: json.loads(row["content"]) for row in rows + row["account_data_type"]: db_to_json(row["content"]) for row in rows } return self.db.runInteraction( @@ -170,7 +170,7 @@ class AccountDataWorkerStore(SQLBaseStore): allow_none=True, ) - return json.loads(content_json) if content_json else None + return db_to_json(content_json) if content_json else None return self.db.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn @@ -255,7 +255,7 @@ class AccountDataWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) - global_account_data = {row[0]: json.loads(row[1]) for row in txn} + global_account_data = {row[0]: db_to_json(row[1]) for row in txn} sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" @@ -267,7 +267,7 @@ class AccountDataWorkerStore(SQLBaseStore): account_data_by_room = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) - room_account_data[row[1]] = json.loads(row[2]) + room_account_data[row[1]] = db_to_json(row[2]) return global_account_data, account_data_by_room diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 7a1fe8cdd2..56659fed37 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py
@@ -22,7 +22,7 @@ from twisted.internet import defer from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.database import Database @@ -303,7 +303,7 @@ class ApplicationServiceTransactionWorkerStore( if not entry: return None - event_ids = json.loads(entry["event_ids"]) + event_ids = db_to_json(entry["event_ids"]) events = yield self.get_events_as_list(event_ids) diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index d313b9705f..da297b31fb 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -21,7 +21,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache @@ -65,7 +65,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): messages = [] for row in txn: stream_pos = row[0] - messages.append(json.loads(row[1])) + messages.append(db_to_json(row[1])) if len(messages) < limit: stream_pos = current_stream_id return messages, stream_pos @@ -173,7 +173,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): messages = [] for row in txn: stream_pos = row[0] - messages.append(json.loads(row[1])) + messages.append(db_to_json(row[1])) if len(messages) < limit: log_kv({"message": "Set stream position to current position"}) stream_pos = current_stream_id @@ -424,9 +424,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) def _add_messages_to_local_device_inbox_txn( self, txn, stream_id, messages_by_user_then_device ): - sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" - txn.execute(sql, (stream_id, stream_id)) - local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): messages_json_for_user = {} diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 343cf9a2d5..45581a6500 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py
@@ -577,7 +577,7 @@ class DeviceWorkerStore(SQLBaseStore): rows = yield self.db.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) - return {user for row in rows for user in json.loads(row[0])} + return {user for row in rows for user in db_to_json(row[0])} else: return set() diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 23f4570c4b..615364f018 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +from canonicaljson import json from twisted.internet import defer from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json class EndToEndRoomKeyStore(SQLBaseStore): @@ -148,7 +148,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "forwarded_count": row["forwarded_count"], # is_verified must be returned to the client as a boolean "is_verified": bool(row["is_verified"]), - "session_data": json.loads(row["session_data"]), + "session_data": db_to_json(row["session_data"]), } return sessions @@ -222,7 +222,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": row[2], "forwarded_count": row[3], "is_verified": row[4], - "session_data": json.loads(row[5]), + "session_data": db_to_json(row[5]), } return ret @@ -319,7 +319,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, retcols=("version", "algorithm", "auth_data", "etag"), ) - result["auth_data"] = json.loads(result["auth_data"]) + result["auth_data"] = db_to_json(result["auth_data"]) result["version"] = str(result["version"]) if result["etag"] is None: result["etag"] = 0 diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 6c3cff82e1..317c07a829 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -366,7 +366,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for row in rows: user_id = row["user_id"] key_type = row["keytype"] - key = json.loads(row["keydata"]) + key = db_to_json(row["keydata"]) user_info = result.setdefault(user_id, {}) user_info[key_type] = key diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index bc9f4f08ea..504babaa7e 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -21,7 +21,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage.database import Database from synapse.util.caches.descriptors import cachedInlineCallbacks @@ -58,7 +58,7 @@ def _deserialize_action(actions, is_highlight): """Custom deserializer for actions. This allows us to "compress" common actions """ if actions: - return json.loads(actions) + return db_to_json(actions) if is_highlight: return DEFAULT_HIGHLIGHT_ACTION diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 230fb5cd7f..6f2e0d15cc 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py
@@ -17,11 +17,9 @@ import itertools import logging from collections import OrderedDict, namedtuple -from functools import wraps from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple import attr -from canonicaljson import json from prometheus_client import Counter from twisted.internet import defer @@ -33,7 +31,7 @@ from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function -from synapse.storage._base import make_in_list_sql_clause +from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.data_stores.main.search import SearchEntry from synapse.storage.database import Database, LoggingTransaction from synapse.storage.util.id_generators import StreamIdGenerator @@ -69,27 +67,6 @@ def encode_json(json_object): _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) -def _retry_on_integrity_error(func): - """Wraps a database function so that it gets retried on IntegrityError, - with `delete_existing=True` passed in. - - Args: - func: function that returns a Deferred and accepts a `delete_existing` arg - """ - - @wraps(func) - @defer.inlineCallbacks - def f(self, *args, **kwargs): - try: - res = yield func(self, *args, delete_existing=False, **kwargs) - except self.database_engine.module.IntegrityError: - logger.exception("IntegrityError, retrying.") - res = yield func(self, *args, delete_existing=True, **kwargs) - return res - - return f - - @attr.s(slots=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -134,7 +111,6 @@ class PersistEventsStore: hs.config.worker.writers.events == hs.get_instance_name() ), "Can only instantiate EventsStore on master" - @_retry_on_integrity_error @defer.inlineCallbacks def _persist_events_and_state_updates( self, @@ -143,7 +119,6 @@ class PersistEventsStore: state_delta_for_room: Dict[str, DeltaState], new_forward_extremeties: Dict[str, List[str]], backfilled: bool = False, - delete_existing: bool = False, ): """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -157,7 +132,6 @@ class PersistEventsStore: new_forward_extremities: Map from room_id to list of event IDs that are the new forward extremities of the room. backfilled - delete_existing Returns: Deferred: resolves when the events have been persisted @@ -197,7 +171,6 @@ class PersistEventsStore: self._persist_events_txn, events_and_contexts=events_and_contexts, backfilled=backfilled, - delete_existing=delete_existing, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) @@ -262,7 +235,7 @@ class PersistEventsStore: ) txn.execute(sql + clause, args) - results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) + results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): yield self.db.runInteraction( @@ -323,7 +296,7 @@ class PersistEventsStore: if prev_event_id in existing_prevs: continue - soft_failed = json.loads(metadata).get("soft_failed") + soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: to_recursively_check.append(prev_event_id) existing_prevs.add(prev_event_id) @@ -341,7 +314,6 @@ class PersistEventsStore: txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool, - delete_existing: bool = False, state_delta_for_room: Dict[str, DeltaState] = {}, new_forward_extremeties: Dict[str, List[str]] = {}, ): @@ -393,13 +365,6 @@ class PersistEventsStore: # From this point onwards the events are only events that we haven't # seen before. - if delete_existing: - # For paranoia reasons, we go and delete all the existing entries - # for these events so we can reinsert them. - # This gets around any problems with some tables already having - # entries. - self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts) - self._store_event_txn(txn, events_and_contexts=events_and_contexts) # Insert into event_to_state_groups. @@ -617,7 +582,7 @@ class PersistEventsStore: txn.execute(sql, (room_id, EventTypes.Create, "")) row = txn.fetchone() if row: - event_json = json.loads(row[0]) + event_json = db_to_json(row[0]) content = event_json.get("content", {}) creator = content.get("creator") room_version_id = content.get("room_version", RoomVersions.V1.identifier) @@ -797,39 +762,6 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] - @classmethod - def _delete_existing_rows_txn(cls, txn, events_and_contexts): - if not events_and_contexts: - # nothing to do here - return - - logger.info("Deleting existing") - - for table in ( - "events", - "event_auth", - "event_json", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_search", - "event_to_state_groups", - "state_events", - "rejections", - "redactions", - "room_memberships", - ): - txn.executemany( - "DELETE FROM %s WHERE event_id = ?" % (table,), - [(ev.event_id,) for ev, _ in events_and_contexts], - ) - - for table in ("event_push_actions",): - txn.executemany( - "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,), - [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts], - ) - def _store_event_txn(self, txn, events_and_contexts): """Insert new events into the event and event_json tables diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 62d28f44dc..663c94b24f 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -15,12 +15,10 @@ import logging -from canonicaljson import json - from twisted.internet import defer from synapse.api.constants import EventContentFields -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -125,7 +123,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): for row in rows: try: event_id = row[1] - event_json = json.loads(row[2]) + event_json = db_to_json(row[2]) sender = event_json["sender"] content = event_json["content"] @@ -208,7 +206,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): for row in ev_rows: event_id = row["event_id"] - event_json = json.loads(row["json"]) + event_json = db_to_json(row["json"]) try: origin_server_ts = event_json["origin_server_ts"] except (KeyError, AttributeError): @@ -317,7 +315,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): soft_failed = False if metadata: - soft_failed = json.loads(metadata).get("soft_failed") + soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: soft_failed_events_to_lookup.add(event_id) @@ -358,7 +356,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): graph[event_id] = {prev_event_id} - soft_failed = json.loads(metadata).get("soft_failed") + soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: soft_failed_events_to_lookup.add(event_id) else: @@ -543,7 +541,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): last_row_event_id = "" for (event_id, event_json_raw) in results: try: - event_json = json.loads(event_json_raw) + event_json = db_to_json(event_json_raw) self.db.simple_insert_many_txn( txn=txn, diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 01cad7d4fa..e812c67078 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py
@@ -21,7 +21,6 @@ import threading from collections import namedtuple from typing import List, Optional, Tuple -from canonicaljson import json from constantly import NamedConstant, Names from twisted.internet import defer @@ -40,7 +39,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id @@ -611,8 +610,8 @@ class EventsWorkerStore(SQLBaseStore): if not allow_rejected and rejected_reason: continue - d = json.loads(row["json"]) - internal_metadata = json.loads(row["internal_metadata"]) + d = db_to_json(row["json"]) + internal_metadata = db_to_json(row["internal_metadata"]) format_version = row["format_version"] if format_version is None: @@ -640,7 +639,7 @@ class EventsWorkerStore(SQLBaseStore): else: room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) if not room_version: - logger.error( + logger.warning( "Event %s in room %s has unknown room version %s", event_id, d["room_id"], diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 4fb9f9850c..01ff561e1a 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py
@@ -21,7 +21,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json # The category ID for the "default" category. We don't store as null in the # database to avoid the fun of null != null @@ -197,7 +197,7 @@ class GroupServerWorkerStore(SQLBaseStore): categories = { row[0]: { "is_public": row[1], - "profile": json.loads(row[2]), + "profile": db_to_json(row[2]), "order": row[3], } for row in txn @@ -221,7 +221,7 @@ class GroupServerWorkerStore(SQLBaseStore): return { row["category_id"]: { "is_public": row["is_public"], - "profile": json.loads(row["profile"]), + "profile": db_to_json(row["profile"]), } for row in rows } @@ -235,7 +235,7 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_group_category", ) - category["profile"] = json.loads(category["profile"]) + category["profile"] = db_to_json(category["profile"]) return category @@ -251,7 +251,7 @@ class GroupServerWorkerStore(SQLBaseStore): return { row["role_id"]: { "is_public": row["is_public"], - "profile": json.loads(row["profile"]), + "profile": db_to_json(row["profile"]), } for row in rows } @@ -265,7 +265,7 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_group_role", ) - role["profile"] = json.loads(role["profile"]) + role["profile"] = db_to_json(role["profile"]) return role @@ -333,7 +333,7 @@ class GroupServerWorkerStore(SQLBaseStore): roles = { row[0]: { "is_public": row[1], - "profile": json.loads(row[2]), + "profile": db_to_json(row[2]), "order": row[3], } for row in txn @@ -462,7 +462,7 @@ class GroupServerWorkerStore(SQLBaseStore): now = int(self._clock.time_msec()) if row and now < row["valid_until_ms"]: - return json.loads(row["attestation_json"]) + return db_to_json(row["attestation_json"]) return None @@ -489,7 +489,7 @@ class GroupServerWorkerStore(SQLBaseStore): "group_id": row[0], "type": row[1], "membership": row[2], - "content": json.loads(row[3]), + "content": db_to_json(row[3]), } for row in txn ] @@ -519,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore): "group_id": group_id, "membership": membership, "type": gtype, - "content": json.loads(content_json), + "content": db_to_json(content_json), } for group_id, membership, gtype, content_json in txn ] @@ -567,7 +567,7 @@ class GroupServerWorkerStore(SQLBaseStore): """ txn.execute(sql, (last_id, current_id, limit)) updates = [ - (stream_id, (group_id, user_id, gtype, json.loads(content_json))) + (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) for stream_id, group_id, user_id, gtype, content_json in txn ] diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index f6e78ca590..d181488db7 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py
@@ -24,7 +24,7 @@ from twisted.internet import defer from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.pusher import PusherWorkerStore @@ -43,8 +43,8 @@ def _load_rules(rawrules, enabled_map): ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) + rule["conditions"] = db_to_json(rawrule["conditions"]) + rule["actions"] = db_to_json(rawrule["actions"]) rule["default"] = False ruleslist.append(rule) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 5461016240..e18f1ca87c 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py
@@ -17,11 +17,11 @@ import logging from typing import Iterable, Iterator, List, Tuple -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ class PusherWorkerStore(SQLBaseStore): for r in rows: dataJson = r["data"] try: - r["data"] = json.loads(dataJson) + r["data"] = db_to_json(dataJson) except Exception as e: logger.warning( "Invalid JSON in data for pusher %d: %s, %s", diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 8f5505bd67..1d723f2d34 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py
@@ -22,7 +22,7 @@ from canonicaljson import json from twisted.internet import defer -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.async_helpers import ObservableDeferred @@ -203,7 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore): for row in rows: content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ row["user_id"] - ] = json.loads(row["data"]) + ] = db_to_json(row["data"]) return [{"type": "m.receipt", "room_id": room_id, "content": content}] @@ -260,7 +260,7 @@ class ReceiptsWorkerStore(SQLBaseStore): event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) - receipt_type[row["user_id"]] = json.loads(row["data"]) + receipt_type[row["user_id"]] = db_to_json(row["data"]) results = { room_id: [results[room_id]] if room_id in results else [] @@ -329,7 +329,7 @@ class ReceiptsWorkerStore(SQLBaseStore): """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(r[0], r[1:5] + (json.loads(r[5]),)) for r in txn] + updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn] limited = False upper_bound = current_id diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 587d4b91c1..27d2c5028c 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py
@@ -27,6 +27,8 @@ from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidati from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import Database +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -42,6 +44,10 @@ class RegistrationWorkerStore(SQLBaseStore): self.config = hs.config self.clock = hs.get_clock() + self._user_id_seq = build_sequence_generator( + database.engine, find_max_generated_user_id_localpart, "user_id_seq", + ) + @cached() def get_user_by_id(self, user_id): return self.db.simple_select_one( @@ -481,39 +487,17 @@ class RegistrationWorkerStore(SQLBaseStore): ret = yield self.db.runInteraction("count_real_users", _count_users) return ret - @defer.inlineCallbacks - def find_next_generated_user_id_localpart(self): - """ - Gets the localpart of the next generated user ID. + async def generate_user_id(self) -> str: + """Generate a suitable localpart for a guest user - Generated user IDs are integers, so we find the largest integer user ID - already taken and return that plus one. + Returns: a (hopefully) free localpart """ - - def _find_next_generated_user_id(txn): - # We bound between '@0' and '@a' to avoid pulling the entire table - # out. - txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") - - regex = re.compile(r"^@(\d+):") - - max_found = 0 - - for (user_id,) in txn: - match = regex.search(user_id) - if match: - max_found = max(int(match.group(1)), max_found) - - return max_found + 1 - - return ( - ( - yield self.db.runInteraction( - "find_next_generated_user_id", _find_next_generated_user_id - ) - ) + next_id = await self.db.runInteraction( + "generate_user_id", self._user_id_seq.get_next_id_txn ) + return str(next_id) + async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: """Returns user id from threepid @@ -1573,3 +1557,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): keyvalues={"user_id": user_id}, values={"expiration_ts_ms": expiration_ts, "email_sent": False}, ) + + +def find_max_generated_user_id_localpart(cur: Cursor) -> int: + """ + Gets the localpart of the max current generated user ID. + + Generated user IDs are integers, so we find the largest integer user ID + already taken and return that. + """ + + # We bound between '@0' and '@a' to avoid pulling the entire table + # out. + cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") + + regex = re.compile(r"^@(\d+):") + + max_found = 0 + + for (user_id,) in cur: + match = regex.search(user_id) + if match: + max_found = max(int(match.group(1)), max_found) + return max_found diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index c473cf158f..d2e1e36e7f 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py
@@ -28,7 +28,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.search import SearchStore from synapse.storage.database import Database, LoggingTransaction from synapse.types import ThirdPartyInstanceID @@ -118,7 +118,12 @@ class RoomWorkerStore(SQLBaseStore): WHERE room_id = ? """ txn.execute(sql, [room_id]) - res = self.db.cursor_to_dict(txn)[0] + # Catch error if sql returns empty result to return "None" instead of an error + try: + res = self.db.cursor_to_dict(txn)[0] + except IndexError: + return None + res["federatable"] = bool(res["federatable"]) res["public"] = bool(res["public"]) return res @@ -665,7 +670,7 @@ class RoomWorkerStore(SQLBaseStore): next_token = None for stream_ordering, content_json in txn: next_token = stream_ordering - event_json = json.loads(content_json) + event_json = db_to_json(content_json) content = event_json["content"] content_url = content.get("url") thumbnail_url = content.get("info", {}).get("thumbnail_url") @@ -910,8 +915,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore): if not row["json"]: retention_policy = {} else: - ev = json.loads(row["json"]) - retention_policy = json.dumps(ev["content"]) + ev = db_to_json(row["json"]) + retention_policy = ev["content"] self.db.simple_insert_txn( txn=txn, @@ -966,7 +971,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): updates = [] for room_id, event_json in txn: - event_dict = json.loads(event_json) + event_dict = db_to_json(event_json) room_version_id = event_dict.get("content", {}).get( "room_version", RoomVersions.V1.identifier ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 44bab65eac..29765890ee 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py
@@ -17,8 +17,6 @@ import logging from typing import Iterable, List, Set -from canonicaljson import json - from twisted.internet import defer from synapse.api.constants import EventTypes, Membership @@ -27,6 +25,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import ( LoggingTransaction, SQLBaseStore, + db_to_json, make_in_list_sql_clause, ) from synapse.storage.data_stores.main.events_worker import EventsWorkerStore @@ -938,7 +937,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): event_id = row["event_id"] room_id = row["room_id"] try: - event_json = json.loads(row["json"]) + event_json = db_to_json(row["json"]) content = event_json["content"] except Exception: continue diff --git a/synapse/storage/data_stores/main/schema/delta/58/10drop_local_rejections_stream.sql b/synapse/storage/data_stores/main/schema/delta/58/10drop_local_rejections_stream.sql new file mode 100644
index 0000000000..eb57203e46 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/10drop_local_rejections_stream.sql
@@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +The version of synapse 1.16.0 on pypi incorrectly contained a migration which +added a table called 'local_rejections_stream'. This table is not used, and +we drop it here for anyone who was affected. +*/ + +DROP TABLE IF EXISTS local_rejections_stream; diff --git a/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql new file mode 100644
index 0000000000..1cc2633aad --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql
@@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We need to store the stream positions by instance in a sharded config world. +-- +-- We default to master as we want the column to be NOT NULL and we correctly +-- reset the instance name to match the config each time we start up. +ALTER TABLE federation_stream_position ADD COLUMN instance_name TEXT NOT NULL DEFAULT 'master'; + +CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position(type, instance_name); diff --git a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py new file mode 100644
index 0000000000..2011f6bceb --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py
@@ -0,0 +1,34 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adds a postgres SEQUENCE for generating guest user IDs. +""" + +from synapse.storage.data_stores.main.registration import ( + find_max_generated_user_id_localpart, +) +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + next_id = find_max_generated_user_id_localpart(cur) + 1 + cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 9d6540d142..a79533dfad 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py
@@ -17,12 +17,10 @@ import logging import re from collections import namedtuple -from canonicaljson import json - from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -157,7 +155,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): stream_ordering = row["stream_ordering"] origin_server_ts = row["origin_server_ts"] try: - event_json = json.loads(row["json"]) + event_json = db_to_json(row["json"]) content = event_json["content"] except Exception: continue diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 347cc50778..bb38a04ede 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py
@@ -353,6 +353,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): last_room_id = progress.get("last_room_id", "") def _background_remove_left_rooms_txn(txn): + # get a batch of room ids to consider sql = """ SELECT DISTINCT room_id FROM current_state_events WHERE room_id > ? ORDER BY room_id LIMIT ? @@ -363,24 +364,68 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): if not room_ids: return True, set() + ########################################################################### + # + # exclude rooms where we have active members + sql = """ SELECT room_id - FROM current_state_events + FROM local_current_membership WHERE room_id > ? AND room_id <= ? - AND type = 'm.room.member' AND membership = 'join' - AND state_key LIKE ? GROUP BY room_id """ - txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name)) - + txn.execute(sql, (last_room_id, room_ids[-1])) joined_room_ids = {row[0] for row in txn} + to_delete = set(room_ids) - joined_room_ids + + ########################################################################### + # + # exclude rooms which we are in the process of constructing; these otherwise + # qualify as "rooms with no local users", and would have their + # forward extremities cleaned up. + + # the following query will return a list of rooms which have forward + # extremities that are *not* also the create event in the room - ie + # those that are not being created currently. + + sql = """ + SELECT DISTINCT efe.room_id + FROM event_forward_extremities efe + LEFT JOIN current_state_events cse ON + cse.event_id = efe.event_id + AND cse.type = 'm.room.create' + AND cse.state_key = '' + WHERE + cse.event_id IS NULL + AND efe.room_id > ? AND efe.room_id <= ? + """ + + txn.execute(sql, (last_room_id, room_ids[-1])) + + # build a set of those rooms within `to_delete` that do not appear in + # the above, leaving us with the rooms in `to_delete` that *are* being + # created. + creating_rooms = to_delete.difference(row[0] for row in txn) + logger.info("skipping rooms which are being created: %s", creating_rooms) + + # now remove the rooms being created from the list of those to delete. + # + # (we could have just taken the intersection of `to_delete` with the result + # of the sql query, but it's useful to be able to log `creating_rooms`; and + # having done so, it's quicker to remove the (few) creating rooms from + # `to_delete` than it is to form the intersection with the (larger) list of + # not-creating-rooms) + + to_delete -= creating_rooms - left_rooms = set(room_ids) - joined_room_ids + ########################################################################### + # + # now clear the state for the rooms - logger.info("Deleting current state left rooms: %r", left_rooms) + logger.info("Deleting current state left rooms: %r", to_delete) # First we get all users that we still think were joined to the # room. This is so that we can mark those device lists as @@ -391,7 +436,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn, table="current_state_events", column="room_id", - iterable=left_rooms, + iterable=to_delete, keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN}, retcols=("state_key",), ) @@ -403,7 +448,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn, table="current_state_events", column="room_id", - iterable=left_rooms, + iterable=to_delete, keyvalues={}, ) @@ -411,7 +456,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn, table="event_forward_extremities", column="room_id", - iterable=left_rooms, + iterable=to_delete, keyvalues={}, ) diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 379d758b5d..5e32c7aa1e 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py
@@ -45,7 +45,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import Database, make_in_list_sql_clause from synapse.storage.engines import PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -253,6 +253,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(StreamWorkerStore, self).__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + self._send_federation = hs.should_send_federation() + self._federation_shard_config = hs.config.federation.federation_shard_config + + # If we're a process that sends federation we may need to reset the + # `federation_stream_position` table to match the current sharding + # config. We don't do this now as otherwise two processes could conflict + # during startup which would cause one to die. + self._need_to_reset_federation_stream_positions = self._send_federation + events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self.db.get_cache_dict( db_conn, @@ -793,22 +803,95 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, events - def get_federation_out_pos(self, typ): - return self.db.simple_select_one_onecol( + async def get_federation_out_pos(self, typ: str) -> int: + if self._need_to_reset_federation_stream_positions: + await self.db.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", - keyvalues={"type": typ}, + keyvalues={"type": typ, "instance_name": self._instance_name}, desc="get_federation_out_pos", ) - def update_federation_out_pos(self, typ, stream_id): - return self.db.simple_update_one( + async def update_federation_out_pos(self, typ, stream_id): + if self._need_to_reset_federation_stream_positions: + await self.db.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db.simple_update_one( table="federation_stream_position", - keyvalues={"type": typ}, + keyvalues={"type": typ, "instance_name": self._instance_name}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) + def _reset_federation_positions_txn(self, txn): + """Fiddles with the `federation_stream_position` table to make it match + the configured federation sender instances during start up. + """ + + # The federation sender instances may have changed, so we need to + # massage the `federation_stream_position` table to have a row per type + # per instance sending federation. If there is a mismatch we update the + # table with the correct rows using the *minimum* stream ID seen. This + # may result in resending of events/EDUs to remote servers, but that is + # preferable to dropping them. + + if not self._send_federation: + return + + # Pull out the configured instances. If we don't have a shard config then + # we assume that we're the only instance sending. + configured_instances = self._federation_shard_config.instances + if not configured_instances: + configured_instances = [self._instance_name] + elif self._instance_name not in configured_instances: + return + + instances_in_table = self.db.simple_select_onecol_txn( + txn, + table="federation_stream_position", + keyvalues={}, + retcol="instance_name", + ) + + if set(instances_in_table) == set(configured_instances): + # Nothing to do + return + + sql = """ + SELECT type, MIN(stream_id) FROM federation_stream_position + GROUP BY type + """ + txn.execute(sql) + min_positions = dict(txn) # Map from type -> min position + + # Ensure we do actually have some values here + assert set(min_positions) == {"federation", "events"} + + sql = """ + DELETE FROM federation_stream_position + WHERE NOT (%s) + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "instance_name", configured_instances + ) + txn.execute(sql % (clause,), args) + + for typ, stream_id in min_positions.items(): + self.db.simple_upsert_txn( + txn, + table="federation_stream_position", + keyvalues={"type": typ, "instance_name": self._instance_name}, + values={"stream_id": stream_id}, + ) + def has_room_changed_since(self, room_id, stream_id): return self._events_stream_cache.has_entity_changed(room_id, stream_id) diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 290317fd94..bd7227773a 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py
@@ -21,6 +21,7 @@ from canonicaljson import json from twisted.internet import defer +from synapse.storage._base import db_to_json from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.util.caches.descriptors import cached @@ -49,7 +50,7 @@ class TagsWorkerStore(AccountDataWorkerStore): tags_by_room = {} for row in rows: room_tags = tags_by_room.setdefault(row["room_id"], {}) - room_tags[row["tag"]] = json.loads(row["content"]) + room_tags[row["tag"]] = db_to_json(row["content"]) return tags_by_room return deferred @@ -180,7 +181,7 @@ class TagsWorkerStore(AccountDataWorkerStore): retcols=("tag", "content"), desc="get_tags_for_room", ).addCallback( - lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows} + lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows} ) diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
index 4c044b1a15..5f1b919748 100644 --- a/synapse/storage/data_stores/main/ui_auth.py +++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json from typing import Any, Dict, Optional, Union import attr +from canonicaljson import json from synapse.api.errors import StoreError -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.types import JsonDict from synapse.util import stringutils as stringutils @@ -118,7 +118,7 @@ class UIAuthWorkerStore(SQLBaseStore): desc="get_ui_auth_session", ) - result["clientdict"] = json.loads(result["clientdict"]) + result["clientdict"] = db_to_json(result["clientdict"]) return UIAuthSessionData(session_id, **result) @@ -168,7 +168,7 @@ class UIAuthWorkerStore(SQLBaseStore): retcols=("stage_type", "result"), desc="get_completed_ui_auth_stages", ): - results[row["stage_type"]] = json.loads(row["result"]) + results[row["stage_type"]] = db_to_json(row["result"]) return results @@ -224,7 +224,7 @@ class UIAuthWorkerStore(SQLBaseStore): ) # Update it and add it back to the database. - serverdict = json.loads(result["serverdict"]) + serverdict = db_to_json(result["serverdict"]) serverdict[key] = value self.db.simple_update_one_txn( @@ -254,7 +254,7 @@ class UIAuthWorkerStore(SQLBaseStore): desc="get_ui_auth_session_data", ) - serverdict = json.loads(result["serverdict"]) + serverdict = db_to_json(result["serverdict"]) return serverdict.get(key, default) diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index ec6b8a4ffd..d3038ff06d 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -70,11 +70,11 @@ class UserErasureWorkerStore(SQLBaseStore): class UserErasureStore(UserErasureWorkerStore): - def mark_user_erased(self, user_id): + def mark_user_erased(self, user_id: str) -> None: """Indicate that user_id wishes their message history to be erased. Args: - user_id (str): full user_id to be erased + user_id: full user_id to be erased """ def f(txn): @@ -89,3 +89,25 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) return self.db.runInteraction("mark_user_erased", f) + + def mark_user_not_erased(self, user_id: str) -> None: + """Indicate that user_id is no longer erased. + + Args: + user_id: full user_id to be un-erased + """ + + def f(txn): + # first check if they are already in the list + txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) + if not txn.fetchone(): + return + + # They are there, delete them. + self.simple_delete_one_txn( + txn, "erased_users", keyvalues={"user_id": user_id} + ) + + self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) + + return self.db.runInteraction("mark_user_not_erased", f) diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index 5db9f20135..128c09a2cf 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.database import Database from synapse.storage.state import StateFilter +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -92,6 +94,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "*stateGroupMembersCache*", 500000, ) + def get_max_state_group_txn(txn: Cursor): + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + return txn.fetchone()[0] + + self._state_group_seq_gen = build_sequence_generator( + self.database_engine, get_max_state_group_txn, "state_group_id_seq" + ) + @cached(max_entries=10000, iterable=True) def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between @@ -386,7 +396,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # AFAIK, this can never happen raise Exception("current_state_ids cannot be None") - state_group = self.database_engine.get_next_state_group_id(txn) + state_group = self._state_group_seq_gen.get_next_id_txn(txn) self.db.simple_insert_txn( txn, diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ab0bbe4bd3..908cbc79e3 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py
@@ -91,12 +91,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): def lock_table(self, txn, table: str) -> None: ... - @abc.abstractmethod - def get_next_state_group_id(self, txn) -> int: - """Returns an int that can be used as a new state_group ID - """ - ... - @property @abc.abstractmethod def server_version(self) -> str: diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a31588080d..ff39281f85 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py
@@ -154,12 +154,6 @@ class PostgresEngine(BaseDatabaseEngine): def lock_table(self, txn, table): txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) - def get_next_state_group_id(self, txn): - """Returns an int that can be used as a new state_group ID - """ - txn.execute("SELECT nextval('state_group_id_seq')") - return txn.fetchone()[0] - @property def server_version(self): """Returns a string giving the server version. For example: '8.1.5' diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 215a949442..8a0f8c89d1 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py
@@ -96,19 +96,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): def lock_table(self, txn, table): return - def get_next_state_group_id(self, txn): - """Returns an int that can be used as a new state_group ID - """ - # We do application locking here since if we're using sqlite then - # we are a single process synapse. - with self._current_state_group_id_lock: - if self._current_state_group_id is None: - txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") - self._current_state_group_id = txn.fetchone()[0] - - self._current_state_group_id += 1 - return self._current_state_group_id - @property def server_version(self): """Gets a string giving the server version. For example: '3.22.0' diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f89ce0bed2..787cebfbec 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py
@@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple from typing_extensions import Deque from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.util.sequence import PostgresSequenceGenerator class IdGenerator(object): @@ -247,7 +248,6 @@ class MultiWriterIdGenerator: ): self._db = db self._instance_name = instance_name - self._sequence_name = sequence_name # We lock as some functions may be called from DB threads. self._lock = threading.Lock() @@ -260,6 +260,8 @@ class MultiWriterIdGenerator: # should be less than the minimum of this set (if not empty). self._unfinished_ids = set() # type: Set[int] + self._sequence_gen = PostgresSequenceGenerator(sequence_name) + def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str ) -> Dict[str, int]: @@ -283,9 +285,7 @@ class MultiWriterIdGenerator: return current_positions def _load_next_id_txn(self, txn): - txn.execute("SELECT nextval(?)", (self._sequence_name,)) - (next_id,) = txn.fetchone() - return next_id + return self._sequence_gen.get_next_id_txn(txn) async def get_next(self): """ diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py new file mode 100644
index 0000000000..63dfea4220 --- /dev/null +++ b/synapse/storage/util/sequence.py
@@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import threading +from typing import Callable, Optional + +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.types import Cursor + + +class SequenceGenerator(metaclass=abc.ABCMeta): + """A class which generates a unique sequence of integers""" + + @abc.abstractmethod + def get_next_id_txn(self, txn: Cursor) -> int: + """Gets the next ID in the sequence""" + ... + + +class PostgresSequenceGenerator(SequenceGenerator): + """An implementation of SequenceGenerator which uses a postgres sequence""" + + def __init__(self, sequence_name: str): + self._sequence_name = sequence_name + + def get_next_id_txn(self, txn: Cursor) -> int: + txn.execute("SELECT nextval(?)", (self._sequence_name,)) + return txn.fetchone()[0] + + +GetFirstCallbackType = Callable[[Cursor], int] + + +class LocalSequenceGenerator(SequenceGenerator): + """An implementation of SequenceGenerator which uses local locking + + This only works reliably if there are no other worker processes generating IDs at + the same time. + """ + + def __init__(self, get_first_callback: GetFirstCallbackType): + """ + Args: + get_first_callback: a callback which is called on the first call to + get_next_id_txn; should return the curreent maximum id + """ + # the callback. this is cleared after it is called, so that it can be GCed. + self._callback = get_first_callback # type: Optional[GetFirstCallbackType] + + # The current max value, or None if we haven't looked in the DB yet. + self._current_max_id = None # type: Optional[int] + self._lock = threading.Lock() + + def get_next_id_txn(self, txn: Cursor) -> int: + # We do application locking here since if we're using sqlite then + # we are a single process synapse. + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + self._current_max_id += 1 + return self._current_max_id + + +def build_sequence_generator( + database_engine: BaseDatabaseEngine, + get_first_callback: GetFirstCallbackType, + sequence_name: str, +) -> SequenceGenerator: + """Get the best impl of SequenceGenerator available + + This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on + sqlite. + + Args: + database_engine: the database engine we are connected to + get_first_callback: a callback which gets the next sequence ID. Used if + we're on sqlite. + sequence_name: the name of a postgres sequence to use. + """ + if isinstance(database_engine, PostgresEngine): + return PostgresSequenceGenerator(sequence_name) + else: + return LocalSequenceGenerator(get_first_callback) diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index cd56cd91ed..ca7c16ff65 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py
@@ -68,13 +68,13 @@ class PaginationConfig(object): elif from_tok: from_tok = StreamToken.from_string(from_tok) except Exception: - raise SynapseError(400, "'from' paramater is invalid") + raise SynapseError(400, "'from' parameter is invalid") try: if to_tok: to_tok = StreamToken.from_string(to_tok) except Exception: - raise SynapseError(400, "'to' paramater is invalid") + raise SynapseError(400, "'to' parameter is invalid") limit = parse_integer(request, "limit", default=default_limit) diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index fcd2aaa9c9..5d3eddcfdc 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py
@@ -68,7 +68,7 @@ class EventSources(object): The returned token does not have the current values for fields other than `room`, since they are not used during pagination. - Retuns: + Returns: Deferred[StreamToken] """ token = StreamToken( diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 60f0de70f7..c63256d3bd 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py
@@ -55,7 +55,7 @@ class Clock(object): return self._reactor.seconds() def time_msec(self): - """Returns the current system time in miliseconds since epoch.""" + """Returns the current system time in milliseconds since epoch.""" return int(self.time() * 1000) def looping_call(self, f, msec, *args, **kwargs): diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 65abf0846e..f562770922 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -352,7 +352,7 @@ class ReadWriteLock(object): # resolved when they release the lock). # # Read: We know its safe to acquire a read lock when the latest writer has - # been resolved. The new reader is appeneded to the list of latest readers. + # been resolved. The new reader is appended to the list of latest readers. # # Write: We know its safe to acquire the write lock when both the latest # writers and readers have been resolved. The new writer replaces the latest diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 64f35fc288..9b09c08b89 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py
@@ -516,7 +516,7 @@ class CacheListDescriptor(_CacheDescriptorBase): """ Args: orig (function) - cached_method_name (str): The name of the chached method. + cached_method_name (str): The name of the cached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int): number of positional arguments (excluding ``self``, but including list_name) to use as cache keys. Defaults to all diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 45af8d3eeb..22a857a306 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py
@@ -12,10 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect import logging from twisted.internet import defer +from twisted.internet.defer import Deferred, fail, succeed +from twisted.python import failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -39,7 +41,7 @@ class Distributor(object): Signals are named simply by strings. TODO(paul): It would be nice to give signals stronger object identities, - so we can attach metadata, docstrings, detect typoes, etc... But this + so we can attach metadata, docstrings, detect typos, etc... But this model will do for today. """ @@ -79,6 +81,28 @@ class Distributor(object): run_as_background_process(name, self.signals[name].fire, *args, **kwargs) +def maybeAwaitableDeferred(f, *args, **kw): + """ + Invoke a function that may or may not return a Deferred or an Awaitable. + + This is a modified version of twisted.internet.defer.maybeDeferred. + """ + try: + result = f(*args, **kw) + except Exception: + return fail(failure.Failure(captureVars=Deferred.debug)) + + if isinstance(result, Deferred): + return result + # Handle the additional case of an awaitable being returned. + elif inspect.isawaitable(result): + return defer.ensureDeferred(result) + elif isinstance(result, failure.Failure): + return fail(result) + else: + return succeed(result) + + class Signal(object): """A Signal is a dispatch point that stores a list of callables as observers of it. @@ -122,7 +146,7 @@ class Signal(object): ), ) - return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) + return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb) deferreds = [run_in_background(do, o) for o in self.observers] diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 2605f3c65b..54c046b6e1 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py
@@ -192,7 +192,7 @@ def _check_yield_points(f: Callable, changes: List[str]): result = yield d except Exception: # this will fish an earlier Failure out of the stack where possible, and - # thus is preferable to passing in an exeception to the Failure + # thus is preferable to passing in an exception to the Failure # constructor, since it results in less stack-mangling. result = Failure() diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index af69587196..8794317caa 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py
@@ -22,7 +22,7 @@ from synapse.api.errors import CodeMessageException logger = logging.getLogger(__name__) -# the intial backoff, after the first transaction fails +# the initial backoff, after the first transaction fails MIN_RETRY_INTERVAL = 10 * 60 * 1000 # how much we multiply the backoff by after each subsequent fail @@ -174,7 +174,7 @@ class RetryDestinationLimiter(object): # has been decommissioned. # If we get a 401, then we should probably back off since they # won't accept our requests for at least a while. - # 429 is us being aggresively rate limited, so lets rate limit + # 429 is us being aggressively rate limited, so lets rate limit # ourselves. if exc_val.code == 404 and self.backoff_on_404: valid_err_code = False diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 08c86e92b8..2e2b40a426 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py
@@ -17,7 +17,7 @@ import itertools import random import re import string -from collections import Iterable +from collections.abc import Iterable from synapse.api.errors import Codes, SynapseError diff --git a/synapse/visibility.py b/synapse/visibility.py
index 3dfd4af26c..0f042c5696 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py
@@ -319,7 +319,7 @@ def filter_events_for_server( return True # Lets check to see if all the events have a history visibility - # of "shared" or "world_readable". If thats the case then we don't + # of "shared" or "world_readable". If that's the case then we don't # need to check membership (as we know the server is in the room). event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), @@ -335,7 +335,7 @@ def filter_events_for_server( visibility_ids.add(hist) # If we failed to find any history visibility events then the default - # is "shared" visiblity. + # is "shared" visibility. if not visibility_ids: all_open = True else: diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 70c8e72303..f9ce609923 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py
@@ -192,7 +192,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") self.failureResultOf(d, SynapseError) - # should suceed on a signed object + # should succeed on a signed object d = _verify_json_for_server(kr, "server9", json1, 500, "test signed") # self.assertFalse(d.called) self.get_success(d) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 640f5f3bce..3a80626224 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py
@@ -41,8 +41,10 @@ class TestEventContext(unittest.HomeserverTestCase): serialize/deserialize. """ - event, context = create_event( - self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + event, context = self.get_success( + create_event( + self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + ) ) self._check_serialize_deserialize(event, context) @@ -51,12 +53,14 @@ class TestEventContext(unittest.HomeserverTestCase): """Test that an EventContext for a state event (with not previous entry) is the same after serialize/deserialize. """ - event, context = create_event( - self.hs, - room_id=self.room_id, - type="m.test", - sender=self.user_id, - state_key="", + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + sender=self.user_id, + state_key="", + ) ) self._check_serialize_deserialize(event, context) @@ -65,13 +69,15 @@ class TestEventContext(unittest.HomeserverTestCase): """Test that an EventContext for a state event (which replaces a previous entry) is the same after serialize/deserialize. """ - event, context = create_event( - self.hs, - room_id=self.room_id, - type="m.room.member", - sender=self.user_id, - state_key=self.user_id, - content={"membership": "leave"}, + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.room.member", + sender=self.user_id, + state_key=self.user_id, + content={"membership": "leave"}, + ) ) self._check_serialize_deserialize(event, context) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 62b47f6574..6aa322bf3a 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py
@@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.get_success(self.handler.delete_device(user1, "abc")) # check the device was deleted - res = self.handler.get_device(user1, "abc") - self.pump() - self.assertIsInstance( - self.failureResultOf(res).value, synapse.api.errors.NotFoundError + self.get_failure( + self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError ) # we'd like to check the access token was invalidated, but that's a @@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): def test_update_unknown_device(self): update = {"display_name": "new_display"} - res = self.handler.update_device("user_id", "unknown_device_id", update) - self.pump() - self.assertIsInstance( - self.failureResultOf(res).value, synapse.api.errors.NotFoundError + self.get_failure( + self.handler.update_device("user_id", "unknown_device_id", update), + synapse.api.errors.NotFoundError, ) def _record_users(self): diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 1acf287ca4..210ddcbb88 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py
@@ -46,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): """If the user has no devices, we expect an empty list. """ local_user = "@boris:" + self.hs.hostname - res = yield self.handler.query_local_devices({local_user: None}) + res = yield defer.ensureDeferred( + self.handler.query_local_devices({local_user: None}) + ) self.assertDictEqual(res, {local_user: {}}) @defer.inlineCallbacks @@ -60,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) # we should be able to change the signature without a problem keys["alg2:k2"]["signatures"]["k1"] = "sig2" - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) @@ -84,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + ) ) self.fail("No error when changing string key") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + ) ) self.fail("No error when replacing dict key with string") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"one_time_keys": {"alg1:k1": {"key": "key"}}}, + ) ) self.fail("No error when replacing string key with dict") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, - device_id, - { - "one_time_keys": { - "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} - } - }, + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + { + "one_time_keys": { + "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} + } + }, + ) ) self.fail("No error when replacing dict key") except errors.SynapseError: @@ -133,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_id = "xyz" keys = {"alg1:k1": "key1"} - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) - res2 = yield self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + res2 = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) ) self.assertEqual( res2, @@ -163,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) keys2 = { "master_key": { @@ -175,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys2) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys2) + ) - devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, local_user + devices = yield defer.ensureDeferred( + self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) @@ -215,7 +241,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", ) - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) # upload two device keys, which will be signed later by the self-signing key device_key_1 = { @@ -245,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "signatures": {local_user: {"ed25519:def": "base64+signature"}}, } - yield self.handler.upload_keys_for_user( - local_user, "abc", {"device_keys": device_key_1} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, "abc", {"device_keys": device_key_1} + ) ) - yield self.handler.upload_keys_for_user( - local_user, "def", {"device_keys": device_key_2} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, "def", {"device_keys": device_key_2} + ) ) # sign the first device key and upload it del device_key_1["signatures"] sign.sign_json(device_key_1, local_user, signing_key) - yield self.handler.upload_signatures_for_device_keys( - local_user, {local_user: {"abc": device_key_1}} + yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1}} + ) ) # sign the second device key and upload both device keys. The server @@ -264,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # signature for it del device_key_2["signatures"] sign.sign_json(device_key_2, local_user, signing_key) - yield self.handler.upload_signatures_for_device_keys( - local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + ) ) device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" - devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, local_user + devices = yield defer.ensureDeferred( + self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"] @@ -292,20 +328,26 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) res = None try: - yield self.hs.get_device_handler().check_device_registered( - user_id=local_user, - device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", - initial_device_display_name="new display name", + yield defer.ensureDeferred( + self.hs.get_device_handler().check_device_registered( + user_id=local_user, + device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + initial_device_display_name="new display name", + ) ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 400) - res = yield self.handler.query_local_devices({local_user: None}) + res = yield defer.ensureDeferred( + self.handler.query_local_devices({local_user: None}) + ) self.assertDictEqual(res, {local_user: {}}) @defer.inlineCallbacks @@ -331,8 +373,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" ) - yield self.handler.upload_keys_for_user( - local_user, device_id, {"device_keys": device_key} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"device_keys": device_key} + ) ) # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 @@ -372,7 +416,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_signing_key": usersigning_key, "self_signing_key": selfsigning_key, } - yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + ) # set up another user with a master key. This user will be signed by # the first user @@ -384,76 +430,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "usage": ["master"], "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, } - yield self.handler.upload_signing_keys_for_user( - other_user, {"master_key": other_master_key} + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user( + other_user, {"master_key": other_master_key} + ) ) # test various signature failures (see below) - ret = yield self.handler.upload_signatures_for_device_keys( - local_user, - { - local_user: { - # fails because the signature is invalid - # should fail with INVALID_SIGNATURE - device_id: { - "user_id": local_user, - "device_id": device_id, - "algorithms": [ - "m.olm.curve25519-aes-sha2", - RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, - ], - "keys": { - "curve25519:xyz": "curve25519+key", - # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA - "ed25519:xyz": device_pubkey, - }, - "signatures": { - local_user: {"ed25519:" + selfsigning_pubkey: "something"} + ret = yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: { + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + device_id: { + "user_id": local_user, + "device_id": device_id, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], + "keys": { + "curve25519:xyz": "curve25519+key", + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + "ed25519:xyz": device_pubkey, + }, + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something" + } + }, }, - }, - # fails because device is unknown - # should fail with NOT_FOUND - "unknown": { - "user_id": local_user, - "device_id": "unknown", - "signatures": { - local_user: {"ed25519:" + selfsigning_pubkey: "something"} + # fails because device is unknown + # should fail with NOT_FOUND + "unknown": { + "user_id": local_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something" + } + }, }, - }, - # fails because the signature is invalid - # should fail with INVALID_SIGNATURE - master_pubkey: { - "user_id": local_user, - "usage": ["master"], - "keys": {"ed25519:" + master_pubkey: master_pubkey}, - "signatures": { - local_user: {"ed25519:" + device_pubkey: "something"} + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + master_pubkey: { + "user_id": local_user, + "usage": ["master"], + "keys": {"ed25519:" + master_pubkey: master_pubkey}, + "signatures": { + local_user: {"ed25519:" + device_pubkey: "something"} + }, }, }, - }, - other_user: { - # fails because the device is not the user's master-signing key - # should fail with NOT_FOUND - "unknown": { - "user_id": other_user, - "device_id": "unknown", - "signatures": { - local_user: {"ed25519:" + usersigning_pubkey: "something"} + other_user: { + # fails because the device is not the user's master-signing key + # should fail with NOT_FOUND + "unknown": { + "user_id": other_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something" + } + }, }, - }, - other_master_pubkey: { - # fails because the key doesn't match what the server has - # should fail with UNKNOWN - "user_id": other_user, - "usage": ["master"], - "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, - "something": "random", - "signatures": { - local_user: {"ed25519:" + usersigning_pubkey: "something"} + other_master_pubkey: { + # fails because the key doesn't match what the server has + # should fail with UNKNOWN + "user_id": other_user, + "usage": ["master"], + "keys": { + "ed25519:" + other_master_pubkey: other_master_pubkey + }, + "something": "random", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something" + } + }, }, }, }, - }, + ) ) user_failures = ret["failures"][local_user] @@ -478,19 +538,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase): sign.sign_json(device_key, local_user, selfsigning_signing_key) sign.sign_json(master_key, local_user, device_signing_key) sign.sign_json(other_master_key, local_user, usersigning_signing_key) - ret = yield self.handler.upload_signatures_for_device_keys( - local_user, - { - local_user: {device_id: device_key, master_pubkey: master_key}, - other_user: {other_master_pubkey: other_master_key}, - }, + ret = yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: {device_id: device_key, master_pubkey: master_key}, + other_user: {other_master_pubkey: other_master_key}, + }, + ) ) self.assertEqual(ret["failures"], {}) # fetch the signed keys/devices and make sure that the signatures are there - ret = yield self.handler.query_devices( - {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ret = yield defer.ensureDeferred( + self.handler.query_devices( + {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ) ) self.assertEqual( diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 822ea42dde..3362050ce0 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py
@@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_version_info(self.local_user) + yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_version_info(self.local_user, "bogus_version") + yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "bogus_version") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -87,14 +89,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_create_version(self): """Check that we can create and then retrieve versions. """ - res = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(res, "1") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) version_etag = res["etag"] self.assertIsInstance(version_etag, str) del res["etag"] @@ -109,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # check we can retrieve it as a specific version - res = yield self.handler.get_version_info(self.local_user, "1") + res = yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "1") + ) self.assertEqual(res["etag"], version_etag) del res["etag"] self.assertDictEqual( @@ -123,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # upload a new one... - res = yield self.handler.create_version( - self.local_user, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) ) self.assertEqual(res, "2") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -149,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_version(self): """Check that we can update versions. """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - res = yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version, - }, + res = yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) ) self.assertDictEqual(res, {}) # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -185,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.update_version( - self.local_user, - "1", - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "1", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + "1", + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1", + }, + ) ) except errors.SynapseError as e: res = e.code @@ -202,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_omitted_version(self): """Check that the update succeeds if the version is missing from the body """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) ) # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, @@ -234,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_bad_version(self): """Check that we get a 400 if the version in the body doesn't match """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") res = None try: - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "incorrect", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect", + }, + ) ) except errors.SynapseError as e: res = e.code @@ -261,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.delete_version(self.local_user, "1") + yield defer.ensureDeferred( + self.handler.delete_version(self.local_user, "1") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -272,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.delete_version(self.local_user) + yield defer.ensureDeferred(self.handler.delete_version(self.local_user)) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -281,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_version(self): """Check that we can create and then delete versions. """ - res = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(res, "1") # check we can delete it - yield self.handler.delete_version(self.local_user, "1") + yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1")) # check that it's gone res = None try: - yield self.handler.get_version_info(self.local_user, "1") + yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "1") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -304,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_room_keys(self.local_user, "bogus_version") + yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, "bogus_version") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -313,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_get_missing_room_keys(self): """Check we get an empty response from an empty backup """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertDictEqual(res, {"rooms": {}}) # TODO: test the locking semantics when uploading room_keys, @@ -331,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.upload_room_keys( - self.local_user, "no_version", room_keys + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, "no_version", room_keys) ) except errors.SynapseError as e: res = e.code @@ -343,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we get a 404 on uploading keys when an nonexistent version is specified """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") res = None try: - yield self.handler.upload_room_keys( - self.local_user, "bogus_version", room_keys + yield defer.ensureDeferred( + self.handler.upload_room_keys( + self.local_user, "bogus_version", room_keys + ) ) except errors.SynapseError as e: res = e.code @@ -362,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_wrong_version(self): """Check that we get a 403 on uploading keys for an old version """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - version = yield self.handler.create_version( - self.local_user, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) ) self.assertEqual(version, "2") res = None try: - yield self.handler.upload_room_keys(self.local_user, "1", room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, "1", room_keys) + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 403) @@ -388,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_insert(self): """Check that we can insert and retrieve keys for a session """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given room - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org" + ) ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given session_id - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, room_keys) @@ -415,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_merge(self): """Check that we can upload a new room_key for an existing session and have it correctly merged""" - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) # get the etag to compare to future versions - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) backup_etag = res["etag"] self.assertEqual(res["count"], 1) @@ -434,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # test that increasing the message_index doesn't replace the existing session new_room_key["first_message_index"] = 2 new_room_key["session_data"] = "new" - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) # the etag should be the same since the session did not change - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should NOT be equal now, since the key changed - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertNotEqual(res["etag"], backup_etag) backup_etag = res["etag"] @@ -464,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # with a lower forwarding count new_room_key["forwarded_count"] = 2 new_room_key["session_data"] = "other" - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should be the same since the session did not change - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # TODO: check edge cases as well as the common variations here @@ -481,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_room_keys(self): """Check that we can insert and delete keys for a session """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") # check for bulk-delete - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys(self.local_user, version) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys(self.local_user, version) + ) + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per room - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys( - self.local_user, version, room_id="!abc:matrix.org" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys( + self.local_user, version, room_id="!abc:matrix.org" + ) ) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per session - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 29dd7d9c6e..4f1347cd25 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py
@@ -72,7 +72,9 @@ class ProfileTestCase(unittest.TestCase): def test_get_my_name(self): yield self.store.set_profile_displayname(self.frank.localpart, "Frank") - displayname = yield self.handler.get_displayname(self.frank) + displayname = yield defer.ensureDeferred( + self.handler.get_displayname(self.frank) + ) self.assertEquals("Frank", displayname) @@ -140,7 +142,9 @@ class ProfileTestCase(unittest.TestCase): {"displayname": "Alice"} ) - displayname = yield self.handler.get_displayname(self.alice) + displayname = yield defer.ensureDeferred( + self.handler.get_displayname(self.alice) + ) self.assertEquals(displayname, "Alice") self.mock_federation.make_query.assert_called_with( @@ -155,8 +159,10 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile("caroline") yield self.store.set_profile_displayname("caroline", "Caroline") - response = yield self.query_handlers["profile"]( - {"user_id": "@caroline:test", "field": "displayname"} + response = yield defer.ensureDeferred( + self.query_handlers["profile"]( + {"user_id": "@caroline:test", "field": "displayname"} + ) ) self.assertEquals({"displayname": "Caroline"}, response) @@ -166,8 +172,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" ) - - avatar_url = yield self.handler.get_avatar_url(self.frank) + avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) self.assertEquals("http://my.server/me.png", avatar_url) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1e6a53bf7f..5878f74175 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py
@@ -138,10 +138,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room - def get_current_users_in_room(room_id): + def get_users_in_room(room_id): return defer.succeed({str(u) for u in self.room_members}) - hs.get_state_handler().get_current_users_in_room = get_current_users_in_room + self.datastore.get_users_in_room = get_users_in_room self.datastore.get_user_directory_stream_pos.return_value = ( # we deliberately return a non-None stream pos to avoid doing an initial_spam diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 954e059e76..69945a8f98 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -67,6 +67,14 @@ def get_connection_factory(): return test_server_connection_factory +# Once Async Mocks or lambdas are supported this can go away. +def generate_resolve_service(result): + async def resolve_service(_): + return result + + return resolve_service + + class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() @@ -373,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the certificate on the server doesn't match the hostname """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv1"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv1/foo/bar") @@ -456,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name has no port, no SRV, and no well-known """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -510,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the .well-known delegates elsewhere """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -572,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -661,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -717,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # the config left to the default, which will not trust it (since the # presented cert is signed by a test CA) - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" config = default_config("test", parse=True) @@ -764,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when there is a single SRV record """ - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443) - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"srvtarget", port=8443)] + ) self.reactor.lookups["srvtarget"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -819,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443) - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"srvtarget", port=8443)] + ) self._handle_well_known_connection( client_factory, @@ -861,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_servername(self): """test the behaviour when the server name has idna chars in""" - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) # the resolver is always called with the IDNA hostname as a native string. self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" @@ -922,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_srv_target(self): """test the behaviour when the target of a SRV record has idna chars""" - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"xn--trget-3qa.com", port=8443) # târget.com - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com + ) self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") @@ -1087,11 +1095,12 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails. """ - - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"target.com", port=8443), - Server(host=b"target.com", port=8444), - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] + ) self.reactor.lookups["target.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index babc201643..fee2985d35 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError from twisted.names import dns, error from synapse.http.federation.srv_resolver import SrvResolver -from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context +from synapse.logging.context import LoggingContext, current_context from tests import unittest from tests.utils import MockClock @@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase): with LoggingContext("one") as ctx: resolve_d = resolver.resolve_service(service_name) - - self.assertNoResult(resolve_d) - - # should have reset to the sentinel context - self.assertIs(current_context(), SENTINEL_CONTEXT) - - result = yield resolve_d + result = yield defer.ensureDeferred(resolve_d) # should have restored our context self.assertIs(current_context(), ctx) @@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {service_name: [entry]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) dns_client_mock.lookupService.assert_called_once_with(service_name) @@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase): dns_client=dns_client_mock, cache=cache, get_time=clock.time ) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) self.assertFalse(dns_client_mock.lookupService.called) @@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase): resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) with self.assertRaises(error.DNSServerError): - yield resolver.resolve_service(service_name) + yield defer.ensureDeferred(resolver.resolve_service(service_name)) @defer.inlineCallbacks def test_name_error(self): @@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) self.assertEquals(len(servers), 0) self.assertEquals(len(cache), 0) @@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - resolve_d = resolver.resolve_service(service_name) - self.assertNoResult(resolve_d) + # Old versions of Twisted don't have an ensureDeferred in failureResultOf. + resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name)) # returning a single "." should make the lookup fail with a ConenctError lookup_deferred.callback( @@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - resolve_d = resolver.resolve_service(service_name) - self.assertNoResult(resolve_d) + # Old versions of Twisted don't have an ensureDeferred in successResultOf. + resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name)) lookup_deferred.callback( ( diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 9d4f0bbe44..06575ba0a6 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import attr @@ -26,8 +26,9 @@ from synapse.app.generic_worker import ( GenericWorkerReplicationHandler, GenericWorkerServer, ) +from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest -from synapse.replication.http import streams +from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -35,7 +36,7 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeTransport +from tests.server import FakeTransport, render logger = logging.getLogger(__name__) @@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(request.method, b"GET") +class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): + """Base class for tests running multiple workers. + + Automatically handle HTTP replication requests from workers to master, + unlike `BaseStreamTestCase`. + """ + + servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]] + + def setUp(self): + super().setUp() + + # build a replication server + self.server_factory = ReplicationStreamProtocolFactory(self.hs) + self.streamer = self.hs.get_replication_streamer() + + store = self.hs.get_datastore() + self.database = store.db + + self.reactor.lookups["testserv"] = "1.2.3.4" + + self._worker_hs_to_resource = {} + + # When we see a connection attempt to the master replication listener we + # automatically set up the connection. This is so that tests don't + # manually have to go and explicitly set it up each time (plus sometimes + # it is impossible to write the handling explicitly in the tests). + self.reactor.add_tcp_client_callback( + "1.2.3.4", 8765, self._handle_http_replication_attempt + ) + + def create_test_json_resource(self): + """Overrides `HomeserverTestCase.create_test_json_resource`. + """ + # We override this so that it automatically registers all the HTTP + # replication servlets, without having to explicitly do that in all + # subclassses. + + resource = ReplicationRestResource(self.hs) + + for servlet in self.servlets: + servlet(self.hs, resource) + + return resource + + def make_worker_hs( + self, worker_app: str, extra_config: dict = {}, **kwargs + ) -> HomeServer: + """Make a new worker HS instance, correctly connecting replcation + stream to the master HS. + + Args: + worker_app: Type of worker, e.g. `synapse.app.federation_sender`. + extra_config: Any extra config to use for this instances. + **kwargs: Options that get passed to `self.setup_test_homeserver`, + useful to e.g. pass some mocks for things like `http_client` + + Returns: + The new worker HomeServer instance. + """ + + config = self._get_worker_hs_config() + config["worker_app"] = worker_app + config.update(extra_config) + + worker_hs = self.setup_test_homeserver( + homeserverToUse=GenericWorkerServer, + config=config, + reactor=self.reactor, + **kwargs + ) + + store = worker_hs.get_datastore() + store.db._db_pool = self.database._db_pool + + repl_handler = ReplicationCommandHandler(worker_hs) + client = ClientReplicationStreamProtocol( + worker_hs, "client", "test", self.clock, repl_handler, + ) + server = self.server_factory.buildProtocol(None) + + client_transport = FakeTransport(server, self.reactor) + client.makeConnection(client_transport) + + server_transport = FakeTransport(client, self.reactor) + server.makeConnection(server_transport) + + # Set up a resource for the worker + resource = ReplicationRestResource(self.hs) + + for servlet in self.servlets: + servlet(worker_hs, resource) + + self._worker_hs_to_resource[worker_hs] = resource + + return worker_hs + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + + def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): + render(request, self._worker_hs_to_resource[worker_hs], self.reactor) + + def replicate(self): + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump() + + def _handle_http_replication_attempt(self): + """Handles a connection attempt to the master replication HTTP + listener. + """ + + # We should have at least one outbound connection attempt, where the + # last is one to the HTTP repication IP/port. + clients = self.reactor.tcpClients + self.assertGreaterEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop() + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8765) + + # Set up client side protocol + client_protocol = client_factory.buildProtocol(None) + + request_factory = OneShotRequestFactory() + + # Set up the server side protocol + channel = _PushHTTPChannel(self.reactor) + channel.requestFactory = request_factory + channel.site = self.site + + # Connect client to server and vice versa. + client_to_server_transport = FakeTransport( + channel, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, channel + ) + channel.makeConnection(server_to_client_transport) + + # Note: at this point we've wired everything up, but we need to return + # before the data starts flowing over the connections as this is called + # inside `connecTCP` before the connection has been passed back to the + # code that requested the TCP connection. + + class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" @@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel): # We need to manually stop the _PullToPushProducer. self._pull_to_push_producer.stop() + def checkPersistence(self, request, version): + """Check whether the connection can be re-used + """ + # We hijack this to always say no for ease of wiring stuff up in + # `handle_http_replication_attempt`. + request.responseHeaders.setRawHeaders(b"connection", [b"close"]) + return False + class _PullToPushProducer: """A push producer that wraps a pull producer. diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 097e1653b4..c9998e88e6 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py
@@ -119,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase): OTHER_USER = "@other_user:localhost" # have the user join - inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + self.get_success( + inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + ) # Update existing power levels with mod at PL50 pls = self.helper.get_state( @@ -157,14 +159,16 @@ class EventsStreamTestCase(BaseStreamTestCase): # roll back all the state by de-modding the user prev_events = fork_point pls["users"][OTHER_USER] = 0 - pl_event = inject_event( - self.hs, - prev_event_ids=prev_events, - type=EventTypes.PowerLevels, - state_key="", - sender=self.user_id, - room_id=self.room_id, - content=pls, + pl_event = self.get_success( + inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) ) # one more bit of state that doesn't get rolled back @@ -268,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase): # have the users join for u in user_ids: - inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + self.get_success( + inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + ) # Update existing power levels with mod at PL50 pls = self.helper.get_state( @@ -306,14 +312,16 @@ class EventsStreamTestCase(BaseStreamTestCase): pl_events = [] for u in user_ids: pls["users"][u] = 0 - e = inject_event( - self.hs, - prev_event_ids=prev_events, - type=EventTypes.PowerLevels, - state_key="", - sender=self.user_id, - room_id=self.room_id, - content=pls, + e = self.get_success( + inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) ) prev_events = [e.event_id] pl_events.append(e) @@ -434,13 +442,15 @@ class EventsStreamTestCase(BaseStreamTestCase): body = "event %i" % (self.event_count,) self.event_count += 1 - return inject_event( - self.hs, - room_id=self.room_id, - sender=sender, - type="test_event", - content={"body": body}, - **kwargs + return self.get_success( + inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_event", + content={"body": body}, + **kwargs + ) ) def _inject_state_event( @@ -459,11 +469,13 @@ class EventsStreamTestCase(BaseStreamTestCase): if body is None: body = "state event %s" % (state_key,) - return inject_event( - self.hs, - room_id=self.room_id, - sender=sender, - type="test_state_event", - state_key=state_key, - content={"body": body}, + return self.get_success( + inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_state_event", + state_key=state_key, + content={"body": body}, + ) ) diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py new file mode 100644
index 0000000000..86c03fd89c --- /dev/null +++ b/tests/replication/test_client_reader_shard.py
@@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.api.constants import LoginType +from synapse.http.site import SynapseRequest +from synapse.rest.client.v2_alpha import register + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker +from tests.server import FakeChannel + +logger = logging.getLogger(__name__) + + +class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): + """Base class for tests of the replication streams""" + + servlets = [register.register_servlets] + + def prepare(self, reactor, clock, hs): + self.recaptcha_checker = DummyRecaptchaChecker(hs) + auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.client_reader" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + + def test_register_single_worker(self): + """Test that registration works when using a single client reader worker. + """ + worker_hs = self.make_worker_hs("synapse.app.client_reader") + + request_1, channel_1 = self.make_request( + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) # type: SynapseRequest, FakeChannel + self.render_on_worker(worker_hs, request_1) + self.assertEqual(request_1.code, 401) + + # Grab the session + session = channel_1.json_body["session"] + + # also complete the dummy auth + request_2, channel_2 = self.make_request( + "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + ) # type: SynapseRequest, FakeChannel + self.render_on_worker(worker_hs, request_2) + self.assertEqual(request_2.code, 200) + + # We're given a registered user. + self.assertEqual(channel_2.json_body["user_id"], "@user:test") + + def test_register_multi_worker(self): + """Test that registration works when using multiple client reader workers. + """ + worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") + worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") + + request_1, channel_1 = self.make_request( + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) # type: SynapseRequest, FakeChannel + self.render_on_worker(worker_hs_1, request_1) + self.assertEqual(request_1.code, 401) + + # Grab the session + session = channel_1.json_body["session"] + + # also complete the dummy auth + request_2, channel_2 = self.make_request( + "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + ) # type: SynapseRequest, FakeChannel + self.render_on_worker(worker_hs_2, request_2) + self.assertEqual(request_2.code, 200) + + # We're given a registered user. + self.assertEqual(channel_2.json_body["user_id"], "@user:test") diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 5448d9f0dc..23be1167a3 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py
@@ -32,6 +32,7 @@ class FederationAckTestCase(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer) + return hs def test_federation_ack_sent(self): diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py new file mode 100644
index 0000000000..8d4dbf232e --- /dev/null +++ b/tests/replication/test_federation_sender_shard.py
@@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from mock import Mock + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.events.builder import EventBuilderFactory +from synapse.rest.admin import register_servlets_for_client_rest_resource +from synapse.rest.client.v1 import login, room +from synapse.types import UserID + +from tests.replication._base import BaseMultiWorkerStreamTestCase + +logger = logging.getLogger(__name__) + + +class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): + servlets = [ + login.register_servlets, + register_servlets_for_client_rest_resource, + room.register_servlets, + ] + + def default_config(self): + conf = super().default_config() + conf["send_federation"] = False + return conf + + def test_send_event_single_sender(self): + """Test that using a single federation sender worker correctly sends a + new event. + """ + mock_client = Mock(spec=["put_json"]) + mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) + + self.make_worker_hs( + "synapse.app.federation_sender", + {"send_federation": True}, + http_client=mock_client, + ) + + user = self.register_user("user", "pass") + token = self.login("user", "pass") + + room = self.create_room_with_remote_server(user, token) + + mock_client.put_json.reset_mock() + + self.create_and_send_event(room, UserID.from_string(user)) + self.replicate() + + # Assert that the event was sent out over federation. + mock_client.put_json.assert_called() + self.assertEqual(mock_client.put_json.call_args[0][0], "other_server") + self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus")) + + def test_send_event_sharded(self): + """Test that using two federation sender workers correctly sends + new events. + """ + mock_client1 = Mock(spec=["put_json"]) + mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", + { + "send_federation": True, + "worker_name": "sender1", + "federation_sender_instances": ["sender1", "sender2"], + }, + http_client=mock_client1, + ) + + mock_client2 = Mock(spec=["put_json"]) + mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", + { + "send_federation": True, + "worker_name": "sender2", + "federation_sender_instances": ["sender1", "sender2"], + }, + http_client=mock_client2, + ) + + user = self.register_user("user2", "pass") + token = self.login("user2", "pass") + + sent_on_1 = False + sent_on_2 = False + for i in range(20): + server_name = "other_server_%d" % (i,) + room = self.create_room_with_remote_server(user, token, server_name) + mock_client1.reset_mock() # type: ignore[attr-defined] + mock_client2.reset_mock() # type: ignore[attr-defined] + + self.create_and_send_event(room, UserID.from_string(user)) + self.replicate() + + if mock_client1.put_json.called: + sent_on_1 = True + mock_client2.put_json.assert_not_called() + self.assertEqual(mock_client1.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("pdus")) + elif mock_client2.put_json.called: + sent_on_2 = True + mock_client1.put_json.assert_not_called() + self.assertEqual(mock_client2.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("pdus")) + else: + raise AssertionError( + "Expected send transaction from one or the other sender" + ) + + if sent_on_1 and sent_on_2: + break + + self.assertTrue(sent_on_1) + self.assertTrue(sent_on_2) + + def test_send_typing_sharded(self): + """Test that using two federation sender workers correctly sends + new typing EDUs. + """ + mock_client1 = Mock(spec=["put_json"]) + mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", + { + "send_federation": True, + "worker_name": "sender1", + "federation_sender_instances": ["sender1", "sender2"], + }, + http_client=mock_client1, + ) + + mock_client2 = Mock(spec=["put_json"]) + mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", + { + "send_federation": True, + "worker_name": "sender2", + "federation_sender_instances": ["sender1", "sender2"], + }, + http_client=mock_client2, + ) + + user = self.register_user("user3", "pass") + token = self.login("user3", "pass") + + typing_handler = self.hs.get_typing_handler() + + sent_on_1 = False + sent_on_2 = False + for i in range(20): + server_name = "other_server_%d" % (i,) + room = self.create_room_with_remote_server(user, token, server_name) + mock_client1.reset_mock() # type: ignore[attr-defined] + mock_client2.reset_mock() # type: ignore[attr-defined] + + self.get_success( + typing_handler.started_typing( + target_user=UserID.from_string(user), + auth_user=UserID.from_string(user), + room_id=room, + timeout=20000, + ) + ) + + self.replicate() + + if mock_client1.put_json.called: + sent_on_1 = True + mock_client2.put_json.assert_not_called() + self.assertEqual(mock_client1.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("edus")) + elif mock_client2.put_json.called: + sent_on_2 = True + mock_client1.put_json.assert_not_called() + self.assertEqual(mock_client2.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("edus")) + else: + raise AssertionError( + "Expected send transaction from one or the other sender" + ) + + if sent_on_1 and sent_on_2: + break + + self.assertTrue(sent_on_1) + self.assertTrue(sent_on_2) + + def create_room_with_remote_server(self, user, token, remote_server="other_server"): + room = self.helper.create_room_as(user, tok=token) + store = self.hs.get_datastore() + federation = self.hs.get_handlers().federation_handler + + prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) + room_version = self.get_success(store.get_room_version(room)) + + factory = EventBuilderFactory(self.hs) + factory.hostname = remote_server + + user_id = UserID("user", remote_server).to_string() + + event_dict = { + "type": EventTypes.Member, + "state_key": user_id, + "content": {"membership": Membership.JOIN}, + "sender": user_id, + "room_id": room, + } + + builder = factory.for_room_version(room_version, event_dict) + join_event = self.get_success(builder.build(prev_event_ids)) + + self.get_success(federation.on_send_join_request(remote_server, join_event)) + self.replicate() + + return room diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py new file mode 100644
index 0000000000..2bdc6edbb1 --- /dev/null +++ b/tests/replication/test_pusher_shard.py
@@ -0,0 +1,193 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from mock import Mock + +from twisted.internet import defer + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication._base import BaseMultiWorkerStreamTestCase + +logger = logging.getLogger(__name__) + + +class PusherShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks pusher sharding works + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # Register a user who sends a message that we'll get notified about + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + def default_config(self): + conf = super().default_config() + conf["start_pushers"] = False + return conf + + def _create_pusher_and_send_msg(self, localpart): + # Create a user that will get push notifications + user_id = self.register_user(localpart, "pass") + access_token = self.login(localpart, "pass") + + # Register a pusher + user_dict = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_dict["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "https://push.example.com/push"}, + ) + ) + + self.pump() + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join( + room=room, user=self.other_user_id, tok=self.other_access_token + ) + + # The other user sends some messages + response = self.helper.send(room, body="Hi!", tok=self.other_access_token) + event_id = response["event_id"] + + return event_id + + def test_send_push_single_worker(self): + """Test that registration works when using a pusher worker. + """ + http_client_mock = Mock(spec_set=["post_json_get_json"]) + http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + {"start_pushers": True}, + proxied_http_client=http_client_mock, + ) + + event_id = self._create_pusher_and_send_msg("user") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock.post_json_get_json.assert_called_once() + self.assertEqual( + http_client_mock.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) + + def test_send_push_multiple_workers(self): + """Test that registration works when using sharded pusher workers. + """ + http_client_mock1 = Mock(spec_set=["post_json_get_json"]) + http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + { + "start_pushers": True, + "worker_name": "pusher1", + "pusher_instances": ["pusher1", "pusher2"], + }, + proxied_http_client=http_client_mock1, + ) + + http_client_mock2 = Mock(spec_set=["post_json_get_json"]) + http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + { + "start_pushers": True, + "worker_name": "pusher2", + "pusher_instances": ["pusher1", "pusher2"], + }, + proxied_http_client=http_client_mock2, + ) + + # We choose a user name that we know should go to pusher1. + event_id = self._create_pusher_and_send_msg("user2") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock1.post_json_get_json.assert_called_once() + http_client_mock2.post_json_get_json.assert_not_called() + self.assertEqual( + http_client_mock1.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock1.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) + + http_client_mock1.post_json_get_json.reset_mock() + http_client_mock2.post_json_get_json.reset_mock() + + # Now we choose a user name that we know should go to pusher2. + event_id = self._create_pusher_and_send_msg("user4") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock1.post_json_get_json.assert_not_called() + http_client_mock2.post_json_get_json.assert_called_once() + self.assertEqual( + http_client_mock2.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock2.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index ae6d05a043..946f06d151 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py
@@ -151,6 +151,401 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): ) +class DeleteRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + + request, channel = self.make_request( + "POST", self.url, json.dumps({}), access_token=self.other_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + url = "/_synapse/admin/v1/rooms/!unknown:test/delete" + + request, channel = self.make_request( + "POST", url, json.dumps({}), access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + url = "/_synapse/admin/v1/rooms/invalidroom/delete" + + request, channel = self.make_request( + "POST", url, json.dumps({}), access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom is not a legal room ID", channel.json_body["error"], + ) + + def test_new_room_user_does_not_exist(self): + """ + Tests that the user ID must be from local server but it does not have to exist. + """ + body = json.dumps({"new_room_user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("kicked_users", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + def test_new_room_user_is_not_local(self): + """ + Check that only local users can create new room to move members. + """ + body = json.dumps({"new_room_user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "User must be our own: @not:exist.bla", channel.json_body["error"], + ) + + def test_block_is_not_bool(self): + """ + If parameter `block` is not boolean, return an error + """ + body = json.dumps({"block": "NotBool"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_purge_room_and_block(self): + """Test to purge a room and block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": True}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_purge_room_and_not_block(self): + """Test to purge a room and do not block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": False}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(self.room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 + ) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["new_room_id"], user_id=self.other_user + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["new_room_id"], user_id=self.other_user + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + # Assert we can no longer peek into the room + self._assert_peek(self.room_id, expect_code=403) + + def _is_blocked(self, room_id, expect=True): + """Assert that the room is blocked or not + """ + d = self.store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _has_no_members(self, room_id): + """Assert there is now no longer anyone in the room + """ + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def _is_member(self, room_id, user_id): + """Test that user is member of the room + """ + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertIn(user_id, users_in_room) + + def _is_purged(self, room_id): + """Test that the following tables have been purged of all rows related to the room. + """ + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + # "state_groups", # Current impl leaves orphaned state groups around. + "state_groups_state", + ): + count = self.get_success( + self.store.db.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + class PurgeRoomTestCase(unittest.HomeserverTestCase): """Test /purge_room admin API. """ @@ -741,6 +1136,52 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id_1, channel.json_body["room_id"]) + def test_room_members(self): + """Test that room members can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Have another user join the room + user_1 = self.register_user("foo", "pass") + user_tok_1 = self.login("foo", "pass") + self.helper.join(room_id_1, user_1, tok=user_tok_1) + + # Have another user join the room + user_2 = self.register_user("bar", "pass") + user_tok_2 = self.login("bar", "pass") + self.helper.join(room_id_1, user_2, tok=user_tok_2) + self.helper.join(room_id_2, user_2, tok=user_tok_2) + + # Have another user join the room + user_3 = self.register_user("foobar", "pass") + user_tok_3 = self.login("foobar", "pass") + self.helper.join(room_id_2, user_3, tok=user_tok_3) + + url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertCountEqual( + ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] + ) + self.assertEqual(channel.json_body["total"], 3) + + url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertCountEqual( + ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] + ) + self.assertEqual(channel.json_body["total"], 3) + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index cca5f548e6..f16eef15f7 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -857,6 +857,53 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + def test_reactivate_user(self): + """ + Test reactivating another user. + """ + + # Deactivate the user. + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=json.dumps({"deactivated": True}).encode(encoding="utf_8"), + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Attempt to reactivate the user (without a password). + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=json.dumps({"deactivated": False}).encode(encoding="utf_8"), + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + # Reactivate the user. + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=json.dumps({"deactivated": False, "password": "foo"}).encode( + encoding="utf_8" + ), + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + def test_set_user_as_admin(self): """ Test setting the admin flag on a user. diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 95475bb651..e54ffea150 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -126,7 +126,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): events.append(self.get_success(store.get_event(valid_event_id))) - # Advance the time by anothe 2 days. After this, the first event should be + # Advance the time by another 2 days. After this, the first event should be # outdated but not the second one. self.reactor.advance(one_day_ms * 2 / 1000) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index fd97999956..db52725cfe 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py
@@ -398,7 +398,7 @@ class CASTestCase(unittest.HomeserverTestCase): </cas:serviceResponse> """ % cas_user_id - ) + ).encode("utf-8") mocked_http_client = Mock(spec=["get_raw"]) mocked_http_client.get_raw.side_effect = get_raw @@ -514,16 +514,17 @@ class JWTTestCase(unittest.HomeserverTestCase): ] jwt_secret = "secret" + jwt_algorithm = "HS256" def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() self.hs.config.jwt_enabled = True self.hs.config.jwt_secret = self.jwt_secret - self.hs.config.jwt_algorithm = "HS256" + self.hs.config.jwt_algorithm = self.jwt_algorithm return self.hs def jwt_encode(self, token, secret=jwt_secret): - return jwt.encode(token, secret, "HS256").decode("ascii") + return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii") def jwt_login(self, *args): params = json.dumps( @@ -546,35 +547,126 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_invalid_signature(self): channel = self.jwt_login({"sub": "frog"}, "notsecret") - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") - self.assertEqual(channel.json_body["error"], "Invalid JWT") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: Signature verification failed", + ) def test_login_jwt_expired(self): channel = self.jwt_login({"sub": "frog", "exp": 864000}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") - self.assertEqual(channel.json_body["error"], "JWT expired") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Signature has expired" + ) def test_login_jwt_not_before(self): now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") - self.assertEqual(channel.json_body["error"], "Invalid JWT") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: The token is not yet valid (nbf)", + ) def test_login_no_sub(self): channel = self.jwt_login({"username": "root"}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") + @override_config( + { + "jwt_config": { + "jwt_enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + "issuer": "test-issuer", + } + } + ) + def test_login_iss(self): + """Test validating the issuer claim.""" + # A valid issuer. + channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + # An invalid issuer. + channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid issuer" + ) + + # Not providing an issuer. + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + 'JWT validation failed: Token is missing the "iss" claim', + ) + + def test_login_iss_no_config(self): + """Test providing an issuer claim without requiring it in the configuration.""" + channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + @override_config( + { + "jwt_config": { + "jwt_enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + "audiences": ["test-audience"], + } + } + ) + def test_login_aud(self): + """Test validating the audience claim.""" + # A valid audience. + channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + # An invalid audience. + channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid audience" + ) + + # Not providing an audience. + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + 'JWT validation failed: Token is missing the "aud" claim', + ) + + def test_login_aud_no_config(self): + """Test providing an audience without requiring it in the configuration.""" + channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid audience" + ) + def test_login_no_token(self): params = json.dumps({"type": "org.matrix.login.jwt"}) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -656,6 +748,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def test_login_jwt_invalid_signature(self): channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") - self.assertEqual(channel.json_body["error"], "Invalid JWT") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: Signature verification failed", + ) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 0fdff79aa7..3c66255dac 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py
@@ -60,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): def test_put_presence_disabled(self): """ - PUT to the status endpoint with use_presence disbled will NOT call + PUT to the status endpoint with use_presence disabled will NOT call set_state on the presence handler. """ self.hs.config.use_presence = False diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index fd641a7c2f..99c9f4e928 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -99,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(400, channel.code, channel.json_body) def test_basic_paginate_relations(self): - """Tests that calling pagination API corectly the latest relations. + """Tests that calling pagination API correctly the latest relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") self.assertEquals(200, channel.code, channel.json_body) diff --git a/tests/server.py b/tests/server.py
index a5e57c52fa..b6e0b14e78 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -237,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def __init__(self): self.threadpool = ThreadPool(self) + self._tcp_callbacks = {} self._udp = [] lookups = self.lookups = {} @@ -268,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def getThreadPool(self): return self.threadpool + def add_tcp_client_callback(self, host, port, callback): + """Add a callback that will be invoked when we receive a connection + attempt to the given IP/port using `connectTCP`. + + Note that the callback gets run before we return the connection to the + client, which means callbacks cannot block while waiting for writes. + """ + self._tcp_callbacks[(host, port)] = callback + + def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): + """Fake L{IReactorTCP.connectTCP}. + """ + + conn = super().connectTCP( + host, port, factory, timeout=timeout, bindAddress=None + ) + + callback = self._tcp_callbacks.get((host, port)) + if callback: + callback() + + return conn + class ThreadPool: """ @@ -486,7 +510,7 @@ class FakeTransport(object): try: self.other.dataReceived(to_write) except Exception as e: - logger.warning("Exception writing to protocol: %s", e) + logger.exception("Exception writing to protocol: %s", e) return self.buffer = self.buffer[len(to_write) :] diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 3b78d48896..b1dceb2918 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py
@@ -56,6 +56,10 @@ class RoomStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks + def test_get_room_unknown_room(self): + self.assertIsNone((yield self.store.get_room("!uknown:test")),) + + @defer.inlineCallbacks def test_get_room_with_stats(self): self.assertDictContainsSubset( { @@ -66,6 +70,10 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room_with_stats(self.room.to_string())), ) + @defer.inlineCallbacks + def test_get_room_with_stats_unknown_room(self): + self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),) + class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 5dd46005e6..f282921538 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py
@@ -118,18 +118,22 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): def test_get_joined_users_from_context(self): room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - bob_event = event_injection.inject_member_event( - self.hs, room, self.u_bob, Membership.JOIN + bob_event = self.get_success( + event_injection.inject_member_event( + self.hs, room, self.u_bob, Membership.JOIN + ) ) # first, create a regular event - event, context = event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[bob_event.event_id], - type="m.test.1", - content={}, + event, context = self.get_success( + event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[bob_event.event_id], + type="m.test.1", + content={}, + ) ) users = self.get_success( @@ -140,22 +144,26 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # Regression test for #7376: create a state event whose key matches bob's # user_id, but which is *not* a membership event, and persist that; then check # that `get_joined_users_from_context` returns the correct users for the next event. - non_member_event = event_injection.inject_event( - self.hs, - room_id=room, - sender=self.u_bob, - prev_event_ids=[bob_event.event_id], - type="m.test.2", - state_key=self.u_bob, - content={}, + non_member_event = self.get_success( + event_injection.inject_event( + self.hs, + room_id=room, + sender=self.u_bob, + prev_event_ids=[bob_event.event_id], + type="m.test.2", + state_key=self.u_bob, + content={}, + ) ) - event, context = event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[non_member_event.event_id], - type="m.test.3", - content={}, + event, context = self.get_success( + event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[non_member_event.event_id], + type="m.test.3", + content={}, + ) ) users = self.get_success( self.store.get_joined_users_from_context(event, context) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 0b88308ff4..a0e133cd4a 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py
@@ -64,8 +64,8 @@ class StateStoreTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) diff --git a/tests/test_federation.py b/tests/test_federation.py
index 89dcc58b99..87a16d7d7a 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -173,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. store = self.homeserver.get_datastore() - store.get_rooms_for_user = Mock(return_value=["!someroom:test"]) + store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. @@ -218,23 +218,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register mock device list retrieval on the federation client. federation_client = self.homeserver.get_federation_client() federation_client.query_user_devices = Mock( - return_value={ - "user_id": remote_user_id, - "stream_id": 1, - "devices": [], - "master_key": { + return_value=succeed( + { "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - "self_signing_key": { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" + remote_self_signing_key: remote_self_signing_key + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, }, - }, - } + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) ) # Resync the device list. diff --git a/tests/test_mau.py b/tests/test_mau.py
index 49667ed7f4..654a6fa42d 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py
@@ -166,7 +166,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.do_sync_for_user(token5) self.do_sync_for_user(token6) - # But old user cant + # But old user can't with self.assertRaises(SynapseError) as cm: self.do_sync_for_user(token1) diff --git a/tests/test_server.py b/tests/test_server.py
index 030f58cbdc..42cada8964 100644 --- a/tests/test_server.py +++ b/tests/test_server.py
@@ -12,26 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import re -from io import StringIO from twisted.internet.defer import Deferred -from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, RedirectException, SynapseError from synapse.config.server import parse_listener_def from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource -from synapse.http.site import SynapseSite, logger +from synapse.http.site import SynapseSite from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock from tests import unittest from tests.server import ( - FakeTransport, ThreadedMemoryReactorClock, make_request, render, @@ -318,54 +312,3 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.assertEqual(location_headers, [b"/no/over/there"]) cookies_headers = [v for k, v in headers if k == b"Set-Cookie"] self.assertEqual(cookies_headers, [b"session=yespls"]) - - -class SiteTestCase(unittest.HomeserverTestCase): - def test_lose_connection(self): - """ - We log the URI correctly redacted when we lose the connection. - """ - - class HangingResource(Resource): - """ - A Resource that strategically hangs, as if it were processing an - answer. - """ - - def render(self, request): - return NOT_DONE_YET - - # Set up a logging handler that we can inspect afterwards - output = StringIO() - handler = logging.StreamHandler(output) - logger.addHandler(handler) - old_level = logger.level - logger.setLevel(10) - self.addCleanup(logger.setLevel, old_level) - self.addCleanup(logger.removeHandler, handler) - - # Make a resource and a Site, the resource will hang and allow us to - # time out the request while it's 'processing' - base_resource = Resource() - base_resource.putChild(b"", HangingResource()) - site = SynapseSite( - "test", "site_tag", self.hs.config.listeners[0], base_resource, "1.0" - ) - - server = site.buildProtocol(None) - client = AccumulatingProtocol() - client.makeConnection(FakeTransport(server, self.reactor)) - server.makeConnection(FakeTransport(client, self.reactor)) - - # Send a request with an access token that will get redacted - server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n") - self.pump() - - # Lose the connection - e = Failure(Exception("Failed123")) - server.connectionLost(e) - handler.flush() - - # Our access token is redacted and the failure reason is logged. - self.assertIn("/?access_token=<redacted>", output.getvalue()) - self.assertIn("Failed123", output.getvalue()) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 43297b530c..8522c6fc09 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py
@@ -22,14 +22,12 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import Collection -from tests.test_utils import get_awaitable_result - """ Utility functions for poking events into the storage of the server under test. """ -def inject_member_event( +async def inject_member_event( hs: synapse.server.HomeServer, room_id: str, sender: str, @@ -46,7 +44,7 @@ def inject_member_event( if extra_content: content.update(extra_content) - return inject_event( + return await inject_event( hs, room_id=room_id, type=EventTypes.Member, @@ -57,7 +55,7 @@ def inject_member_event( ) -def inject_event( +async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[Collection[str]] = None, @@ -72,37 +70,27 @@ def inject_event( prev_event_ids: prev_events for the event. If not specified, will be looked up kwargs: fields for the event to be created """ - test_reactor = hs.get_reactor() - - event, context = create_event(hs, room_version, prev_event_ids, **kwargs) + event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) - d = hs.get_storage().persistence.persist_event(event, context) - test_reactor.advance(0) - get_awaitable_result(d) + await hs.get_storage().persistence.persist_event(event, context) return event -def create_event( +async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[Collection[str]] = None, **kwargs ) -> Tuple[EventBase, EventContext]: - test_reactor = hs.get_reactor() - if room_version is None: - d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) - test_reactor.advance(0) - room_version = get_awaitable_result(d) + room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"]) builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs ) - d = hs.get_event_creation_handler().create_new_client_event( + event, context = await hs.get_event_creation_handler().create_new_client_event( builder, prev_event_ids=prev_event_ids ) - test_reactor.advance(0) - event, context = get_awaitable_result(d) return event, context diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index f7381b2885..b371efc0df 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py
@@ -53,7 +53,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # # before we do that, we persist some other events to act as state. - self.inject_visibility("@admin:hs", "joined") + yield self.inject_visibility("@admin:hs", "joined") for i in range(0, 10): yield self.inject_room_member("@resident%i:hs" % i) @@ -137,8 +137,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) return event @@ -158,8 +158,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) @@ -179,8 +179,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) diff --git a/tests/unittest.py b/tests/unittest.py
index 3175a3fa02..68d2586efd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase): user: MXID of the user to inject the membership for. membership: The membership type. """ - event_injection.inject_member_event(self.hs, room, user, membership) + self.get_success( + event_injection.inject_member_event(self.hs, room, user, membership) + ) class FederatingHomeserverTestCase(HomeserverTestCase): diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 95301c013c..58ee918f65 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py
@@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable(self): - # a function which retuns an incomplete deferred, but doesn't follow + # a function which returns an incomplete deferred, but doesn't follow # the synapse rules. def blocking_function(): d = defer.Deferred() @@ -183,7 +183,7 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable_with_await(self): - # an async function which retuns an incomplete coroutine, but doesn't + # an async function which returns an incomplete coroutine, but doesn't # follow the synapse rules. async def blocking_function(): diff --git a/tests/utils.py b/tests/utils.py
index 4d17355a5c..ac643679aa 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -671,6 +671,8 @@ def create_room(hs, room_id, creator_id): }, ) - event, context = yield event_creation_handler.create_new_client_event(builder) + event, context = yield defer.ensureDeferred( + event_creation_handler.create_new_client_event(builder) + ) yield persistence_store.persist_event(event, context) diff --git a/tox.ini b/tox.ini
index 1c042cb227..834d68aea5 100644 --- a/tox.ini +++ b/tox.ini
@@ -126,7 +126,7 @@ deps = black==19.10b0 commands = python -m black --check --diff . - /bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}" + /bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}" {toxinidir}/scripts-dev/config-lint.sh [testenv:check_isort]